summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorLorenzo Colitti <lorenzo@google.com>2016-01-14 11:49:33 +0900
committerLorenzo Colitti <lorenzo@google.com>2016-01-14 20:31:54 +0900
commit093d6d4da21f624c7d2e85a45b6afb062b3ea222 (patch)
tree948ff990325317c2b80607346fb05b597aa03b6e /tests
parent59c72161d70447232ee646d2cb9b756d76db91f3 (diff)
downloadextras-093d6d4da21f624c7d2e85a45b6afb062b3ea222.tar.gz
Add code and tests for inet_diag bytecode.
Change-Id: I02af43151cf14905cc762455f282cb7fa5a1b003
Diffstat (limited to 'tests')
-rw-r--r--tests/net_test/netlink.py7
-rwxr-xr-xtests/net_test/sock_diag.py114
-rwxr-xr-xtests/net_test/sock_diag_test.py55
3 files changed, 165 insertions, 11 deletions
diff --git a/tests/net_test/netlink.py b/tests/net_test/netlink.py
index 479d3a4d..2b8f7449 100644
--- a/tests/net_test/netlink.py
+++ b/tests/net_test/netlink.py
@@ -79,11 +79,12 @@ class NetlinkSocket(object):
def _GetConstantName(self, module, value, prefix):
thismodule = sys.modules[module]
for name in dir(thismodule):
+ if name.startswith("INET_DIAG_BC"):
+ break
if (name.startswith(prefix) and
not name.startswith(prefix + "F_") and
- name.isupper() and
- getattr(thismodule, name) == value):
- return name
+ name.isupper() and getattr(thismodule, name) == value):
+ return name
return value
def _Decode(self, command, msg, nla_type, nla_data):
diff --git a/tests/net_test/sock_diag.py b/tests/net_test/sock_diag.py
index cfce751a..5cb83cf3 100755
--- a/tests/net_test/sock_diag.py
+++ b/tests/net_test/sock_diag.py
@@ -20,6 +20,7 @@
import errno
from socket import * # pylint: disable=wildcard-import
+import struct
import cstruct
import net_test
@@ -37,6 +38,9 @@ SOCK_DESTROY = 21
# Message types.
TCPDIAG_GETSOCK = 18
+# Request attributes.
+INET_DIAG_REQ_BYTECODE = 1
+
# Extensions.
INET_DIAG_NONE = 0
INET_DIAG_MEMINFO = 1
@@ -49,6 +53,17 @@ INET_DIAG_SKMEMINFO = 7
INET_DIAG_SHUTDOWN = 8
INET_DIAG_DCTCPINFO = 9
+# Bytecode operations.
+INET_DIAG_BC_NOP = 0
+INET_DIAG_BC_JMP = 1
+INET_DIAG_BC_S_GE = 2
+INET_DIAG_BC_S_LE = 3
+INET_DIAG_BC_D_GE = 4
+INET_DIAG_BC_D_LE = 5
+INET_DIAG_BC_AUTO = 6
+INET_DIAG_BC_S_COND = 7
+INET_DIAG_BC_D_COND = 8
+
# Data structure formats.
# These aren't constants, they're classes. So, pylint: disable=invalid-name
InetDiagSockId = cstruct.Struct(
@@ -62,6 +77,9 @@ InetDiagMsg = cstruct.Struct(
[InetDiagSockId])
InetDiagMeminfo = cstruct.Struct(
"InetDiagMeminfo", "=IIII", "rmem wmem fmem tmem")
+InetDiagBcOp = cstruct.Struct("InetDiagBcOp", "BBH", "code yes no")
+InetDiagHostcond = cstruct.Struct("InetDiagHostcond", "=BBxxi",
+ "family prefix_len port")
SkMeminfo = cstruct.Struct(
"SkMeminfo", "=IIIIIIII",
@@ -133,11 +151,94 @@ class SockDiag(netlink.NetlinkSocket):
def _EmptyInetDiagSockId():
return InetDiagSockId(("\x00" * len(InetDiagSockId)))
- def Dump(self, diag_req):
- out = self._Dump(SOCK_DIAG_BY_FAMILY, diag_req, InetDiagMsg, "")
+ def PackBytecode(self, instructions):
+ """Compiles instructions to inet_diag bytecode.
+
+ The input is a list of (INET_DIAG_BC_xxx, yes, no, arg) tuples, where yes
+ and no are relative jump offsets measured in instructions. The yes branch
+ is taken if the instruction matches.
+
+ To accept, jump 1 past the last instruction. To reject, jump 2 past the
+ last instruction.
+
+ The target of a no jump is only valid if it is reachable by following
+ only yes jumps from the first instruction - see inet_diag_bc_audit and
+ valid_cc. This means that if cond1 and cond2 are two mutually exclusive
+ filter terms, it is not possible to implement cond1 OR cond2 using:
+
+ ...
+ cond1 2 1 arg
+ cond2 1 2 arg
+ accept
+ reject
+
+ but only using:
+
+ ...
+ cond1 1 2 arg
+ jmp 1 2
+ cond2 1 2 arg
+ accept
+ reject
+
+ The jmp instruction ignores yes and always jumps to no, but yes must be 1
+ or the bytecode won't validate. It doesn't have to be jmp - any instruction
+ that is guaranteed not to match on real data will do.
+
+ Args:
+ instructions: list of instruction tuples
+
+ Returns:
+ A string, the raw bytecode.
+ """
+ args = []
+ positions = [0]
+
+ for op, yes, no, arg in instructions:
+
+ if yes <= 0 or no <= 0:
+ raise ValueError("Jumps must be > 0")
+
+ if op in [INET_DIAG_BC_NOP, INET_DIAG_BC_JMP, INET_DIAG_BC_AUTO]:
+ arg = ""
+ elif op in [INET_DIAG_BC_S_GE, INET_DIAG_BC_S_LE,
+ INET_DIAG_BC_D_GE, INET_DIAG_BC_D_LE]:
+ arg = "\x00\x00" + struct.pack("=H", arg)
+ elif op in [INET_DIAG_BC_S_COND, INET_DIAG_BC_D_COND]:
+ addr, prefixlen, port = arg
+ family = AF_INET6 if ":" in addr else AF_INET
+ addr = inet_pton(family, addr)
+ arg = InetDiagHostcond((family, prefixlen, port)).Pack() + addr
+ else:
+ raise ValueError("Unsupported opcode %d" % op)
+
+ args.append(arg)
+ length = len(InetDiagBcOp) + len(arg)
+ positions.append(positions[-1] + length)
+
+ # Reject label.
+ positions.append(positions[-1] + 4) # Why 4? Because the kernel uses 4.
+ assert len(args) == len(instructions) == len(positions) - 2
+
+ # print positions
+
+ packed = ""
+ for i, (op, yes, no, arg) in enumerate(instructions):
+ yes = positions[i + yes] - positions[i]
+ no = positions[i + no] - positions[i]
+ instruction = InetDiagBcOp((op, yes, no)).Pack() + args[i]
+ #print "%3d: %d %3d %3d %s %s" % (positions[i], op, yes, no,
+ # arg, instruction.encode("hex"))
+ packed += instruction
+ #print
+
+ return packed
+
+ def Dump(self, diag_req, bytecode=""):
+ out = self._Dump(SOCK_DIAG_BY_FAMILY, diag_req, InetDiagMsg, bytecode)
return out
- def DumpAllInetSockets(self, protocol, sock_id=None, ext=0,
+ def DumpAllInetSockets(self, protocol, bytecode, sock_id=None, ext=0,
states=ALL_NON_TIME_WAIT):
"""Dumps IPv4 or IPv6 sockets matching the specified parameters."""
# DumpSockets(AF_UNSPEC) does not result in dumping all inet sockets, it
@@ -145,10 +246,13 @@ class SockDiag(netlink.NetlinkSocket):
if sock_id is None:
sock_id = self._EmptyInetDiagSockId()
+ if bytecode:
+ bytecode = self._NlAttr(INET_DIAG_REQ_BYTECODE, bytecode)
+
sockets = []
for family in [AF_INET, AF_INET6]:
diag_req = InetDiagReqV2((family, protocol, ext, states, sock_id))
- sockets += self.Dump(diag_req)
+ sockets += self.Dump(diag_req, bytecode)
return sockets
@@ -255,6 +359,6 @@ if __name__ == "__main__":
sock_id.dport = 443
ext = 1 << (INET_DIAG_TOS - 1) | 1 << (INET_DIAG_TCLASS - 1)
states = 0xffffffff
- diag_msgs = n.DumpAllInetSockets(IPPROTO_TCP,
+ diag_msgs = n.DumpAllInetSockets(IPPROTO_TCP, "",
sock_id=sock_id, ext=ext, states=states)
print diag_msgs
diff --git a/tests/net_test/sock_diag_test.py b/tests/net_test/sock_diag_test.py
index 97df3bfb..b13befe9 100755
--- a/tests/net_test/sock_diag_test.py
+++ b/tests/net_test/sock_diag_test.py
@@ -31,6 +31,7 @@ import threading
NUM_SOCKETS = 100
+NO_BYTECODE = ""
# TODO: Backport SOCK_DESTROY and delete this.
HAVE_SOCK_DESTROY = net_test.LINUX_VERSION >= (4, 4)
@@ -115,7 +116,7 @@ class SockDiagTest(SockDiagBaseTest):
def testFindsAllMySockets(self):
self.socketpairs = self._CreateLotsOfSockets()
- sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP)
+ sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE)
self.assertGreaterEqual(len(sockets), NUM_SOCKETS)
# Find the cookies for all of our sockets.
@@ -149,6 +150,54 @@ class SockDiagTest(SockDiagBaseTest):
diag_msg = self.sock_diag.GetSockDiag(req)
self.assertSockDiagMatchesSocket(sock, diag_msg)
+ def testBytecodeCompilation(self):
+ instructions = [
+ (sock_diag.INET_DIAG_BC_S_GE, 1, 8, 0), # 0
+ (sock_diag.INET_DIAG_BC_D_LE, 1, 7, 0xffff), # 8
+ (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::1", 128, -1)), # 16
+ (sock_diag.INET_DIAG_BC_JMP, 1, 3, None), # 44
+ (sock_diag.INET_DIAG_BC_S_COND, 2, 4, ("127.0.0.1", 32, -1)), # 48
+ (sock_diag.INET_DIAG_BC_D_LE, 1, 3, 0x6665), # not used # 64
+ (sock_diag.INET_DIAG_BC_NOP, 1, 1, None), # 72
+ # 76 acc
+ # 80 rej
+ ]
+ bytecode = self.sock_diag.PackBytecode(instructions)
+ expected = (
+ "0208500000000000"
+ "050848000000ffff"
+ "071c20000a800000ffffffff00000000000000000000000000000001"
+ "01041c00"
+ "0718200002200000ffffffff7f000001"
+ "0508100000006566"
+ "00040400"
+ )
+ self.assertMultiLineEqual(expected, bytecode.encode("hex"))
+ self.assertEquals(76, len(bytecode))
+ self.socketpairs = self._CreateLotsOfSockets()
+ filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode)
+ allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE)
+ self.assertEquals(len(allsockets), len(filteredsockets))
+
+ # Pick a few sockets in hash table order, and check that the bytecode we
+ # compiled selects them properly.
+ for socketpair in self.socketpairs.values()[:20]:
+ for s in socketpair:
+ diag_msg = self.sock_diag.FindSockDiagFromFd(s)
+ instructions = [
+ (sock_diag.INET_DIAG_BC_S_GE, 1, 5, diag_msg.id.sport),
+ (sock_diag.INET_DIAG_BC_S_LE, 1, 4, diag_msg.id.sport),
+ (sock_diag.INET_DIAG_BC_D_GE, 1, 3, diag_msg.id.dport),
+ (sock_diag.INET_DIAG_BC_D_LE, 1, 2, diag_msg.id.dport),
+ ]
+ bytecode = self.sock_diag.PackBytecode(instructions)
+ self.assertEquals(32, len(bytecode))
+ sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode)
+ self.assertEquals(1, len(sockets))
+
+ # TODO: why doesn't comparing the cstructs work?
+ self.assertEquals(diag_msg.Pack(), sockets[0][0].Pack())
+
@unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
def testClosesSockets(self):
self.socketpairs = self._CreateLotsOfSockets()
@@ -356,7 +405,7 @@ class TcpTest(SockDiagBaseTest):
req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
req.states = 1 << sock_diag.TCP_SYN_RECV | 1 << sock_diag.TCP_ESTABLISHED
req.id.cookie = "\x00" * 8
- children = self.sock_diag.Dump(req)
+ children = self.sock_diag.Dump(req, NO_BYTECODE)
return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
for d, _ in children]
@@ -486,7 +535,7 @@ class TcpTest(SockDiagBaseTest):
sock_id.sport = self.port
states = 1 << sock_diag.TCP_SYN_RECV
req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id))
- children = self.sock_diag.Dump(req)
+ children = self.sock_diag.Dump(req, NO_BYTECODE)
self.assertTrue(children)
for child, unused_args in children: