OSDN Git Service

Add support for reading routing rules.
authorLorenzo Colitti <lorenzo@google.com>
Fri, 4 Apr 2014 11:18:37 +0000 (20:18 +0900)
committerLorenzo Colitti <lorenzo@google.com>
Mon, 2 Feb 2015 08:47:25 +0000 (17:47 +0900)
Change-Id: I24e04f691cb5688d87da0b880ce6000fcc22c781

tests/net_test/iproute.py

index 268c03f..44c3073 100644 (file)
@@ -6,6 +6,9 @@ import os
 import socket
 import struct
 
+import cstruct
+
+
 ### Base netlink constants. See include/uapi/linux/netlink.h.
 NETLINK_ROUTE = 0
 
@@ -14,33 +17,42 @@ NLM_F_REQUEST = 1
 NLM_F_ACK = 4
 NLM_F_EXCL = 0x200
 NLM_F_CREATE = 0x400
+NLM_F_DUMP = 0x300
 
 # Message types.
 NLMSG_ERROR = 2
 
 # Data structure formats.
-STRUCT_NLMSGHDR = "=LHHLL"
-STRUCT_NLMSGERR = "=i"
-STRUCT_NLATTR = "=HH"
+# These aren't constants, they're classes. So, pylint: disable=invalid-name
+NLMsgHdr = cstruct.Struct("NLMsgHdr", "=LHHLL", "length type flags seq pid")
+NLMsgErr = cstruct.Struct("NLMsgErr", "=i", "error")
+NLAttr = cstruct.Struct("NLAttr", "=HH", "nla_len nla_type")
 
 
 ### rtnetlink constants. See include/uapi/linux/rtnetlink.h.
 # Message types.
 RTM_NEWRULE = 32
 RTM_DELRULE = 33
+RTM_GETRULE = 34
 
 # Routing message type values (rtm_type).
 RTN_UNSPEC = 0
 RTN_UNICAST = 1
 
 # Routing protocol values (rtm_protocol).
+RTPROT_UNSPEC = 0
 RTPROT_STATIC = 4
 
 # Route scope values (rtm_scope).
 RT_SCOPE_UNIVERSE = 0
 
+# Named routing tables.
+RT_TABLE_UNSPEC = 0
+
 # Data structure formats.
-STRUCT_RTMSG = "=BBBBBBBBI"
+RTMsg = cstruct.Struct(
+    "RTMsg", "=BBBBBBBBI",
+    "family dst_len src_len tos table protocol scope type flags")
 
 
 ### FIB rule constants. See include/uapi/linux/fib_rules.h.
@@ -49,23 +61,16 @@ FRA_FWMARK = 10
 FRA_TABLE = 15
 
 
-def Unpack(fmt, data):
-  """Unpacks a data structure with variable-size contents at the end."""
-  size = struct.calcsize(fmt)
-  data, remainder = data[:size], data[size:]
-  return struct.unpack(fmt, data), remainder
-
-
 class IPRoute(object):
 
   """Provides a tiny subset of iproute functionality."""
 
-  BUFSIZE = 1024
+  BUFSIZE = 65536
 
   def _NlAttrU32(self, nla_type, value):
     data = struct.pack("=I", value)
-    nla_len = struct.calcsize(STRUCT_NLATTR) + len(data)
-    return struct.pack(STRUCT_NLATTR, nla_len, nla_type) + data
+    nla_len = len(data) + len(NLAttr)
+    return NLAttr((nla_len, nla_type)).Pack() + data
 
   def __init__(self):
     # Global sequence number.
@@ -82,6 +87,17 @@ class IPRoute(object):
   def _Recv(self):
     return self.sock.recv(self.BUFSIZE)
 
+  def _ExpectAck(self):
+    # Find the error code.
+    response = self._Recv()
+    hdr, data = cstruct.Read(response, NLMsgHdr)
+    if hdr.type == NLMSG_ERROR:
+      error = NLMsgErr(data).error
+      if error:
+        raise IOError(error, os.strerror(-error))
+    else:
+      raise ValueError("Unexpected netlink ACK type %d" % hdr.type)
+
   def _Rule(self, version, is_add, table, match_nlattr, priority):
     """Python equivalent of "ip rule <add|del> <match_cond> lookup <table>".
 
@@ -100,8 +116,8 @@ class IPRoute(object):
 
     # Create a struct rtmsg specifying the table and the given match attributes.
     family = {4: socket.AF_INET, 6: socket.AF_INET6}[version]
-    rtmsg = struct.pack(STRUCT_RTMSG, family, 0, 0, 0, 0,
-                        RTPROT_STATIC, RT_SCOPE_UNIVERSE, RTN_UNICAST, 0)
+    rtmsg = RTMsg((family, 0, 0, 0, RT_TABLE_UNSPEC,
+                   RTPROT_STATIC, RT_SCOPE_UNIVERSE, RTN_UNICAST, 0)).Pack()
     rtmsg += self._NlAttrU32(FRA_PRIORITY, priority)
     rtmsg += match_nlattr
     rtmsg += self._NlAttrU32(FRA_TABLE, table)
@@ -112,24 +128,55 @@ class IPRoute(object):
     if is_add:
       flags |= (NLM_F_EXCL | NLM_F_CREATE)
 
-    # Fill in the length field.
-    length = struct.calcsize(STRUCT_NLMSGHDR) + len(rtmsg)
-    nlmsg = struct.pack(STRUCT_NLMSGHDR, length, command, flags,
-                        self.seq, self.pid) + rtmsg
+    length = len(NLMsgHdr) + len(rtmsg)
+    nlmsg = NLMsgHdr((length, command, flags, self.seq, self.pid)).Pack()
 
     # Send the message and block forever until we receive a response.
-    self._Send(nlmsg)
-    response = self._Recv()
+    self._Send(nlmsg + rtmsg)
 
-    # Find the error code.
-    (_, msgtype, _, _, _), msg = Unpack(STRUCT_NLMSGHDR, response)
-    if msgtype == NLMSG_ERROR:
-      ((error,), _) = Unpack(STRUCT_NLMSGERR, msg)
-      if error:
-        raise IOError(error, os.strerror(-error))
-    else:
-      raise ValueError("Unexpected netlink ACK type %d" % msgtype)
+    # Expect a successful ACK.
+    self._ExpectAck()
 
   def FwmarkRule(self, version, is_add, fwmark, table, priority=16383):
     nlattr = self._NlAttrU32(FRA_FWMARK, fwmark)
     return self._Rule(version, is_add, table, nlattr, priority)
+
+  def DumpRules(self, version):
+    """Returns the IP rules for the specified IP version."""
+    # Create a struct rtmsg specifying the table and the given match attributes.
+    family = {4: socket.AF_INET, 6: socket.AF_INET6}[version]
+    rtmsg = RTMsg((family, 0, 0, 0, 0, 0, 0, 0, 0))
+
+    # Create a netlink dump request containing the rtmsg.
+    command = RTM_GETRULE
+    flags = NLM_F_DUMP | NLM_F_REQUEST
+    length = len(NLMsgHdr) + len(rtmsg)
+    nlmsghdr = NLMsgHdr((length, command, flags, self.seq, self.pid))
+
+    self._Send(nlmsghdr.Pack() + rtmsg.Pack())
+    data = self._Recv()
+
+    rules = []
+    while data:
+      # Parse the netlink and rtmsg headers.
+      nlmsghdr, data = cstruct.Read(data, NLMsgHdr)
+      rtmsg, data = cstruct.Read(data, RTMsg)
+
+      # Parse the attributes in the rtmsg.
+      attributes = []
+      bytesleft = nlmsghdr.length - len(nlmsghdr) - len(rtmsg)
+      while bytesleft:
+        # Read the nlattr header.
+        nla, data = cstruct.Read(data, NLAttr)
+
+        # Read the data. We don't know how to parse attributes, so just return
+        # them as raw bytes.
+        datalen = nla.nla_len - len(nla)
+        nla_data, data = data[:datalen], data[datalen:]
+
+        attributes.append((nla, nla_data))
+        bytesleft -= nla.nla_len
+
+      rules.append((rtmsg, attributes))
+
+    return rules