import socket
import struct
+import cstruct
+
+
### Base netlink constants. See include/uapi/linux/netlink.h.
NETLINK_ROUTE = 0
NLM_F_ACK = 4
NLM_F_EXCL = 0x200
NLM_F_CREATE = 0x400
+NLM_F_DUMP = 0x300
# Message types.
NLMSG_ERROR = 2
# Data structure formats.
-STRUCT_NLMSGHDR = "=LHHLL"
-STRUCT_NLMSGERR = "=i"
-STRUCT_NLATTR = "=HH"
+# These aren't constants, they're classes. So, pylint: disable=invalid-name
+NLMsgHdr = cstruct.Struct("NLMsgHdr", "=LHHLL", "length type flags seq pid")
+NLMsgErr = cstruct.Struct("NLMsgErr", "=i", "error")
+NLAttr = cstruct.Struct("NLAttr", "=HH", "nla_len nla_type")
### rtnetlink constants. See include/uapi/linux/rtnetlink.h.
# Message types.
RTM_NEWRULE = 32
RTM_DELRULE = 33
+RTM_GETRULE = 34
# Routing message type values (rtm_type).
RTN_UNSPEC = 0
RTN_UNICAST = 1
# Routing protocol values (rtm_protocol).
+RTPROT_UNSPEC = 0
RTPROT_STATIC = 4
# Route scope values (rtm_scope).
RT_SCOPE_UNIVERSE = 0
+# Named routing tables.
+RT_TABLE_UNSPEC = 0
+
# Data structure formats.
-STRUCT_RTMSG = "=BBBBBBBBI"
+RTMsg = cstruct.Struct(
+ "RTMsg", "=BBBBBBBBI",
+ "family dst_len src_len tos table protocol scope type flags")
### FIB rule constants. See include/uapi/linux/fib_rules.h.
FRA_TABLE = 15
-def Unpack(fmt, data):
- """Unpacks a data structure with variable-size contents at the end."""
- size = struct.calcsize(fmt)
- data, remainder = data[:size], data[size:]
- return struct.unpack(fmt, data), remainder
-
-
class IPRoute(object):
"""Provides a tiny subset of iproute functionality."""
- BUFSIZE = 1024
+ BUFSIZE = 65536
def _NlAttrU32(self, nla_type, value):
data = struct.pack("=I", value)
- nla_len = struct.calcsize(STRUCT_NLATTR) + len(data)
- return struct.pack(STRUCT_NLATTR, nla_len, nla_type) + data
+ nla_len = len(data) + len(NLAttr)
+ return NLAttr((nla_len, nla_type)).Pack() + data
def __init__(self):
# Global sequence number.
def _Recv(self):
return self.sock.recv(self.BUFSIZE)
+ def _ExpectAck(self):
+ # Find the error code.
+ response = self._Recv()
+ hdr, data = cstruct.Read(response, NLMsgHdr)
+ if hdr.type == NLMSG_ERROR:
+ error = NLMsgErr(data).error
+ if error:
+ raise IOError(error, os.strerror(-error))
+ else:
+ raise ValueError("Unexpected netlink ACK type %d" % hdr.type)
+
def _Rule(self, version, is_add, table, match_nlattr, priority):
"""Python equivalent of "ip rule <add|del> <match_cond> lookup <table>".
# Create a struct rtmsg specifying the table and the given match attributes.
family = {4: socket.AF_INET, 6: socket.AF_INET6}[version]
- rtmsg = struct.pack(STRUCT_RTMSG, family, 0, 0, 0, 0,
- RTPROT_STATIC, RT_SCOPE_UNIVERSE, RTN_UNICAST, 0)
+ rtmsg = RTMsg((family, 0, 0, 0, RT_TABLE_UNSPEC,
+ RTPROT_STATIC, RT_SCOPE_UNIVERSE, RTN_UNICAST, 0)).Pack()
rtmsg += self._NlAttrU32(FRA_PRIORITY, priority)
rtmsg += match_nlattr
rtmsg += self._NlAttrU32(FRA_TABLE, table)
if is_add:
flags |= (NLM_F_EXCL | NLM_F_CREATE)
- # Fill in the length field.
- length = struct.calcsize(STRUCT_NLMSGHDR) + len(rtmsg)
- nlmsg = struct.pack(STRUCT_NLMSGHDR, length, command, flags,
- self.seq, self.pid) + rtmsg
+ length = len(NLMsgHdr) + len(rtmsg)
+ nlmsg = NLMsgHdr((length, command, flags, self.seq, self.pid)).Pack()
# Send the message and block forever until we receive a response.
- self._Send(nlmsg)
- response = self._Recv()
+ self._Send(nlmsg + rtmsg)
- # Find the error code.
- (_, msgtype, _, _, _), msg = Unpack(STRUCT_NLMSGHDR, response)
- if msgtype == NLMSG_ERROR:
- ((error,), _) = Unpack(STRUCT_NLMSGERR, msg)
- if error:
- raise IOError(error, os.strerror(-error))
- else:
- raise ValueError("Unexpected netlink ACK type %d" % msgtype)
+ # Expect a successful ACK.
+ self._ExpectAck()
def FwmarkRule(self, version, is_add, fwmark, table, priority=16383):
nlattr = self._NlAttrU32(FRA_FWMARK, fwmark)
return self._Rule(version, is_add, table, nlattr, priority)
+
+ def DumpRules(self, version):
+ """Returns the IP rules for the specified IP version."""
+ # Create a struct rtmsg specifying the table and the given match attributes.
+ family = {4: socket.AF_INET, 6: socket.AF_INET6}[version]
+ rtmsg = RTMsg((family, 0, 0, 0, 0, 0, 0, 0, 0))
+
+ # Create a netlink dump request containing the rtmsg.
+ command = RTM_GETRULE
+ flags = NLM_F_DUMP | NLM_F_REQUEST
+ length = len(NLMsgHdr) + len(rtmsg)
+ nlmsghdr = NLMsgHdr((length, command, flags, self.seq, self.pid))
+
+ self._Send(nlmsghdr.Pack() + rtmsg.Pack())
+ data = self._Recv()
+
+ rules = []
+ while data:
+ # Parse the netlink and rtmsg headers.
+ nlmsghdr, data = cstruct.Read(data, NLMsgHdr)
+ rtmsg, data = cstruct.Read(data, RTMsg)
+
+ # Parse the attributes in the rtmsg.
+ attributes = []
+ bytesleft = nlmsghdr.length - len(nlmsghdr) - len(rtmsg)
+ while bytesleft:
+ # Read the nlattr header.
+ nla, data = cstruct.Read(data, NLAttr)
+
+ # Read the data. We don't know how to parse attributes, so just return
+ # them as raw bytes.
+ datalen = nla.nla_len - len(nla)
+ nla_data, data = data[:datalen], data[datalen:]
+
+ attributes.append((nla, nla_data))
+ bytesleft -= nla.nla_len
+
+ rules.append((rtmsg, attributes))
+
+ return rules