OSDN Git Service

Merge "zram performance test."
[android-x86/system-extras.git] / tests / net_test / netlink.py
1 #!/usr/bin/python
2 #
3 # Copyright 2014 The Android Open Source Project
4 #
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
8 #
9 # http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
16
17 """Partial Python implementation of iproute functionality."""
18
19 # pylint: disable=g-bad-todo
20
21 import errno
22 import os
23 import socket
24 import struct
25 import sys
26
27 import cstruct
28
29
30 # Request constants.
31 NLM_F_REQUEST = 1
32 NLM_F_ACK = 4
33 NLM_F_REPLACE = 0x100
34 NLM_F_EXCL = 0x200
35 NLM_F_CREATE = 0x400
36 NLM_F_DUMP = 0x300
37
38 # Message types.
39 NLMSG_ERROR = 2
40 NLMSG_DONE = 3
41
42 # Data structure formats.
43 # These aren't constants, they're classes. So, pylint: disable=invalid-name
44 NLMsgHdr = cstruct.Struct("NLMsgHdr", "=LHHLL", "length type flags seq pid")
45 NLMsgErr = cstruct.Struct("NLMsgErr", "=i", "error")
46 NLAttr = cstruct.Struct("NLAttr", "=HH", "nla_len nla_type")
47
48 # Alignment / padding.
49 NLA_ALIGNTO = 4
50
51
52 def PaddedLength(length):
53   # TODO: This padding is probably overly simplistic.
54   return NLA_ALIGNTO * ((length / NLA_ALIGNTO) + (length % NLA_ALIGNTO != 0))
55
56
57 class NetlinkSocket(object):
58   """A basic netlink socket object."""
59
60   BUFSIZE = 65536
61   DEBUG = False
62   # List of netlink messages to print, e.g., [], ["NEIGH", "ROUTE"], or ["ALL"]
63   NL_DEBUG = []
64
65   def _Debug(self, s):
66     if self.DEBUG:
67       print s
68
69   def _NlAttr(self, nla_type, data):
70     datalen = len(data)
71     # Pad the data if it's not a multiple of NLA_ALIGNTO bytes long.
72     padding = "\x00" * (PaddedLength(datalen) - datalen)
73     nla_len = datalen + len(NLAttr)
74     return NLAttr((nla_len, nla_type)).Pack() + data + padding
75
76   def _NlAttrU32(self, nla_type, value):
77     return self._NlAttr(nla_type, struct.pack("=I", value))
78
79   def _GetConstantName(self, module, value, prefix):
80     thismodule = sys.modules[module]
81     for name in dir(thismodule):
82       if (name.startswith(prefix) and
83           not name.startswith(prefix + "F_") and
84           name.isupper() and
85           getattr(thismodule, name) == value):
86         return name
87     return value
88
89   def _Decode(self, command, msg, nla_type, nla_data):
90     """No-op, nonspecific version of decode."""
91     return nla_type, nla_data
92
93   def _ParseAttributes(self, command, family, msg, data):
94     """Parses and decodes netlink attributes.
95
96     Takes a block of NLAttr data structures, decodes them using Decode, and
97     returns the result in a dict keyed by attribute number.
98
99     Args:
100       command: An integer, the rtnetlink command being carried out.
101       family: The address family.
102       msg: A Struct, the type of the data after the netlink header.
103       data: A byte string containing a sequence of NLAttr data structures.
104
105     Returns:
106       A dictionary mapping attribute types (integers) to decoded values.
107
108     Raises:
109       ValueError: There was a duplicate attribute type.
110     """
111     attributes = {}
112     while data:
113       # Read the nlattr header.
114       nla, data = cstruct.Read(data, NLAttr)
115
116       # Read the data.
117       datalen = nla.nla_len - len(nla)
118       padded_len = PaddedLength(nla.nla_len) - len(nla)
119       nla_data, data = data[:datalen], data[padded_len:]
120
121       # If it's an attribute we know about, try to decode it.
122       nla_name, nla_data = self._Decode(command, msg, nla.nla_type, nla_data)
123
124       # We only support unique attributes for now.
125       if nla_name in attributes:
126         raise ValueError("Duplicate attribute %d" % nla_name)
127
128       attributes[nla_name] = nla_data
129       self._Debug("      %s" % str((nla_name, nla_data)))
130
131     return attributes
132
133   def __init__(self):
134     # Global sequence number.
135     self.seq = 0
136     self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.FAMILY)
137     self.sock.connect((0, 0))  # The kernel.
138     self.pid = self.sock.getsockname()[1]
139
140   def _Send(self, msg):
141     # self._Debug(msg.encode("hex"))
142     self.seq += 1
143     self.sock.send(msg)
144
145   def _Recv(self):
146     data = self.sock.recv(self.BUFSIZE)
147     # self._Debug(data.encode("hex"))
148     return data
149
150   def _ExpectDone(self):
151     response = self._Recv()
152     hdr = NLMsgHdr(response)
153     if hdr.type != NLMSG_DONE:
154       raise ValueError("Expected DONE, got type %d" % hdr.type)
155
156   def _ParseAck(self, response):
157     # Find the error code.
158     hdr, data = cstruct.Read(response, NLMsgHdr)
159     if hdr.type == NLMSG_ERROR:
160       error = NLMsgErr(data).error
161       if error:
162         raise IOError(error, os.strerror(-error))
163     else:
164       raise ValueError("Expected ACK, got type %d" % hdr.type)
165
166   def _ExpectAck(self):
167     response = self._Recv()
168     self._ParseAck(response)
169
170   def _SendNlRequest(self, command, data, flags):
171     """Sends a netlink request and expects an ack."""
172     length = len(NLMsgHdr) + len(data)
173     nlmsg = NLMsgHdr((length, command, flags, self.seq, self.pid)).Pack()
174
175     self.MaybeDebugCommand(command, nlmsg + data)
176
177     # Send the message.
178     self._Send(nlmsg + data)
179
180     if flags & NLM_F_ACK:
181       self._ExpectAck()
182
183   def _ParseNLMsg(self, data, msgtype):
184     """Parses a Netlink message into a header and a dictionary of attributes."""
185     nlmsghdr, data = cstruct.Read(data, NLMsgHdr)
186     self._Debug("  %s" % nlmsghdr)
187
188     if nlmsghdr.type == NLMSG_ERROR or nlmsghdr.type == NLMSG_DONE:
189       print "done"
190       return (None, None), data
191
192     nlmsg, data = cstruct.Read(data, msgtype)
193     self._Debug("    %s" % nlmsg)
194
195     # Parse the attributes in the nlmsg.
196     attrlen = nlmsghdr.length - len(nlmsghdr) - len(nlmsg)
197     attributes = self._ParseAttributes(nlmsghdr.type, nlmsg.family,
198                                        nlmsg, data[:attrlen])
199     data = data[attrlen:]
200     return (nlmsg, attributes), data
201
202   def _GetMsgList(self, msgtype, data, expect_done):
203     out = []
204     while data:
205       msg, data = self._ParseNLMsg(data, msgtype)
206       if msg is None:
207         break
208       out.append(msg)
209     if expect_done:
210       self._ExpectDone()
211     return out
212
213   def _Dump(self, command, msg, msgtype):
214     """Sends a dump request and returns a list of decoded messages."""
215     # Create a netlink dump request containing the msg.
216     flags = NLM_F_DUMP | NLM_F_REQUEST
217     length = len(NLMsgHdr) + len(msg)
218     nlmsghdr = NLMsgHdr((length, command, flags, self.seq, self.pid))
219
220     # Send the request.
221     self._Send(nlmsghdr.Pack() + msg.Pack())
222
223     # Keep reading netlink messages until we get a NLMSG_DONE.
224     out = []
225     while True:
226       data = self._Recv()
227       response_type = NLMsgHdr(data).type
228       if response_type == NLMSG_DONE:
229         break
230       out.extend(self._GetMsgList(msgtype, data, False))
231
232     return out