diff options
author | Lorenzo Colitti <lorenzo@google.com> | 2016-01-14 11:49:33 +0900 |
---|---|---|
committer | Lorenzo Colitti <lorenzo@google.com> | 2016-01-14 20:31:54 +0900 |
commit | 093d6d4da21f624c7d2e85a45b6afb062b3ea222 (patch) | |
tree | 948ff990325317c2b80607346fb05b597aa03b6e /tests | |
parent | 59c72161d70447232ee646d2cb9b756d76db91f3 (diff) | |
download | extras-093d6d4da21f624c7d2e85a45b6afb062b3ea222.tar.gz |
Add code and tests for inet_diag bytecode.
Change-Id: I02af43151cf14905cc762455f282cb7fa5a1b003
Diffstat (limited to 'tests')
-rw-r--r-- | tests/net_test/netlink.py | 7 | ||||
-rwxr-xr-x | tests/net_test/sock_diag.py | 114 | ||||
-rwxr-xr-x | tests/net_test/sock_diag_test.py | 55 |
3 files changed, 165 insertions, 11 deletions
diff --git a/tests/net_test/netlink.py b/tests/net_test/netlink.py index 479d3a4d..2b8f7449 100644 --- a/tests/net_test/netlink.py +++ b/tests/net_test/netlink.py @@ -79,11 +79,12 @@ class NetlinkSocket(object): def _GetConstantName(self, module, value, prefix): thismodule = sys.modules[module] for name in dir(thismodule): + if name.startswith("INET_DIAG_BC"): + break if (name.startswith(prefix) and not name.startswith(prefix + "F_") and - name.isupper() and - getattr(thismodule, name) == value): - return name + name.isupper() and getattr(thismodule, name) == value): + return name return value def _Decode(self, command, msg, nla_type, nla_data): diff --git a/tests/net_test/sock_diag.py b/tests/net_test/sock_diag.py index cfce751a..5cb83cf3 100755 --- a/tests/net_test/sock_diag.py +++ b/tests/net_test/sock_diag.py @@ -20,6 +20,7 @@ import errno from socket import * # pylint: disable=wildcard-import +import struct import cstruct import net_test @@ -37,6 +38,9 @@ SOCK_DESTROY = 21 # Message types. TCPDIAG_GETSOCK = 18 +# Request attributes. +INET_DIAG_REQ_BYTECODE = 1 + # Extensions. INET_DIAG_NONE = 0 INET_DIAG_MEMINFO = 1 @@ -49,6 +53,17 @@ INET_DIAG_SKMEMINFO = 7 INET_DIAG_SHUTDOWN = 8 INET_DIAG_DCTCPINFO = 9 +# Bytecode operations. +INET_DIAG_BC_NOP = 0 +INET_DIAG_BC_JMP = 1 +INET_DIAG_BC_S_GE = 2 +INET_DIAG_BC_S_LE = 3 +INET_DIAG_BC_D_GE = 4 +INET_DIAG_BC_D_LE = 5 +INET_DIAG_BC_AUTO = 6 +INET_DIAG_BC_S_COND = 7 +INET_DIAG_BC_D_COND = 8 + # Data structure formats. # These aren't constants, they're classes. So, pylint: disable=invalid-name InetDiagSockId = cstruct.Struct( @@ -62,6 +77,9 @@ InetDiagMsg = cstruct.Struct( [InetDiagSockId]) InetDiagMeminfo = cstruct.Struct( "InetDiagMeminfo", "=IIII", "rmem wmem fmem tmem") +InetDiagBcOp = cstruct.Struct("InetDiagBcOp", "BBH", "code yes no") +InetDiagHostcond = cstruct.Struct("InetDiagHostcond", "=BBxxi", + "family prefix_len port") SkMeminfo = cstruct.Struct( "SkMeminfo", "=IIIIIIII", @@ -133,11 +151,94 @@ class SockDiag(netlink.NetlinkSocket): def _EmptyInetDiagSockId(): return InetDiagSockId(("\x00" * len(InetDiagSockId))) - def Dump(self, diag_req): - out = self._Dump(SOCK_DIAG_BY_FAMILY, diag_req, InetDiagMsg, "") + def PackBytecode(self, instructions): + """Compiles instructions to inet_diag bytecode. + + The input is a list of (INET_DIAG_BC_xxx, yes, no, arg) tuples, where yes + and no are relative jump offsets measured in instructions. The yes branch + is taken if the instruction matches. + + To accept, jump 1 past the last instruction. To reject, jump 2 past the + last instruction. + + The target of a no jump is only valid if it is reachable by following + only yes jumps from the first instruction - see inet_diag_bc_audit and + valid_cc. This means that if cond1 and cond2 are two mutually exclusive + filter terms, it is not possible to implement cond1 OR cond2 using: + + ... + cond1 2 1 arg + cond2 1 2 arg + accept + reject + + but only using: + + ... + cond1 1 2 arg + jmp 1 2 + cond2 1 2 arg + accept + reject + + The jmp instruction ignores yes and always jumps to no, but yes must be 1 + or the bytecode won't validate. It doesn't have to be jmp - any instruction + that is guaranteed not to match on real data will do. + + Args: + instructions: list of instruction tuples + + Returns: + A string, the raw bytecode. + """ + args = [] + positions = [0] + + for op, yes, no, arg in instructions: + + if yes <= 0 or no <= 0: + raise ValueError("Jumps must be > 0") + + if op in [INET_DIAG_BC_NOP, INET_DIAG_BC_JMP, INET_DIAG_BC_AUTO]: + arg = "" + elif op in [INET_DIAG_BC_S_GE, INET_DIAG_BC_S_LE, + INET_DIAG_BC_D_GE, INET_DIAG_BC_D_LE]: + arg = "\x00\x00" + struct.pack("=H", arg) + elif op in [INET_DIAG_BC_S_COND, INET_DIAG_BC_D_COND]: + addr, prefixlen, port = arg + family = AF_INET6 if ":" in addr else AF_INET + addr = inet_pton(family, addr) + arg = InetDiagHostcond((family, prefixlen, port)).Pack() + addr + else: + raise ValueError("Unsupported opcode %d" % op) + + args.append(arg) + length = len(InetDiagBcOp) + len(arg) + positions.append(positions[-1] + length) + + # Reject label. + positions.append(positions[-1] + 4) # Why 4? Because the kernel uses 4. + assert len(args) == len(instructions) == len(positions) - 2 + + # print positions + + packed = "" + for i, (op, yes, no, arg) in enumerate(instructions): + yes = positions[i + yes] - positions[i] + no = positions[i + no] - positions[i] + instruction = InetDiagBcOp((op, yes, no)).Pack() + args[i] + #print "%3d: %d %3d %3d %s %s" % (positions[i], op, yes, no, + # arg, instruction.encode("hex")) + packed += instruction + #print + + return packed + + def Dump(self, diag_req, bytecode=""): + out = self._Dump(SOCK_DIAG_BY_FAMILY, diag_req, InetDiagMsg, bytecode) return out - def DumpAllInetSockets(self, protocol, sock_id=None, ext=0, + def DumpAllInetSockets(self, protocol, bytecode, sock_id=None, ext=0, states=ALL_NON_TIME_WAIT): """Dumps IPv4 or IPv6 sockets matching the specified parameters.""" # DumpSockets(AF_UNSPEC) does not result in dumping all inet sockets, it @@ -145,10 +246,13 @@ class SockDiag(netlink.NetlinkSocket): if sock_id is None: sock_id = self._EmptyInetDiagSockId() + if bytecode: + bytecode = self._NlAttr(INET_DIAG_REQ_BYTECODE, bytecode) + sockets = [] for family in [AF_INET, AF_INET6]: diag_req = InetDiagReqV2((family, protocol, ext, states, sock_id)) - sockets += self.Dump(diag_req) + sockets += self.Dump(diag_req, bytecode) return sockets @@ -255,6 +359,6 @@ if __name__ == "__main__": sock_id.dport = 443 ext = 1 << (INET_DIAG_TOS - 1) | 1 << (INET_DIAG_TCLASS - 1) states = 0xffffffff - diag_msgs = n.DumpAllInetSockets(IPPROTO_TCP, + diag_msgs = n.DumpAllInetSockets(IPPROTO_TCP, "", sock_id=sock_id, ext=ext, states=states) print diag_msgs diff --git a/tests/net_test/sock_diag_test.py b/tests/net_test/sock_diag_test.py index 97df3bfb..b13befe9 100755 --- a/tests/net_test/sock_diag_test.py +++ b/tests/net_test/sock_diag_test.py @@ -31,6 +31,7 @@ import threading NUM_SOCKETS = 100 +NO_BYTECODE = "" # TODO: Backport SOCK_DESTROY and delete this. HAVE_SOCK_DESTROY = net_test.LINUX_VERSION >= (4, 4) @@ -115,7 +116,7 @@ class SockDiagTest(SockDiagBaseTest): def testFindsAllMySockets(self): self.socketpairs = self._CreateLotsOfSockets() - sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP) + sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE) self.assertGreaterEqual(len(sockets), NUM_SOCKETS) # Find the cookies for all of our sockets. @@ -149,6 +150,54 @@ class SockDiagTest(SockDiagBaseTest): diag_msg = self.sock_diag.GetSockDiag(req) self.assertSockDiagMatchesSocket(sock, diag_msg) + def testBytecodeCompilation(self): + instructions = [ + (sock_diag.INET_DIAG_BC_S_GE, 1, 8, 0), # 0 + (sock_diag.INET_DIAG_BC_D_LE, 1, 7, 0xffff), # 8 + (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::1", 128, -1)), # 16 + (sock_diag.INET_DIAG_BC_JMP, 1, 3, None), # 44 + (sock_diag.INET_DIAG_BC_S_COND, 2, 4, ("127.0.0.1", 32, -1)), # 48 + (sock_diag.INET_DIAG_BC_D_LE, 1, 3, 0x6665), # not used # 64 + (sock_diag.INET_DIAG_BC_NOP, 1, 1, None), # 72 + # 76 acc + # 80 rej + ] + bytecode = self.sock_diag.PackBytecode(instructions) + expected = ( + "0208500000000000" + "050848000000ffff" + "071c20000a800000ffffffff00000000000000000000000000000001" + "01041c00" + "0718200002200000ffffffff7f000001" + "0508100000006566" + "00040400" + ) + self.assertMultiLineEqual(expected, bytecode.encode("hex")) + self.assertEquals(76, len(bytecode)) + self.socketpairs = self._CreateLotsOfSockets() + filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode) + allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE) + self.assertEquals(len(allsockets), len(filteredsockets)) + + # Pick a few sockets in hash table order, and check that the bytecode we + # compiled selects them properly. + for socketpair in self.socketpairs.values()[:20]: + for s in socketpair: + diag_msg = self.sock_diag.FindSockDiagFromFd(s) + instructions = [ + (sock_diag.INET_DIAG_BC_S_GE, 1, 5, diag_msg.id.sport), + (sock_diag.INET_DIAG_BC_S_LE, 1, 4, diag_msg.id.sport), + (sock_diag.INET_DIAG_BC_D_GE, 1, 3, diag_msg.id.dport), + (sock_diag.INET_DIAG_BC_D_LE, 1, 2, diag_msg.id.dport), + ] + bytecode = self.sock_diag.PackBytecode(instructions) + self.assertEquals(32, len(bytecode)) + sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode) + self.assertEquals(1, len(sockets)) + + # TODO: why doesn't comparing the cstructs work? + self.assertEquals(diag_msg.Pack(), sockets[0][0].Pack()) + @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") def testClosesSockets(self): self.socketpairs = self._CreateLotsOfSockets() @@ -356,7 +405,7 @@ class TcpTest(SockDiagBaseTest): req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) req.states = 1 << sock_diag.TCP_SYN_RECV | 1 << sock_diag.TCP_ESTABLISHED req.id.cookie = "\x00" * 8 - children = self.sock_diag.Dump(req) + children = self.sock_diag.Dump(req, NO_BYTECODE) return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) for d, _ in children] @@ -486,7 +535,7 @@ class TcpTest(SockDiagBaseTest): sock_id.sport = self.port states = 1 << sock_diag.TCP_SYN_RECV req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id)) - children = self.sock_diag.Dump(req) + children = self.sock_diag.Dump(req, NO_BYTECODE) self.assertTrue(children) for child, unused_args in children: |