diff options
author | Chenbo Feng <fengc@google.com> | 2017-06-23 02:53:39 +0000 |
---|---|---|
committer | Gerrit Code Review <noreply-gerritcodereview@google.com> | 2017-06-23 02:53:41 +0000 |
commit | 48901fd8e9b1c298927ce8961c2f093cbb733794 (patch) | |
tree | 612b7844880f828bb6c30359ffc3cf5539b7aba2 | |
parent | f2fa15277178f17931e4bbcb0ad3d30513ad16a5 (diff) | |
parent | fb6884381835d5eb9257a276c6b95d9ddf384cfd (diff) | |
download | tests-48901fd8e9b1c298927ce8961c2f093cbb733794.tar.gz |
Merge "Add tests for cgroup v2 bpf and new helper functions"
-rwxr-xr-x | net/test/bpf.py | 90 | ||||
-rwxr-xr-x | net/test/bpf_test.py | 372 | ||||
-rwxr-xr-x | net/test/run_net_test.sh | 1 | ||||
-rwxr-xr-x | net/test/sock_diag.py | 5 |
4 files changed, 339 insertions, 129 deletions
diff --git a/net/test/bpf.py b/net/test/bpf.py index 14ee613..9e8cf1e 100755 --- a/net/test/bpf.py +++ b/net/test/bpf.py @@ -36,6 +36,8 @@ BPF_MAP_GET_NEXT_KEY = 4 BPF_PROG_LOAD = 5 BPF_OBJ_PIN = 6 BPF_OBJ_GET = 7 +BPF_PROG_ATTACH = 8 +BPF_PROG_DETACH = 9 SO_ATTACH_BPF = 50 # BPF map type constant. @@ -51,6 +53,16 @@ BPF_PROG_TYPE_SOCKET_FILTER = 1 BPF_PROG_TYPE_KPROBE = 2 BPF_PROG_TYPE_SCHED_CLS = 3 BPF_PROG_TYPE_SCHED_ACT = 4 +BPF_PROG_TYPE_TRACEPOINT = 5 +BPF_PROG_TYPE_XDP = 6 +BPF_PROG_TYPE_PERF_EVENT = 7 +BPF_PROG_TYPE_CGROUP_SKB = 8 +BPF_PROG_TYPE_CGROUP_SOCK = 9 + +# BPF program attach type. +BPF_CGROUP_INET_INGRESS = 0 +BPF_CGROUP_INET_EGRESS = 1 +BPF_CGROUP_INET_SOCK_CREATE = 2 # BPF register constant BPF_REG_0 = 0 @@ -124,8 +136,12 @@ BPF_FUNC_unspec = 0 BPF_FUNC_map_lookup_elem = 1 BPF_FUNC_map_update_elem = 2 BPF_FUNC_map_delete_elem = 3 +BPF_FUNC_get_socket_cookie = 46 +BPF_FUNC_get_socket_uid = 47 -# BPF attr struct +# These object below belongs to the same kernel union and the types below +# (e.g., bpf_attr_create) aren't kernel struct names but just different +# variants of the union. BpfAttrCreate = cstruct.Struct("bpf_attr_create", "=IIII", "map_type key_size value_size max_entries") BpfAttrOps = cstruct.Struct("bpf_attr_ops", "=QQQQ", @@ -133,6 +149,8 @@ BpfAttrOps = cstruct.Struct("bpf_attr_ops", "=QQQQ", BpfAttrProgLoad = cstruct.Struct( "bpf_attr_prog_load", "=IIQQIIQI", "prog_type insn_cnt insns" " license log_level log_size log_buf kern_version") +BpfAttrProgAttach = cstruct.Struct( + "bpf_attr_prog_attach", "=III", "target_fd attach_bpf_fd attach_type") BpfInsn = cstruct.Struct("bpf_insn", "=BBhi", "code dst_src_reg off imm") libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True) @@ -140,12 +158,15 @@ HAVE_EBPF_SUPPORT = net_test.LINUX_VERSION >= (4, 4, 0) # BPF program syscalls -def CreateMap(map_type, key_size, value_size, max_entries): - attr = BpfAttrCreate((map_type, key_size, value_size, max_entries)) - ret = libc.syscall(__NR_bpf, BPF_MAP_CREATE, attr.CPointer(), len(attr)) +def BpfSyscall(op, attr): + ret = libc.syscall(__NR_bpf, op, attr.CPointer(), len(attr)) csocket.MaybeRaiseSocketError(ret) return ret +def CreateMap(map_type, key_size, value_size, max_entries): + attr = BpfAttrCreate((map_type, key_size, value_size, max_entries)) + return BpfSyscall(BPF_MAP_CREATE, attr) + def UpdateMap(map_fd, key, value, flags=0): c_value = ctypes.c_uint32(value) @@ -153,9 +174,7 @@ def UpdateMap(map_fd, key, value, flags=0): value_ptr = ctypes.addressof(c_value) key_ptr = ctypes.addressof(c_key) attr = BpfAttrOps((map_fd, key_ptr, value_ptr, flags)) - ret = libc.syscall(__NR_bpf, BPF_MAP_UPDATE_ELEM, - attr.CPointer(), len(attr)) - csocket.MaybeRaiseSocketError(ret) + BpfSyscall(BPF_MAP_UPDATE_ELEM, attr) def LookupMap(map_fd, key): @@ -163,9 +182,7 @@ def LookupMap(map_fd, key): c_key = ctypes.c_uint32(key) attr = BpfAttrOps( (map_fd, ctypes.addressof(c_key), ctypes.addressof(c_value), 0)) - ret = libc.syscall(__NR_bpf, BPF_MAP_LOOKUP_ELEM, - attr.CPointer(), len(attr)) - csocket.MaybeRaiseSocketError(ret) + BpfSyscall(BPF_MAP_LOOKUP_ELEM, attr) return c_value @@ -174,37 +191,44 @@ def GetNextKey(map_fd, key): c_next_key = ctypes.c_uint32(0) attr = BpfAttrOps( (map_fd, ctypes.addressof(c_key), ctypes.addressof(c_next_key), 0)) - ret = libc.syscall(__NR_bpf, BPF_MAP_GET_NEXT_KEY, - attr.CPointer(), len(attr)) - csocket.MaybeRaiseSocketError(ret) + BpfSyscall(BPF_MAP_GET_NEXT_KEY, attr) return c_next_key def DeleteMap(map_fd, key): c_key = ctypes.c_uint32(key) attr = BpfAttrOps((map_fd, ctypes.addressof(c_key), 0, 0)) - ret = libc.syscall(__NR_bpf, BPF_MAP_DELETE_ELEM, - attr.CPointer(), len(attr)) - csocket.MaybeRaiseSocketError(ret) + BpfSyscall(BPF_MAP_DELETE_ELEM, attr) -def BpfProgLoad(prog_type, insn_ptr, prog_len, insn_len): +def BpfProgLoad(prog_type, instructions): + bpf_prog = "".join(instructions) + insn_buff = ctypes.create_string_buffer(bpf_prog) gpl_license = ctypes.create_string_buffer(b"GPL") log_buf = ctypes.create_string_buffer(b"", LOG_SIZE) - attr = BpfAttrProgLoad( - (prog_type, prog_len / insn_len, insn_ptr, ctypes.addressof(gpl_license), - LOG_LEVEL, LOG_SIZE, ctypes.addressof(log_buf), 0)) - ret = libc.syscall(__NR_bpf, BPF_PROG_LOAD, attr.CPointer(), len(attr)) - csocket.MaybeRaiseSocketError(ret) - return ret - + attr = BpfAttrProgLoad((prog_type, len(insn_buff) / len(BpfInsn), + ctypes.addressof(insn_buff), + ctypes.addressof(gpl_license), LOG_LEVEL, + LOG_SIZE, ctypes.addressof(log_buf), 0)) + return BpfSyscall(BPF_PROG_LOAD, attr) +# Attach a socket eBPF filter to a target socket def BpfProgAttachSocket(sock_fd, prog_fd): prog_ptr = ctypes.c_uint32(prog_fd) ret = libc.setsockopt(sock_fd, socket.SOL_SOCKET, SO_ATTACH_BPF, ctypes.addressof(prog_ptr), ctypes.sizeof(prog_ptr)) csocket.MaybeRaiseSocketError(ret) +# Attach a eBPF filter to a cgroup +def BpfProgAttach(prog_fd, target_fd, prog_type): + attr = BpfAttrProgAttach((target_fd, prog_fd, prog_type)) + return BpfSyscall(BPF_PROG_ATTACH, attr) + +# Detach a eBPF filter from a cgroup +def BpfProgDetach(target_fd, prog_type): + attr = BpfAttrProgAttach((target_fd, 0, prog_type)) + return BpfSyscall(BPF_PROG_DETACH, attr) + # BPF program command constructors def BpfMov64Reg(dst, src): @@ -275,22 +299,8 @@ def BpfLoadMapFd(map_fd, dst): return insn1.Pack() + insn2.Pack() -def BpfFuncLookupMap(): - code = BPF_JMP | BPF_CALL - dst_src = 0 - ret = BpfInsn((code, dst_src, 0, BPF_FUNC_map_lookup_elem)) - return ret.Pack() - - -def BpfFuncUpdateMap(): - code = BPF_JMP | BPF_CALL - dst_src = 0 - ret = BpfInsn((code, dst_src, 0, BPF_FUNC_map_update_elem)) - return ret.Pack() - - -def BpfFuncDeleteMap(): +def BpfFuncCall(func): code = BPF_JMP | BPF_CALL dst_src = 0 - ret = BpfInsn((code, dst_src, 0, BPF_FUNC_map_delete_elem)) + ret = BpfInsn((code, dst_src, 0, func)) return ret.Pack() diff --git a/net/test/bpf_test.py b/net/test/bpf_test.py index 6d17423..33ef22d 100755 --- a/net/test/bpf_test.py +++ b/net/test/bpf_test.py @@ -18,14 +18,131 @@ import ctypes import errno import os import socket +import struct import unittest from bpf import * # pylint: disable=wildcard-import import csocket import net_test +import sock_diag libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True) HAVE_EBPF_SUPPORT = net_test.LINUX_VERSION >= (4, 4, 0) +HAVE_EBPF_ACCOUNTING = net_test.LINUX_VERSION >= (4, 9, 0) +KEY_SIZE = 8 +VALUE_SIZE = 4 +TOTAL_ENTRIES = 20 +# Offset to store the map key in stack register REG10 +key_offset = -8 +# Offset to store the map value in stack register REG10 +value_offset = -16 + +# Debug usage only. +def PrintMapInfo(map_fd): + # A random key that the map does not contain. + key = 10086 + while 1: + try: + nextKey = GetNextKey(map_fd, key).value + value = LookupMap(map_fd, nextKey) + print repr(nextKey) + " : " + repr(value.value) + key = nextKey + except: + print "no value" + break + + +# A dummy loopback function that causes a socket to send traffic to itself. +def SocketUDPLoopBack(packet_count, version, prog_fd): + family = {4: socket.AF_INET, 6: socket.AF_INET6}[version] + sock = socket.socket(family, socket.SOCK_DGRAM, 0) + if prog_fd is not None: + BpfProgAttachSocket(sock.fileno(), prog_fd) + net_test.SetNonBlocking(sock) + addr = {4: "127.0.0.1", 6: "::1"}[version] + sock.bind((addr, 0)) + addr = sock.getsockname() + sockaddr = csocket.Sockaddr(addr) + for i in xrange(packet_count): + sock.sendto("foo", addr) + data, retaddr = csocket.Recvfrom(sock, 4096, 0) + assert "foo" == data + assert sockaddr == retaddr + return sock + + +# The main code block for eBPF packet counting program. It takes a preloaded +# key from BPF_REG_0 and use it to look up the bpf map, if the element does not +# exist in the map yet, the program will update the map with a new <key, 1> +# pair. Otherwise it will jump to next code block to handle it. +# REG0: regiter storing return value from helper function and the final return +# value of eBPF program. +# REG1 - REG5: temporary register used for storing values and load parameters +# into eBPF helper function. After calling helper function, the value for these +# registers will be reset. +# REG6 - REG9: registers store values that will not be cleared when calling +# eBPF helper function. +# REG10: A stack stores values need to be accessed by the address. Program can +# retrieve the address of a value by specifying the position of the value in +# the stack. +def BpfFuncCountPacketInit(map_fd): + key_pos = BPF_REG_7 + insPackCountStart = [ + # Get a preloaded key from BPF_REG_0 and store it at BPF_REG_7 + BpfMov64Reg(key_pos, BPF_REG_10), + BpfAlu64Imm(BPF_ADD, key_pos, key_offset), + # Load map fd and look up the key in the map + BpfLoadMapFd(map_fd, BPF_REG_1), + BpfMov64Reg(BPF_REG_2, key_pos), + BpfFuncCall(BPF_FUNC_map_lookup_elem), + # if the map element already exist, jump out of this + # code block and let next part to handle it + BpfJumpImm(BPF_AND, BPF_REG_0, 0, 10), + BpfLoadMapFd(map_fd, BPF_REG_1), + BpfMov64Reg(BPF_REG_2, key_pos), + # Initial a new <key, value> pair with value equal to 1 and update to map + BpfStMem(BPF_W, BPF_REG_10, value_offset, 1), + BpfMov64Reg(BPF_REG_3, BPF_REG_10), + BpfAlu64Imm(BPF_ADD, BPF_REG_3, value_offset), + BpfMov64Imm(BPF_REG_4, 0), + BpfFuncCall(BPF_FUNC_map_update_elem) + ] + return insPackCountStart + + +INS_BPF_EXIT_BLOCK = [ + BpfMov64Imm(BPF_REG_0, 0), + BpfExitInsn() +] + +# Bpf instruction for cgroup bpf filter to accept a packet and exit. +INS_CGROUP_ACCEPT = [ + # Set return value to 1 and exit. + BpfMov64Imm(BPF_REG_0, 1), + BpfExitInsn() +] + +# Bpf instruction for socket bpf filter to accept a packet and exit. +INS_SK_FILTER_ACCEPT = [ + # Precondition: BPF_REG_6 = sk_buff context + # Load the packet length from BPF_REG_6 and store it in BPF_REG_0 as the + # return value. + BpfLdxMem(BPF_W, BPF_REG_0, BPF_REG_6, 0), + BpfExitInsn() +] + +# Update a existing map element with +1. +INS_PACK_COUNT_UPDATE = [ + # Precondition: BPF_REG_0 = Value retrieved from BPF maps + # Add one to the corresponding eBPF value field for a specific eBPF key. + BpfMov64Reg(BPF_REG_2, BPF_REG_0), + BpfMov64Imm(BPF_REG_1, 1), + BpfRawInsn(BPF_STX | BPF_XADD | BPF_W, BPF_REG_2, BPF_REG_1, 0, 0), +] + +INS_BPF_PARAM_STORE = [ + BpfStxMem(BPF_DW, BPF_REG_10, BPF_REG_0, key_offset), +] @unittest.skipUnless(HAVE_EBPF_SUPPORT, "eBPF function not fully supported") @@ -34,130 +151,207 @@ class BpfTest(net_test.NetworkTest): def setUp(self): self.map_fd = -1 self.prog_fd = -1 + self.sock = None def tearDown(self): if self.prog_fd >= 0: os.close(self.prog_fd) if self.map_fd >= 0: os.close(self.map_fd) + if self.sock: + self.sock.close() def testCreateMap(self): key, value = 1, 1 - self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, 4, 4, 100) + self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE, + TOTAL_ENTRIES) UpdateMap(self.map_fd, key, value) - self.assertEquals(LookupMap(self.map_fd, key).value, value) + self.assertEquals(value, LookupMap(self.map_fd, key).value) DeleteMap(self.map_fd, key) self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, key) def testIterateMap(self): - self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, 4, 4, 100) + self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE, + TOTAL_ENTRIES) value = 1024 - for key in xrange(1, 100): + for key in xrange(1, TOTAL_ENTRIES): UpdateMap(self.map_fd, key, value) - for key in xrange(1, 100): - self.assertEquals(LookupMap(self.map_fd, key).value, value) + for key in xrange(1, TOTAL_ENTRIES): + self.assertEquals(value, LookupMap(self.map_fd, key).value) self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, 101) key = 0 count = 0 while 1: - if count == 99: + if count == TOTAL_ENTRIES - 1: self.assertRaisesErrno(errno.ENOENT, GetNextKey, self.map_fd, key) break else: result = GetNextKey(self.map_fd, key) key = result.value self.assertGreater(key, 0) - self.assertEquals(LookupMap(self.map_fd, key).value, value) + self.assertEquals(value, LookupMap(self.map_fd, key).value) count += 1 def testProgLoad(self): - bpf_prog = BpfMov64Reg(BPF_REG_6, BPF_REG_1) - bpf_prog += BpfLdxMem(BPF_W, BPF_REG_0, BPF_REG_6, 0) - bpf_prog += BpfExitInsn() - insn_buff = ctypes.create_string_buffer(bpf_prog) - # Load a program that does nothing except pass every packet it receives - # It should not block the packet transmission otherwise the test fails. - self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, - ctypes.addressof(insn_buff), - len(insn_buff), BpfInsn._length) - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0) - sock.settimeout(1) - BpfProgAttachSocket(sock.fileno(), self.prog_fd) - addr = "127.0.0.1" - sock.bind((addr, 0)) - addr = sock.getsockname() - sockaddr = csocket.Sockaddr(addr) - sock.sendto("foo", addr) - data, addr = csocket.Recvfrom(sock, 4096, 0) - self.assertEqual("foo", data) - self.assertEqual(sockaddr, addr) + # Move skb to BPF_REG_6 for further usage + instructions = [ + BpfMov64Reg(BPF_REG_6, BPF_REG_1) + ] + instructions += INS_SK_FILTER_ACCEPT + self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions) + SocketUDPLoopBack(1, 4, self.prog_fd) + SocketUDPLoopBack(1, 6, self.prog_fd) def testPacketBlock(self): - bpf_prog = BpfMov64Reg(BPF_REG_6, BPF_REG_1) - bpf_prog += BpfMov64Imm(BPF_REG_0, 0) - bpf_prog += BpfExitInsn() - insn_buff = ctypes.create_string_buffer(bpf_prog) - self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, - ctypes.addressof(insn_buff), - len(insn_buff), BpfInsn._length) - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0) - sock.settimeout(1) - BpfProgAttachSocket(sock.fileno(), self.prog_fd) - addr = "127.0.0.1" - sock.bind((addr, 0)) - addr = sock.getsockname() - sock.sendto("foo", addr) - self.assertRaisesErrno(errno.EAGAIN, csocket.Recvfrom, sock, 4096, 0) + self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, INS_BPF_EXIT_BLOCK) + self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 4, self.prog_fd) + self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 6, self.prog_fd) def testPacketCount(self): - self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, 4, 4, 10) + self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE, + TOTAL_ENTRIES) key = 0xf0f0 - bpf_prog = BpfMov64Reg(BPF_REG_6, BPF_REG_1) - bpf_prog += BpfLoadMapFd(self.map_fd, BPF_REG_1) - bpf_prog += BpfMov64Imm(BPF_REG_7, key) - bpf_prog += BpfStxMem(BPF_W, BPF_REG_10, BPF_REG_7, -4) - bpf_prog += BpfMov64Reg(BPF_REG_8, BPF_REG_10) - bpf_prog += BpfAlu64Imm(BPF_ADD, BPF_REG_8, -4) - bpf_prog += BpfMov64Reg(BPF_REG_2, BPF_REG_8) - bpf_prog += BpfFuncLookupMap() - bpf_prog += BpfJumpImm(BPF_AND, BPF_REG_0, 0, 10) - bpf_prog += BpfLoadMapFd(self.map_fd, BPF_REG_1) - bpf_prog += BpfMov64Reg(BPF_REG_2, BPF_REG_8) - bpf_prog += BpfStMem(BPF_W, BPF_REG_10, -8, 1) - bpf_prog += BpfMov64Reg(BPF_REG_3, BPF_REG_10) - bpf_prog += BpfAlu64Imm(BPF_ADD, BPF_REG_3, -8) - bpf_prog += BpfMov64Imm(BPF_REG_4, 0) - bpf_prog += BpfFuncUpdateMap() - bpf_prog += BpfLdxMem(BPF_W, BPF_REG_0, BPF_REG_6, 0) - bpf_prog += BpfExitInsn() - bpf_prog += BpfMov64Reg(BPF_REG_2, BPF_REG_0) - bpf_prog += BpfMov64Imm(BPF_REG_1, 1) - bpf_prog += BpfRawInsn(BPF_STX | BPF_XADD | BPF_W, BPF_REG_2, BPF_REG_1, - 0, 0) - bpf_prog += BpfLdxMem(BPF_W, BPF_REG_0, BPF_REG_6, 0) - bpf_prog += BpfExitInsn() - insn_buff = ctypes.create_string_buffer(bpf_prog) - # this program loaded is used to counting the packet transmitted through - # a target socket. It will store the packet count into the eBPF map and we - # will verify if the counting result is correct. - self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, - ctypes.addressof(insn_buff), - len(insn_buff), BpfInsn._length) - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0) - sock.settimeout(1) - BpfProgAttachSocket(sock.fileno(), self.prog_fd) - addr = "127.0.0.1" - sock.bind((addr, 0)) - addr = sock.getsockname() - sockaddr = csocket.Sockaddr(addr) - packet_count = 100 - for i in xrange(packet_count): - sock.sendto("foo", addr) - data, retaddr = csocket.Recvfrom(sock, 4096, 0) - self.assertEqual("foo", data) - self.assertEqual(sockaddr, retaddr) - self.assertEquals(LookupMap(self.map_fd, key).value, packet_count) + # Set up instruction block with key loaded at BPF_REG_0. + instructions = [ + BpfMov64Reg(BPF_REG_6, BPF_REG_1), + BpfMov64Imm(BPF_REG_0, key) + ] + # Concatenate the generic packet count bpf program to it. + instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd) + + INS_SK_FILTER_ACCEPT + INS_PACK_COUNT_UPDATE + + INS_SK_FILTER_ACCEPT) + self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions) + packet_count = 10 + SocketUDPLoopBack(packet_count, 4, self.prog_fd) + SocketUDPLoopBack(packet_count, 6, self.prog_fd) + self.assertEquals(packet_count * 2, LookupMap(self.map_fd, key).value) + + @unittest.skipUnless(HAVE_EBPF_ACCOUNTING, + "BPF helper function is not fully supported") + def testGetSocketCookie(self): + self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE, + TOTAL_ENTRIES) + # Move skb to REG6 for further usage, call helper function to get socket + # cookie of current skb and return the cookie at REG0 for next code block + instructions = [ + BpfMov64Reg(BPF_REG_6, BPF_REG_1), + BpfFuncCall(BPF_FUNC_get_socket_cookie) + ] + instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd) + + INS_SK_FILTER_ACCEPT + INS_PACK_COUNT_UPDATE + + INS_SK_FILTER_ACCEPT) + self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions) + packet_count = 10 + def PacketCountByCookie(version): + self.sock = SocketUDPLoopBack(packet_count, version, self.prog_fd) + cookie = sock_diag.SockDiag.GetSocketCookie(self.sock) + self.assertEquals(packet_count, LookupMap(self.map_fd, cookie).value) + self.sock.close() + PacketCountByCookie(4) + PacketCountByCookie(6) + + @unittest.skipUnless(HAVE_EBPF_ACCOUNTING, + "BPF helper function is not fully supported") + def testGetSocketUid(self): + self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE, + TOTAL_ENTRIES) + # Set up the instruction with uid at BPF_REG_0. + instructions = [ + BpfMov64Reg(BPF_REG_6, BPF_REG_1), + BpfFuncCall(BPF_FUNC_get_socket_uid) + ] + # Concatenate the generic packet count bpf program to it. + instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd) + + INS_SK_FILTER_ACCEPT + INS_PACK_COUNT_UPDATE + + INS_SK_FILTER_ACCEPT) + self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions) + packet_count = 10 + uid = 12345 + with net_test.RunAsUid(uid): + self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, uid) + SocketUDPLoopBack(packet_count, 4, self.prog_fd) + self.assertEquals(packet_count, LookupMap(self.map_fd, uid).value) + DeleteMap(self.map_fd, uid); + SocketUDPLoopBack(packet_count, 6, self.prog_fd) + self.assertEquals(packet_count, LookupMap(self.map_fd, uid).value) + +@unittest.skipUnless(HAVE_EBPF_ACCOUNTING, + "Cgroup BPF is not fully supported") +class BpfCgroupTest(net_test.NetworkTest): + + @classmethod + def setUpClass(cls): + if not os.path.isdir("/tmp"): + os.mkdir('/tmp') + os.system('mount -t cgroup2 cg_bpf /tmp') + cls._cg_fd = os.open('/tmp', os.O_DIRECTORY | os.O_RDONLY) + + @classmethod + def tearDownClass(cls): + os.close(cls._cg_fd) + os.system('umount cg_bpf') + + def setUp(self): + self.prog_fd = -1 + self.map_fd = -1 + + def tearDown(self): + if self.prog_fd >= 0: + os.close(self.prog_fd) + if self.map_fd >= 0: + os.close(self.map_fd) + try: + BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_EGRESS) + except socket.error: + pass + try: + BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS) + except socket.error: + pass + + def testCgroupBpfAttach(self): + self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, INS_BPF_EXIT_BLOCK) + BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_INGRESS) + BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS) + + def testCgroupIngress(self): + self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, INS_BPF_EXIT_BLOCK) + BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_INGRESS) + self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 4, None) + self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 6, None) + BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS) + SocketUDPLoopBack(1, 4, None) + SocketUDPLoopBack(1, 6, None) + + def testCgroupEgress(self): + self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, INS_BPF_EXIT_BLOCK) + BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_EGRESS) + self.assertRaisesErrno(errno.EPERM, SocketUDPLoopBack, 1, 4, None) + self.assertRaisesErrno(errno.EPERM, SocketUDPLoopBack, 1, 6, None) + BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_EGRESS) + SocketUDPLoopBack( 1, 4, None) + SocketUDPLoopBack( 1, 6, None) + def testCgroupBpfUid(self): + self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE, + TOTAL_ENTRIES) + # Similar to the program used in testGetSocketUid. + instructions = [ + BpfMov64Reg(BPF_REG_6, BPF_REG_1), + BpfFuncCall(BPF_FUNC_get_socket_uid) + ] + instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd) + + INS_CGROUP_ACCEPT + INS_PACK_COUNT_UPDATE + INS_CGROUP_ACCEPT) + self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, instructions) + BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_INGRESS) + uid = os.getuid() + packet_count = 20 + SocketUDPLoopBack(packet_count, 4, None) + self.assertEquals(packet_count, LookupMap(self.map_fd, uid).value) + DeleteMap(self.map_fd, uid); + SocketUDPLoopBack(packet_count, 6, None) + self.assertEquals(packet_count, LookupMap(self.map_fd, uid).value) + BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS) if __name__ == "__main__": unittest.main() diff --git a/net/test/run_net_test.sh b/net/test/run_net_test.sh index bfba4db..268c6d7 100755 --- a/net/test/run_net_test.sh +++ b/net/test/run_net_test.sh @@ -28,6 +28,7 @@ OPTIONS="$OPTIONS TRANSPORT INET_XFRM_MODE_TUNNEL INET6_AH INET6_ESP" OPTIONS="$OPTIONS INET6_XFRM_MODE_TRANSPORT INET6_XFRM_MODE_TUNNEL" OPTIONS="$OPTIONS CRYPTO_SHA256 CRYPTO_SHA512 CRYPTO_AES_X86_64" OPTIONS="$OPTIONS CRYPTO_ECHAINIV" +OPTIONS="$OPTIONS SOCK_CGROUP_DATA CGROUP_BPF" # For 3.1 kernels, where devtmpfs is not on by default. OPTIONS="$OPTIONS DEVTMPFS DEVTMPFS_MOUNT" diff --git a/net/test/sock_diag.py b/net/test/sock_diag.py index 1865891..c6278e1 100755 --- a/net/test/sock_diag.py +++ b/net/test/sock_diag.py @@ -375,6 +375,11 @@ class SockDiag(netlink.NetlinkSocket): sock_id = InetDiagSockId((sport, dport, src, dst, iface, "\x00" * 8)) return InetDiagReqV2((family, protocol, 0, 0xffffffff, sock_id)) + @staticmethod + def GetSocketCookie(s): + cookie = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8) + return struct.unpack("=Q", cookie)[0] + def FindSockInfoFromFd(self, s): """Gets a diag_msg and attrs from the kernel for the specified socket.""" req = self.DiagReqFromSocket(s) |