diff options
author | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2021-07-15 02:04:53 +0000 |
---|---|---|
committer | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2021-07-15 02:04:53 +0000 |
commit | 8b91ef61d8e0f72fc41b54e3dbdad7b81b91a15c (patch) | |
tree | aab8551ed6cf00ab715b49f4c5eff93f423cf5dd | |
parent | ee022c14d19c7593edaf8ed8130ceab0de182808 (diff) | |
parent | 1af183d66e3733393c6738704047247bd3c2f9a1 (diff) | |
download | tests-android12-mainline-captiveportallogin-release.tar.gz |
Snap for 7550930 from 1af183d66e3733393c6738704047247bd3c2f9a1 to mainline-captiveportallogin-releaseandroid-mainline-12.0.0_r6android-mainline-12.0.0_r23android12-mainline-captiveportallogin-release
Change-Id: I3de2fb7328f9799d912ac3240307f96fd5ebf35b
39 files changed, 1099 insertions, 444 deletions
@@ -1,3 +1,31 @@ +package { + default_applicable_licenses: ["kernel_tests_license"], +} + +// Added automatically by a large-scale-change that took the approach of +// 'apply every license found to every target'. While this makes sure we respect +// every license restriction, it may not be entirely correct. +// +// e.g. GPL in an MIT project might only apply to the contrib/ directory. +// +// Please consider splitting the single license below into multiple licenses, +// taking care not to lose any license_kind information, and overriding the +// default license using the 'licenses: [...]' property on targets as needed. +// +// For unused files, consider creating a 'fileGroup' with "//visibility:private" +// to attach the license to, and including a comment whether the files may be +// used in the current project. +// See: http://go/android-license-faq +license { + name: "kernel_tests_license", + visibility: [":__subpackages__"], + license_kinds: [ + "SPDX-license-identifier-Apache-2.0", + "SPDX-license-identifier-OpenSSL", + ], + // large-scale-change unable to identify any license_text files +} + python_defaults { name: "kernel_tests_defaults", version: { diff --git a/devicetree/early_mount/Android.bp b/devicetree/early_mount/Android.bp index 63131c6..6c60e91 100644 --- a/devicetree/early_mount/Android.bp +++ b/devicetree/early_mount/Android.bp @@ -12,6 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +package { + // See: http://go/android-license-faq + // A large-scale-change added 'default_applicable_licenses' to import + // all of the 'license_kinds' from "kernel_tests_license" + // to get the below license kinds: + // SPDX-license-identifier-Apache-2.0 + default_applicable_licenses: ["kernel_tests_license"], +} + python_test { name: "dt_early_mount_test", srcs: [ diff --git a/net/test/Android.bp b/net/test/Android.bp index 6c4a75f..2ecef87 100644 --- a/net/test/Android.bp +++ b/net/test/Android.bp @@ -1,3 +1,13 @@ +package { + // See: http://go/android-license-faq + // A large-scale-change added 'default_applicable_licenses' to import + // all of the 'license_kinds' from "kernel_tests_license" + // to get the below license kinds: + // SPDX-license-identifier-Apache-2.0 + // SPDX-license-identifier-OpenSSL + default_applicable_licenses: ["kernel_tests_license"], +} + python_defaults { name: "kernel_net_tests_defaults", srcs: [ diff --git a/net/test/bpf.py b/net/test/bpf.py index 5062e31..6d22423 100755 --- a/net/test/bpf.py +++ b/net/test/bpf.py @@ -14,14 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""kernel net test library for bpf testing.""" + import ctypes import os +import platform +import resource +import socket import csocket import cstruct import net_test -import socket -import platform # __NR_bpf syscall numbers for various architectures. # NOTE: If python inherited COMPAT_UTS_MACHINE, uname's 'machine' field will @@ -29,8 +32,11 @@ import platform # around this problem and pick the right syscall nr, we can additionally check # the bitness of the python interpreter. Assume that the 64-bit architectures # are not running with COMPAT_UTS_MACHINE and must be 64-bit at all times. -# TODO: is there a better way of doing this? -__NR_bpf = { +# +# Is there a better way of doing this? +# Is it correct to use os.uname()[4] instead of platform.machine() ? +# Should we use 'sys.maxsize > 2**32' instead of platform.architecture()[0] ? +__NR_bpf = { # pylint: disable=invalid-name "aarch64-32bit": 386, "aarch64-64bit": 280, "armv7l-32bit": 386, @@ -150,13 +156,18 @@ BPF_CALL = 0x80 BPF_EXIT = 0x90 # BPF helper function constants +# pylint: disable=invalid-name BPF_FUNC_unspec = 0 BPF_FUNC_map_lookup_elem = 1 BPF_FUNC_map_update_elem = 2 BPF_FUNC_map_delete_elem = 3 +BPF_FUNC_ktime_get_ns = 5 BPF_FUNC_get_current_uid_gid = 15 +BPF_FUNC_skb_change_head = 43 BPF_FUNC_get_socket_cookie = 46 BPF_FUNC_get_socket_uid = 47 +BPF_FUNC_ktime_get_boot_ns = 125 +# pylint: enable=invalid-name BPF_F_RDONLY = 1 << 3 BPF_F_WRONLY = 1 << 4 @@ -164,19 +175,30 @@ BPF_F_WRONLY = 1 << 4 # These object below belongs to the same kernel union and the types below # (e.g., bpf_attr_create) aren't kernel struct names but just different # variants of the union. -BpfAttrCreate = cstruct.Struct("bpf_attr_create", "=IIIII", - "map_type key_size value_size max_entries, map_flags") -BpfAttrOps = cstruct.Struct("bpf_attr_ops", "=QQQQ", - "map_fd key_ptr value_ptr flags") +# pylint: disable=invalid-name +BpfAttrCreate = cstruct.Struct( + "bpf_attr_create", "=IIIII", + "map_type key_size value_size max_entries, map_flags") +BpfAttrOps = cstruct.Struct( + "bpf_attr_ops", "=QQQQ", + "map_fd key_ptr value_ptr flags") BpfAttrProgLoad = cstruct.Struct( "bpf_attr_prog_load", "=IIQQIIQI", "prog_type insn_cnt insns" " license log_level log_size log_buf kern_version") BpfAttrProgAttach = cstruct.Struct( "bpf_attr_prog_attach", "=III", "target_fd attach_bpf_fd attach_type") BpfInsn = cstruct.Struct("bpf_insn", "=BBhi", "code dst_src_reg off imm") +# pylint: enable=invalid-name libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True) HAVE_EBPF_SUPPORT = net_test.LINUX_VERSION >= (4, 4, 0) +HAVE_EBPF_4_9 = net_test.LINUX_VERSION >= (4, 9, 0) +HAVE_EBPF_4_14 = net_test.LINUX_VERSION >= (4, 14, 0) +HAVE_EBPF_4_19 = net_test.LINUX_VERSION >= (4, 19, 0) +HAVE_EBPF_5_4 = net_test.LINUX_VERSION >= (5, 4, 0) + +# set memlock resource 1 GiB +resource.setrlimit(resource.RLIMIT_MEMLOCK, (1073741824, 1073741824)) # BPF program syscalls @@ -185,6 +207,7 @@ def BpfSyscall(op, attr): csocket.MaybeRaiseSocketError(ret) return ret + def CreateMap(map_type, key_size, value_size, max_entries, map_flags=0): attr = BpfAttrCreate((map_type, key_size, value_size, max_entries, map_flags)) return BpfSyscall(BPF_MAP_CREATE, attr) @@ -209,31 +232,34 @@ def LookupMap(map_fd, key): def GetNextKey(map_fd, key): + """Get the next key in the map after the specified key.""" if key is not None: c_key = ctypes.c_uint32(key) c_next_key = ctypes.c_uint32(0) key_ptr = ctypes.addressof(c_key) else: - key_ptr = 0; + key_ptr = 0 c_next_key = ctypes.c_uint32(0) attr = BpfAttrOps( (map_fd, key_ptr, ctypes.addressof(c_next_key), 0)) BpfSyscall(BPF_MAP_GET_NEXT_KEY, attr) return c_next_key + def GetFirstKey(map_fd): return GetNextKey(map_fd, None) + def DeleteMap(map_fd, key): c_key = ctypes.c_uint32(key) attr = BpfAttrOps((map_fd, ctypes.addressof(c_key), 0, 0)) BpfSyscall(BPF_MAP_DELETE_ELEM, attr) -def BpfProgLoad(prog_type, instructions): +def BpfProgLoad(prog_type, instructions, prog_license=b"GPL"): bpf_prog = "".join(instructions) insn_buff = ctypes.create_string_buffer(bpf_prog) - gpl_license = ctypes.create_string_buffer(b"GPL") + gpl_license = ctypes.create_string_buffer(prog_license) log_buf = ctypes.create_string_buffer(b"", LOG_SIZE) attr = BpfAttrProgLoad((prog_type, len(insn_buff) / len(BpfInsn), ctypes.addressof(insn_buff), @@ -241,6 +267,7 @@ def BpfProgLoad(prog_type, instructions): LOG_SIZE, ctypes.addressof(log_buf), 0)) return BpfSyscall(BPF_PROG_LOAD, attr) + # Attach a socket eBPF filter to a target socket def BpfProgAttachSocket(sock_fd, prog_fd): uint_fd = ctypes.c_uint32(prog_fd) @@ -248,11 +275,13 @@ def BpfProgAttachSocket(sock_fd, prog_fd): ctypes.pointer(uint_fd), ctypes.sizeof(uint_fd)) csocket.MaybeRaiseSocketError(ret) + # Attach a eBPF filter to a cgroup def BpfProgAttach(prog_fd, target_fd, prog_type): attr = BpfAttrProgAttach((target_fd, prog_fd, prog_type)) return BpfSyscall(BPF_PROG_ATTACH, attr) + # Detach a eBPF filter from a cgroup def BpfProgDetach(target_fd, prog_type): attr = BpfAttrProgAttach((target_fd, 0, prog_type)) diff --git a/net/test/bpf_test.py b/net/test/bpf_test.py index ea3e56b..a014918 100755 --- a/net/test/bpf_test.py +++ b/net/test/bpf_test.py @@ -18,19 +18,87 @@ import ctypes import errno import os import socket -import struct import subprocess import tempfile import unittest -from bpf import * # pylint: disable=wildcard-import +import bpf +from bpf import BPF_ADD +from bpf import BPF_AND +from bpf import BPF_CGROUP_INET_EGRESS +from bpf import BPF_CGROUP_INET_INGRESS +from bpf import BPF_CGROUP_INET_SOCK_CREATE +from bpf import BPF_DW +from bpf import BPF_F_RDONLY +from bpf import BPF_F_WRONLY +from bpf import BPF_FUNC_get_current_uid_gid +from bpf import BPF_FUNC_get_socket_cookie +from bpf import BPF_FUNC_get_socket_uid +from bpf import BPF_FUNC_ktime_get_boot_ns +from bpf import BPF_FUNC_ktime_get_ns +from bpf import BPF_FUNC_map_lookup_elem +from bpf import BPF_FUNC_map_update_elem +from bpf import BPF_FUNC_skb_change_head +from bpf import BPF_JNE +from bpf import BPF_MAP_TYPE_HASH +from bpf import BPF_PROG_TYPE_CGROUP_SKB +from bpf import BPF_PROG_TYPE_CGROUP_SOCK +from bpf import BPF_PROG_TYPE_SCHED_CLS +from bpf import BPF_PROG_TYPE_SOCKET_FILTER +from bpf import BPF_REG_0 +from bpf import BPF_REG_1 +from bpf import BPF_REG_10 +from bpf import BPF_REG_2 +from bpf import BPF_REG_3 +from bpf import BPF_REG_4 +from bpf import BPF_REG_6 +from bpf import BPF_REG_7 +from bpf import BPF_STX +from bpf import BPF_W +from bpf import BPF_XADD +from bpf import BpfAlu64Imm +from bpf import BpfExitInsn +from bpf import BpfFuncCall +from bpf import BpfJumpImm +from bpf import BpfLdxMem +from bpf import BpfLoadMapFd +from bpf import BpfMov64Imm +from bpf import BpfMov64Reg +from bpf import BpfProgAttach +from bpf import BpfProgAttachSocket +from bpf import BpfProgDetach +from bpf import BpfProgLoad +from bpf import BpfRawInsn +from bpf import BpfStMem +from bpf import BpfStxMem +from bpf import CreateMap +from bpf import DeleteMap +from bpf import GetFirstKey +from bpf import GetNextKey +from bpf import LookupMap +from bpf import UpdateMap import csocket import net_test +from net_test import LINUX_VERSION import sock_diag libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True) -HAVE_EBPF_ACCOUNTING = net_test.LINUX_VERSION >= (4, 9, 0) -HAVE_EBPF_SOCKET = net_test.LINUX_VERSION >= (4, 14, 0) + +HAVE_EBPF_ACCOUNTING = bpf.HAVE_EBPF_4_9 +HAVE_EBPF_SOCKET = bpf.HAVE_EBPF_4_14 + +# bpf_ktime_get_ns() was made non-GPL requiring in 5.8 and at the same time +# bpf_ktime_get_boot_ns() was added, both of these changes were backported to +# Android Common Kernel in 4.14.221, 4.19.175, 5.4.97. +# As such we require 4.14.222+ 4.19.176+ 5.4.98+ 5.8.0+, +# but since we only really care about LTS releases: +HAVE_EBPF_KTIME_GET_NS_APACHE2 = ( + ((LINUX_VERSION > (4, 14, 221)) and (LINUX_VERSION < (4, 19, 0))) or + ((LINUX_VERSION > (4, 19, 175)) and (LINUX_VERSION < (5, 4, 0))) or + (LINUX_VERSION > (5, 4, 97)) +) +HAVE_EBPF_KTIME_GET_BOOT_NS = HAVE_EBPF_KTIME_GET_NS_APACHE2 + KEY_SIZE = 8 VALUE_SIZE = 4 TOTAL_ENTRIES = 20 @@ -41,18 +109,19 @@ key_offset = -8 # Offset to store the map value in stack register REG10 value_offset = -16 + # Debug usage only. def PrintMapInfo(map_fd): # A random key that the map does not contain. key = 10086 while 1: try: - nextKey = GetNextKey(map_fd, key).value - value = LookupMap(map_fd, nextKey) - print repr(nextKey) + " : " + repr(value.value) - key = nextKey - except: - print "no value" + next_key = GetNextKey(map_fd, key).value + value = LookupMap(map_fd, next_key) + print(repr(next_key) + " : " + repr(value.value)) # pylint: disable=superfluous-parens + key = next_key + except socket.error: + print("no value") # pylint: disable=superfluous-parens break @@ -67,7 +136,7 @@ def SocketUDPLoopBack(packet_count, version, prog_fd): sock.bind((addr, 0)) addr = sock.getsockname() sockaddr = csocket.Sockaddr(addr) - for i in xrange(packet_count): + for _ in range(packet_count): sock.sendto("foo", addr) data, retaddr = csocket.Recvfrom(sock, 4096, 0) assert "foo" == data @@ -91,7 +160,7 @@ def SocketUDPLoopBack(packet_count, version, prog_fd): # the stack. def BpfFuncCountPacketInit(map_fd): key_pos = BPF_REG_7 - insPackCountStart = [ + return [ # Get a preloaded key from BPF_REG_0 and store it at BPF_REG_7 BpfMov64Reg(key_pos, BPF_REG_10), BpfAlu64Imm(BPF_ADD, key_pos, key_offset), @@ -111,7 +180,6 @@ def BpfFuncCountPacketInit(map_fd): BpfMov64Imm(BPF_REG_4, 0), BpfFuncCall(BPF_FUNC_map_update_elem) ] - return insPackCountStart INS_BPF_EXIT_BLOCK = [ @@ -148,11 +216,13 @@ INS_BPF_PARAM_STORE = [ BpfStxMem(BPF_DW, BPF_REG_10, BPF_REG_0, key_offset), ] + @unittest.skipUnless(HAVE_EBPF_ACCOUNTING, "BPF helper function is not fully supported") class BpfTest(net_test.NetworkTest): def setUp(self): + super(BpfTest, self).setUp() self.map_fd = -1 self.prog_fd = -1 self.sock = None @@ -164,38 +234,39 @@ class BpfTest(net_test.NetworkTest): os.close(self.map_fd) if self.sock: self.sock.close() + super(BpfTest, self).tearDown() def testCreateMap(self): key, value = 1, 1 self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE, TOTAL_ENTRIES) UpdateMap(self.map_fd, key, value) - self.assertEquals(value, LookupMap(self.map_fd, key).value) + self.assertEqual(value, LookupMap(self.map_fd, key).value) DeleteMap(self.map_fd, key) self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, key) - def CheckAllMapEntry(self, nonexistent_key, totalEntries, value): + def CheckAllMapEntry(self, nonexistent_key, total_entries, value): count = 0 key = nonexistent_key while True: - if count == totalEntries: + if count == total_entries: self.assertRaisesErrno(errno.ENOENT, GetNextKey, self.map_fd, key) break else: result = GetNextKey(self.map_fd, key) key = result.value self.assertGreaterEqual(key, 0) - self.assertEquals(value, LookupMap(self.map_fd, key).value) + self.assertEqual(value, LookupMap(self.map_fd, key).value) count += 1 def testIterateMap(self): self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE, TOTAL_ENTRIES) value = 1024 - for key in xrange(0, TOTAL_ENTRIES): + for key in range(0, TOTAL_ENTRIES): UpdateMap(self.map_fd, key, value) - for key in xrange(0, TOTAL_ENTRIES): - self.assertEquals(value, LookupMap(self.map_fd, key).value) + for key in range(0, TOTAL_ENTRIES): + self.assertEqual(value, LookupMap(self.map_fd, key).value) self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, 101) nonexistent_key = -1 self.CheckAllMapEntry(nonexistent_key, TOTAL_ENTRIES, value) @@ -204,13 +275,12 @@ class BpfTest(net_test.NetworkTest): self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE, TOTAL_ENTRIES) value = 1024 - for key in xrange(0, TOTAL_ENTRIES): + for key in range(0, TOTAL_ENTRIES): UpdateMap(self.map_fd, key, value) - firstKey = GetFirstKey(self.map_fd) - key = firstKey.value + first_key = GetFirstKey(self.map_fd) + key = first_key.value self.CheckAllMapEntry(key, TOTAL_ENTRIES - 1, value) - def testRdOnlyMap(self): self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE, TOTAL_ENTRIES, map_flags=BPF_F_RDONLY) @@ -219,8 +289,6 @@ class BpfTest(net_test.NetworkTest): self.assertRaisesErrno(errno.EPERM, UpdateMap, self.map_fd, key, value) self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, key) - @unittest.skipUnless(HAVE_EBPF_ACCOUNTING, - "BPF helper function is not fully supported") def testWrOnlyMap(self): self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE, TOTAL_ENTRIES, map_flags=BPF_F_WRONLY) @@ -261,10 +329,93 @@ class BpfTest(net_test.NetworkTest): packet_count = 10 SocketUDPLoopBack(packet_count, 4, self.prog_fd) SocketUDPLoopBack(packet_count, 6, self.prog_fd) - self.assertEquals(packet_count * 2, LookupMap(self.map_fd, key).value) + self.assertEqual(packet_count * 2, LookupMap(self.map_fd, key).value) + + ############################################################################## + # + # Test for presence of kernel patch: + # + # ANDROID: net: bpf: Allow TC programs to call BPF_FUNC_skb_change_head + # + # 4.14: https://android-review.googlesource.com/c/kernel/common/+/1237789 + # commit fe82848d9c1c887d2a84d3738c13e644d01b6d6f + # + # 4.19: https://android-review.googlesource.com/c/kernel/common/+/1237788 + # commit 6e04d94ab72435b45c413daff63520fd724e260e + # + # 5.4: https://android-review.googlesource.com/c/kernel/common/+/1237787 + # commit d730995e7bc5b4c10cc176235b704a274e6ec16f + # + # Upstream in Linux v5.8: + # net: bpf: Allow TC programs to call BPF_FUNC_skb_change_head + # commit 6f3f65d80dac8f2bafce2213005821fccdce194c + # + @unittest.skipUnless(bpf.HAVE_EBPF_4_14, + "no bpf_skb_change_head() support for pre-4.14 kernels") + def testSkbChangeHead(self): + # long bpf_skb_change_head(struct sk_buff *skb, u32 len, u64 flags) + instructions = [ + BpfMov64Imm(BPF_REG_2, 14), # u32 len + BpfMov64Imm(BPF_REG_3, 0), # u64 flags + BpfFuncCall(BPF_FUNC_skb_change_head), + ] + INS_BPF_EXIT_BLOCK + self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SCHED_CLS, instructions, + b"Apache 2.0") + # No exceptions? Good. + + def testKtimeGetNsGPL(self): + instructions = [BpfFuncCall(BPF_FUNC_ktime_get_ns)] + INS_BPF_EXIT_BLOCK + self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SCHED_CLS, instructions) + # No exceptions? Good. + + ############################################################################## + # + # Test for presence of kernel patch: + # + # UPSTREAM: net: bpf: Make bpf_ktime_get_ns() available to non GPL programs + # + # 4.14: https://android-review.googlesource.com/c/kernel/common/+/1585269 + # commit cbb4c73f9eab8f3c8ac29175d45c99ccba382e15 + # + # 4.19: https://android-review.googlesource.com/c/kernel/common/+/1355243 + # commit 272e21ccc9a92feeee80aff0587410a314b73c5b + # + # 5.4: https://android-review.googlesource.com/c/kernel/common/+/1355422 + # commit 45217b91eaaa3a563247c4f470f4cb785de6b1c6 + # + @unittest.skipUnless(HAVE_EBPF_KTIME_GET_NS_APACHE2, + "no bpf_ktime_get_ns() support for non-GPL programs") + def testKtimeGetNsApache2(self): + instructions = [BpfFuncCall(BPF_FUNC_ktime_get_ns)] + INS_BPF_EXIT_BLOCK + self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SCHED_CLS, instructions, + b"Apache 2.0") + # No exceptions? Good. + + ############################################################################## + # + # Test for presence of kernel patch: + # + # BACKPORT: bpf: add bpf_ktime_get_boot_ns() + # + # 4.14: https://android-review.googlesource.com/c/kernel/common/+/1585587 + # commit 34073d7a8ee47ca908b56e9a1d14ca0615fdfc09 + # + # 4.19: https://android-review.googlesource.com/c/kernel/common/+/1585606 + # commit 4812ec50935dfe59ba9f48a572e278dd0b02af68 + # + # 5.4: https://android-review.googlesource.com/c/kernel/common/+/1585252 + # commit 57b3f4830fb66a6038c4c1c66ca2e138fe8be231 + # + @unittest.skipUnless(HAVE_EBPF_KTIME_GET_BOOT_NS, + "no bpf_ktime_get_boot_ns() support") + def testKtimeGetBootNs(self): + instructions = [ + BpfFuncCall(BPF_FUNC_ktime_get_boot_ns), + ] + INS_BPF_EXIT_BLOCK + self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SCHED_CLS, instructions, + b"Apache 2.0") + # No exceptions? Good. - @unittest.skipUnless(HAVE_EBPF_ACCOUNTING, - "BPF helper function is not fully supported") def testGetSocketCookie(self): self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE, TOTAL_ENTRIES) @@ -282,13 +433,11 @@ class BpfTest(net_test.NetworkTest): def PacketCountByCookie(version): self.sock = SocketUDPLoopBack(packet_count, version, self.prog_fd) cookie = sock_diag.SockDiag.GetSocketCookie(self.sock) - self.assertEquals(packet_count, LookupMap(self.map_fd, cookie).value) + self.assertEqual(packet_count, LookupMap(self.map_fd, cookie).value) self.sock.close() PacketCountByCookie(4) PacketCountByCookie(6) - @unittest.skipUnless(HAVE_EBPF_ACCOUNTING, - "BPF helper function is not fully supported") def testGetSocketUid(self): self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE, TOTAL_ENTRIES) @@ -307,10 +456,11 @@ class BpfTest(net_test.NetworkTest): with net_test.RunAsUid(uid): self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, uid) SocketUDPLoopBack(packet_count, 4, self.prog_fd) - self.assertEquals(packet_count, LookupMap(self.map_fd, uid).value) - DeleteMap(self.map_fd, uid); + self.assertEqual(packet_count, LookupMap(self.map_fd, uid).value) + DeleteMap(self.map_fd, uid) SocketUDPLoopBack(packet_count, 6, self.prog_fd) - self.assertEquals(packet_count, LookupMap(self.map_fd, uid).value) + self.assertEqual(packet_count, LookupMap(self.map_fd, uid).value) + @unittest.skipUnless(HAVE_EBPF_ACCOUNTING, "Cgroup BPF is not fully supported") @@ -318,6 +468,7 @@ class BpfCgroupTest(net_test.NetworkTest): @classmethod def setUpClass(cls): + super(BpfCgroupTest, cls).setUpClass() cls._cg_dir = tempfile.mkdtemp(prefix="cg_bpf-") cmd = "mount -t cgroup2 cg_bpf %s" % cls._cg_dir try: @@ -332,10 +483,12 @@ class BpfCgroupTest(net_test.NetworkTest): @classmethod def tearDownClass(cls): os.close(cls._cg_fd) - subprocess.call(('umount %s' % cls._cg_dir).split()) + subprocess.call(("umount %s" % cls._cg_dir).split()) os.rmdir(cls._cg_dir) + super(BpfCgroupTest, cls).tearDownClass() def setUp(self): + super(BpfCgroupTest, self).setUp() self.prog_fd = -1 self.map_fd = -1 @@ -356,6 +509,7 @@ class BpfCgroupTest(net_test.NetworkTest): BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE) except socket.error: pass + super(BpfCgroupTest, self).tearDown() def testCgroupBpfAttach(self): self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, INS_BPF_EXIT_BLOCK) @@ -377,8 +531,8 @@ class BpfCgroupTest(net_test.NetworkTest): self.assertRaisesErrno(errno.EPERM, SocketUDPLoopBack, 1, 4, None) self.assertRaisesErrno(errno.EPERM, SocketUDPLoopBack, 1, 6, None) BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_EGRESS) - SocketUDPLoopBack( 1, 4, None) - SocketUDPLoopBack( 1, 6, None) + SocketUDPLoopBack(1, 4, None) + SocketUDPLoopBack(1, 6, None) def testCgroupBpfUid(self): self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE, @@ -389,7 +543,8 @@ class BpfCgroupTest(net_test.NetworkTest): BpfFuncCall(BPF_FUNC_get_socket_uid) ] instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd) - + INS_CGROUP_ACCEPT + INS_PACK_COUNT_UPDATE + INS_CGROUP_ACCEPT) + + INS_CGROUP_ACCEPT + INS_PACK_COUNT_UPDATE + + INS_CGROUP_ACCEPT) self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, instructions) BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_INGRESS) packet_count = 20 @@ -397,45 +552,44 @@ class BpfCgroupTest(net_test.NetworkTest): with net_test.RunAsUid(uid): self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, uid) SocketUDPLoopBack(packet_count, 4, None) - self.assertEquals(packet_count, LookupMap(self.map_fd, uid).value) + self.assertEqual(packet_count, LookupMap(self.map_fd, uid).value) DeleteMap(self.map_fd, uid) SocketUDPLoopBack(packet_count, 6, None) - self.assertEquals(packet_count, LookupMap(self.map_fd, uid).value) + self.assertEqual(packet_count, LookupMap(self.map_fd, uid).value) BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS) def checkSocketCreate(self, family, socktype, success): try: sock = socket.socket(family, socktype, 0) sock.close() - except socket.error, e: + except socket.error as e: if success: self.fail("Failed to create socket family=%d type=%d err=%s" % (family, socktype, os.strerror(e.errno))) - return; + return if not success: - self.fail("unexpected socket family=%d type=%d created, should be blocked" % - (family, socktype)) - + self.fail("unexpected socket family=%d type=%d created, should be blocked" + % (family, socktype)) def trySocketCreate(self, success): - for family in [socket.AF_INET, socket.AF_INET6]: - for socktype in [socket.SOCK_DGRAM, socket.SOCK_STREAM]: - self.checkSocketCreate(family, socktype, success) + for family in [socket.AF_INET, socket.AF_INET6]: + for socktype in [socket.SOCK_DGRAM, socket.SOCK_STREAM]: + self.checkSocketCreate(family, socktype, success) @unittest.skipUnless(HAVE_EBPF_SOCKET, - "Cgroup BPF socket is not supported") + "Cgroup BPF socket is not supported") def testCgroupSocketCreateBlock(self): instructions = [ BpfFuncCall(BPF_FUNC_get_current_uid_gid), BpfAlu64Imm(BPF_AND, BPF_REG_0, 0xfffffff), BpfJumpImm(BPF_JNE, BPF_REG_0, TEST_UID, 2), ] - instructions += INS_BPF_EXIT_BLOCK + INS_CGROUP_ACCEPT; + instructions += INS_BPF_EXIT_BLOCK + INS_CGROUP_ACCEPT self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SOCK, instructions) BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE) with net_test.RunAsUid(TEST_UID): # Socket creation with target uid should fail - self.trySocketCreate(False); + self.trySocketCreate(False) # Socket create with different uid should success self.trySocketCreate(True) BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE) diff --git a/net/test/csocket_test.py b/net/test/csocket_test.py index 2191fec..19760fe 100755 --- a/net/test/csocket_test.py +++ b/net/test/csocket_test.py @@ -16,7 +16,15 @@ """Unit tests for csocket.""" -from socket import * # pylint: disable=wildcard-import +import socket +# pylint: disable=g-importing-member +from socket import AF_INET +from socket import AF_INET6 +from socket import inet_pton +from socket import SOCK_DGRAM +from socket import SOL_IP +# pylint: enable=g-importing-member + import unittest import csocket @@ -29,7 +37,7 @@ SOL_IPV6 = 41 class CsocketTest(unittest.TestCase): def _BuildSocket(self, family, addr): - s = socket(family, SOCK_DGRAM, 0) + s = socket.socket(family, SOCK_DGRAM, 0) s.bind((addr, 0)) return s diff --git a/net/test/cstruct.py b/net/test/cstruct.py index 5e05263..c675c9e 100644 --- a/net/test/cstruct.py +++ b/net/test/cstruct.py @@ -28,24 +28,24 @@ Example usage: >>> # Create instances from a tuple of values, raw bytes, zero-initialized, or >>> # using keywords. ... n1 = NLMsgHdr((44, 32, 0x2, 0, 491)) ->>> print n1 +>>> print(n1) NLMsgHdr(length=44, type=32, flags=2, seq=0, pid=491) >>> >>> n2 = NLMsgHdr("\x2c\x00\x00\x00\x21\x00\x02\x00" ... "\x00\x00\x00\x00\xfe\x01\x00\x00" + "junk at end") ->>> print n2 +>>> print(n2) NLMsgHdr(length=44, type=33, flags=2, seq=0, pid=510) >>> >>> n3 = netlink.NLMsgHdr() # Zero-initialized ->>> print n3 +>>> print(n3) NLMsgHdr(length=0, type=0, flags=0, seq=0, pid=0) >>> >>> n4 = netlink.NLMsgHdr(length=44, type=33) # Other fields zero-initialized ->>> print n4 +>>> print(n4) NLMsgHdr(length=44, type=33, flags=0, seq=0, pid=0) >>> >>> # Serialize to raw bytes. -... print n1.Pack().encode("hex") +... print(n1.Pack().encode("hex")) 2c0000002000020000000000eb010000 >>> >>> # Parse the beginning of a byte stream as a struct, and return the struct @@ -70,11 +70,14 @@ NLMsgHdr(length=44, type=33, flags=0, seq=0, pid=0) import ctypes import string import struct +import re def CalcSize(fmt): if "A" in fmt: fmt = fmt.replace("A", "s") + # Remove the last digital since it will cause error in python3. + fmt = (re.split('\d+$', fmt)[0]) return struct.calcsize(fmt) def CalcNumElements(fmt): @@ -82,7 +85,7 @@ def CalcNumElements(fmt): fmt = fmt.replace("S", "") numstructs = prevlen - len(fmt) size = CalcSize(fmt) - elements = struct.unpack(fmt, "\x00" * size) + elements = struct.unpack(fmt, b"\x00" * size) return len(elements) + numstructs @@ -118,7 +121,7 @@ def Struct(name, fmt, fieldnames, substructs={}): # where XX is the length of the struct type's packed representation. _format = "" laststructindex = 0 - for i in xrange(len(fmt)): + for i in range(len(fmt)): if fmt[i] == "S": # Nested struct. Record the index in our struct it should go into. index = CalcNumElements(fmt[:i]) @@ -138,17 +141,17 @@ def Struct(name, fmt, fieldnames, substructs={}): offset_list = [0] last_offset = 0 - for i in xrange(len(_format)): + for i in range(len(_format)): offset = CalcSize(_format[:i]) if offset > last_offset: last_offset = offset offset_list.append(offset) # A dictionary that maps field names to their offsets in the struct. - _offsets = dict(zip(_fieldnames, offset_list)) + _offsets = dict(list(zip(_fieldnames, offset_list))) # Check that the number of field names matches the number of fields. - numfields = len(struct.unpack(_format, "\x00" * _length)) + numfields = len(struct.unpack(_format, b"\x00" * _length)) if len(_fieldnames) != numfields: raise ValueError("Invalid cstruct: \"%s\" has %d elements, \"%s\" has %d." % (fmt, numfields, fieldnames, len(_fieldnames))) @@ -182,7 +185,7 @@ def Struct(name, fmt, fieldnames, substructs={}): # Default construct from null bytes. self._Parse("\x00" * len(self)) # If any keywords were supplied, set those fields. - for k, v in kwargs.iteritems(): + for k, v in kwargs.items(): setattr(self, k, v) elif isinstance(tuple_or_bytes, str): # Initializing from a string. diff --git a/net/test/cstruct_test.py b/net/test/cstruct_test.py index 6b27973..af1fc42 100755 --- a/net/test/cstruct_test.py +++ b/net/test/cstruct_test.py @@ -27,16 +27,16 @@ TestStructB = cstruct.Struct("TestStructB", "=BI", "byte1 int2") class CstructTest(unittest.TestCase): def CheckEquals(self, a, b): - self.assertEquals(a, b) - self.assertEquals(b, a) + self.assertEqual(a, b) + self.assertEqual(b, a) assert a == b assert b == a assert not (a != b) # pylint: disable=g-comparison-negation,superfluous-parens assert not (b != a) # pylint: disable=g-comparison-negation,superfluous-parens def CheckNotEquals(self, a, b): - self.assertNotEquals(a, b) - self.assertNotEquals(b, a) + self.assertNotEqual(a, b) + self.assertNotEqual(b, a) assert a != b assert b != a assert not (a == b) # pylint: disable=g-comparison-negation,superfluous-parens @@ -69,9 +69,9 @@ class CstructTest(unittest.TestCase): expectedlen = (len(TestStructA) + 2 + len(TestStructA) + len(TestStructB) + 4 + 1) - self.assertEquals(expectedlen, len(DoubleNested)) + self.assertEqual(expectedlen, len(DoubleNested)) - self.assertEquals(7, d.nest2.nest3.byte1) + self.assertEqual(7, d.nest2.nest3.byte1) d.byte3 = 252 d.nest2.word1 = 33214 @@ -80,17 +80,17 @@ class CstructTest(unittest.TestCase): t = n.nest3 t.int2 = 33627591 - self.assertEquals(33627591, d.nest2.nest3.int2) + self.assertEqual(33627591, d.nest2.nest3.int2) expected = ( "DoubleNested(nest1=TestStructA(byte1=1, int2=2)," " nest2=Nested(word1=33214, nest2=TestStructA(byte1=3, int2=4)," " nest3=TestStructB(byte1=7, int2=33627591), int4=-55), byte3=252)") - self.assertEquals(expected, str(d)) + self.assertEqual(expected, str(d)) expected = ("01" "02000000" "81be" "03" "04000000" "07" "c71d0102" "ffffffc9" "fc").decode("hex") - self.assertEquals(expected, d.Pack()) + self.assertEqual(expected, d.Pack()) unpacked = DoubleNested(expected) self.CheckEquals(unpacked, d) @@ -102,24 +102,24 @@ class CstructTest(unittest.TestCase): t = TestStruct((2, nullstr, 12345, nullstr, 33210)) expected = ("TestStruct(byte1=2, string2=68656c6c6f0000000000000000000000," " int3=12345, ascii4=hello, word5=33210)") - self.assertEquals(expected, str(t)) + self.assertEqual(expected, str(t)) embeddednull = "hello\x00visible123" t = TestStruct((2, embeddednull, 12345, embeddednull, 33210)) expected = ("TestStruct(byte1=2, string2=68656c6c6f0076697369626c65313233," " int3=12345, ascii4=hello\x00visible123, word5=33210)") - self.assertEquals(expected, str(t)) + self.assertEqual(expected, str(t)) def testZeroInitialization(self): TestStruct = cstruct.Struct("TestStruct", "B16si16AH", "byte1 string2 int3 ascii4 word5") t = TestStruct() - self.assertEquals(0, t.byte1) - self.assertEquals("\x00" * 16, t.string2) - self.assertEquals(0, t.int3) - self.assertEquals("\x00" * 16, t.ascii4) - self.assertEquals(0, t.word5) - self.assertEquals("\x00" * len(TestStruct), t.Pack()) + self.assertEqual(0, t.byte1) + self.assertEqual("\x00" * 16, t.string2) + self.assertEqual(0, t.int3) + self.assertEqual("\x00" * 16, t.ascii4) + self.assertEqual(0, t.word5) + self.assertEqual("\x00" * len(TestStruct), t.Pack()) def testKeywordInitialization(self): TestStruct = cstruct.Struct("TestStruct", "=B16sIH", @@ -130,26 +130,26 @@ class CstructTest(unittest.TestCase): # Populate all fields t1 = TestStruct(byte1=1, string2=text, int3=0xFEDCBA98, word4=0x1234) expected = ("01" + text_bytes + "98BADCFE" "3412").decode("hex") - self.assertEquals(expected, t1.Pack()) + self.assertEqual(expected, t1.Pack()) # Partially populated t1 = TestStruct(string2=text, word4=0x1234) expected = ("00" + text_bytes + "00000000" "3412").decode("hex") - self.assertEquals(expected, t1.Pack()) + self.assertEqual(expected, t1.Pack()) def testCstructOffset(self): TestStruct = cstruct.Struct("TestStruct", "B16si16AH", "byte1 string2 int3 ascii4 word5") nullstr = "hello" + (16 - len("hello")) * "\x00" t = TestStruct((2, nullstr, 12345, nullstr, 33210)) - self.assertEquals(0, t.offset("byte1")) - self.assertEquals(1, t.offset("string2")) # sizeof(byte) - self.assertEquals(17, t.offset("int3")) # sizeof(byte) + 16*sizeof(char) + self.assertEqual(0, t.offset("byte1")) + self.assertEqual(1, t.offset("string2")) # sizeof(byte) + self.assertEqual(17, t.offset("int3")) # sizeof(byte) + 16*sizeof(char) # The integer is automatically padded by the struct module # to match native alignment. # offset = sizeof(byte) + 16*sizeof(char) + padding + sizeof(int) - self.assertEquals(24, t.offset("ascii4")) - self.assertEquals(40, t.offset("word5")) + self.assertEqual(24, t.offset("ascii4")) + self.assertEqual(40, t.offset("word5")) self.assertRaises(KeyError, t.offset, "random") # TODO: Add support for nested struct offset diff --git a/net/test/forwarding_test.py b/net/test/forwarding_test.py index 34394cd..b35e19f 100755 --- a/net/test/forwarding_test.py +++ b/net/test/forwarding_test.py @@ -126,8 +126,8 @@ class ForwardingTest(multinetwork_base.MultiNetworkBaseTest): mydst = "%s:%04X" % (net_test.FormatSockStatAddress(remoteaddr), remoteport) state = None sockets = [s for s in sockets if s[0] == mysrc and s[1] == mydst] - self.assertEquals(1, len(sockets)) - self.assertEquals("%02X" % self.TCP_TIME_WAIT, sockets[0][2]) + self.assertEqual(1, len(sockets)) + self.assertEqual("%02X" % self.TCP_TIME_WAIT, sockets[0][2]) # Remove our IP address. try: diff --git a/net/test/genetlink.py b/net/test/genetlink.py index dda3964..6928f07 100755 --- a/net/test/genetlink.py +++ b/net/test/genetlink.py @@ -120,4 +120,4 @@ class GenericNetlinkControl(GenericNetlink): if __name__ == "__main__": g = GenericNetlinkControl() - print g.GetFamily("tcp_metrics") + print(g.GetFamily("tcp_metrics")) diff --git a/net/test/iproute.py b/net/test/iproute.py index 470cbf1..9036246 100644 --- a/net/test/iproute.py +++ b/net/test/iproute.py @@ -408,7 +408,7 @@ class IPRoute(netlink.NetlinkSocket): while True: try: self._SendNlRequest(RTM_DELRULE, rtmsg) - except IOError, e: + except IOError as e: if e.errno == errno.ENOENT: break else: @@ -459,7 +459,7 @@ class IPRoute(netlink.NetlinkSocket): subject = CommandSubject(command) if "ALL" not in self.NL_DEBUG and subject not in self.NL_DEBUG: return - print self.CommandToString(command, data) + print(self.CommandToString(command, data)) def MaybeDebugMessage(self, message): hdr = netlink.NLMsgHdr(message) @@ -467,7 +467,7 @@ class IPRoute(netlink.NetlinkSocket): def PrintMessage(self, message): hdr = netlink.NLMsgHdr(message) - print self.CommandToString(hdr.type, message) + print(self.CommandToString(hdr.type, message)) def DumpRules(self, version): """Returns the IP rules for the specified IP version.""" @@ -774,4 +774,4 @@ if __name__ == "__main__": iproute.DEBUG = True iproute.DumpRules(6) iproute.DumpLinks() - print iproute.GetRoutes("2001:4860:4860::8888", 0, 0, None) + print(iproute.GetRoutes("2001:4860:4860::8888", 0, 0, None)) diff --git a/net/test/leak_test.py b/net/test/leak_test.py index 8a42611..a245817 100755 --- a/net/test/leak_test.py +++ b/net/test/leak_test.py @@ -63,7 +63,7 @@ class ForceSocketBufferOptionTest(net_test.NetworkTest): val = 4097 self.assertGreater(2 * val, minbuf) s.setsockopt(SOL_SOCKET, force_option, val) - self.assertEquals(2 * val, s.getsockopt(SOL_SOCKET, option)) + self.assertEqual(2 * val, s.getsockopt(SOL_SOCKET, option)) # Check that the force option sets at least the minimum value instead # of a negative value on integer overflow. Because the kernel multiplies diff --git a/net/test/multinetwork_base.py b/net/test/multinetwork_base.py index 8dbd360..6b79d4f 100644 --- a/net/test/multinetwork_base.py +++ b/net/test/multinetwork_base.py @@ -366,7 +366,7 @@ class MultiNetworkBaseTest(net_test.NetworkTest): @classmethod def _RestoreSysctls(cls): - for sysctl, value in cls.saved_sysctls.iteritems(): + for sysctl, value in cls.saved_sysctls.items(): try: open(sysctl, "w").write(value) except IOError: @@ -558,7 +558,7 @@ class MultiNetworkBaseTest(net_test.NetworkTest): # MAC address has 1 in the least-significant bit. if include_multicast or not int(ether.dst.split(":")[0], 16) & 0x1: packets.append(ether.payload) - except OSError, e: + except OSError as e: # EAGAIN means there are no more packets waiting. if re.match(e.message, os.strerror(errno.EAGAIN)): # If we didn't see any packets, try again for good luck. @@ -669,7 +669,7 @@ class MultiNetworkBaseTest(net_test.NetworkTest): # repr() can be expensive. Call it only if the test is going to fail and we # want to see the error. if expected_real != actual_real: - self.assertEquals(repr(expected_real), repr(actual_real)) + self.assertEqual(repr(expected_real), repr(actual_real)) def PacketMatches(self, expected, actual): try: @@ -710,7 +710,7 @@ class MultiNetworkBaseTest(net_test.NetworkTest): # expected, but this is good enough for now. try: self.assertPacketMatches(expected, packets[-1]) - except Exception, e: + except Exception as e: raise UnexpectedPacketError( "%s: diff with last packet:\n%s" % (msg, e.message)) diff --git a/net/test/multinetwork_test.py b/net/test/multinetwork_test.py index a0b464a..092736b 100755 --- a/net/test/multinetwork_test.py +++ b/net/test/multinetwork_test.py @@ -134,7 +134,7 @@ class OutgoingTest(multinetwork_base.MultiNetworkBaseTest): self.ExpectPacketOn(netid, msg, expected) def CheckOutgoingPackets(self, routing_mode): - for _ in xrange(self.ITERATIONS): + for _ in range(self.ITERATIONS): for netid in self.tuns: self.CheckPingPacket(4, netid, routing_mode, self.IPV4_PING) @@ -194,7 +194,7 @@ class OutgoingTest(multinetwork_base.MultiNetworkBaseTest): # If we're testing connected sockets, connect the socket on the first # netid now. if use_connect: - netid = self.tuns.keys()[0] + netid = list(self.tuns.keys())[0] self.SelectInterface(s, netid, mode) s.connect((dstaddr, 53)) expected.src = self.MyAddress(version, netid) @@ -270,7 +270,7 @@ class OutgoingTest(multinetwork_base.MultiNetworkBaseTest): self.CheckRemarking(6, True) def testIPv6StickyPktinfo(self): - for _ in xrange(self.ITERATIONS): + for _ in range(self.ITERATIONS): for netid in self.tuns: s = net_test.UDPSocket(AF_INET6) @@ -312,7 +312,7 @@ class OutgoingTest(multinetwork_base.MultiNetworkBaseTest): self.ExpectPacketOn(netid, msg, expected) def CheckPktinfoRouting(self, version): - for _ in xrange(self.ITERATIONS): + for _ in range(self.ITERATIONS): for netid in self.tuns: family = self.GetProtocolFamily(version) s = net_test.UDPSocket(family) @@ -478,11 +478,11 @@ class TCPAcceptTest(multinetwork_base.InboundMarkingTest): self.InvalidateDstCache(version, netid) if mode == self.MODE_INCOMING_MARK: - self.assertEquals(netid, mark & self.NETID_FWMASK, + self.assertEqual(netid, mark & self.NETID_FWMASK, msg + ": Accepted socket: Expected mark %d, got %d" % ( netid, mark)) elif mode != self.MODE_EXPLICIT_MARK: - self.assertEquals(0, self.GetSocketMark(listensocket)) + self.assertEqual(0, self.GetSocketMark(listensocket)) # Check the FIN was sent on the right interface, and ack it. We don't expect # this to fail because by the time the connection is established things are @@ -670,14 +670,14 @@ class RIOTest(multinetwork_base.MultiNetworkBaseTest): table, prefix, plen, router, ifindex, None, None) def testSetAcceptRaRtInfoMinPlen(self): - for plen in xrange(-1, 130): + for plen in range(-1, 130): self.SetAcceptRaRtInfoMinPlen(plen) - self.assertEquals(plen, self.GetAcceptRaRtInfoMinPlen()) + self.assertEqual(plen, self.GetAcceptRaRtInfoMinPlen()) def testSetAcceptRaRtInfoMaxPlen(self): - for plen in xrange(-1, 130): + for plen in range(-1, 130): self.SetAcceptRaRtInfoMaxPlen(plen) - self.assertEquals(plen, self.GetAcceptRaRtInfoMaxPlen()) + self.assertEqual(plen, self.GetAcceptRaRtInfoMaxPlen()) def testZeroRtLifetime(self): PREFIX = "2001:db8:8901:2300::" @@ -700,7 +700,7 @@ class RIOTest(multinetwork_base.MultiNetworkBaseTest): RTLIFETIME = 70372 PRF = 0 # sweep from high to low to avoid spurious failures from late arrivals. - for plen in xrange(130, 1, -1): + for plen in range(130, 1, -1): self.SetAcceptRaRtInfoMinPlen(plen) # RIO with plen < min_plen should be ignored self.SendRIO(RTLIFETIME, plen - 1, PREFIX, PRF) @@ -715,7 +715,7 @@ class RIOTest(multinetwork_base.MultiNetworkBaseTest): RTLIFETIME = 73078 PRF = 0 # sweep from low to high to avoid spurious failures from late arrivals. - for plen in xrange(-1, 128, 1): + for plen in range(-1, 128, 1): self.SetAcceptRaRtInfoMaxPlen(plen) # RIO with plen > max_plen should be ignored self.SendRIO(RTLIFETIME, plen + 1, PREFIX, PRF) @@ -782,16 +782,16 @@ class RIOTest(multinetwork_base.MultiNetworkBaseTest): baseline = self.CountRoutes() self.SetAcceptRaRtInfoMaxPlen(56) # Send many RIOs compared to the expected number on a healthy system. - for i in xrange(0, COUNT): + for i in range(0, COUNT): prefix = "2001:db8:%x:1100::" % i self.SendRIO(RTLIFETIME, PLEN, prefix, PRF) time.sleep(0.1) - self.assertEquals(COUNT + baseline, self.CountRoutes()) - for i in xrange(0, COUNT): + self.assertEqual(COUNT + baseline, self.CountRoutes()) + for i in range(0, COUNT): prefix = "2001:db8:%x:1100::" % i self.DelRA6(prefix, PLEN) # Expect that we can return to baseline config without lingering routes. - self.assertEquals(baseline, self.CountRoutes()) + self.assertEqual(baseline, self.CountRoutes()) class RATest(multinetwork_base.MultiNetworkBaseTest): @@ -874,7 +874,7 @@ class RATest(multinetwork_base.MultiNetworkBaseTest): return len(open("/proc/net/ipv6_route").readlines()) num_routes = GetNumRoutes() - for i in xrange(10, 20): + for i in range(10, 20): try: self.tuns[i] = self.CreateTunInterface(i) self.SendRA(i) @@ -908,23 +908,23 @@ class RATest(multinetwork_base.MultiNetworkBaseTest): csocket.SetSocketTimeout(s.sock, 100) try: data = s._Recv() - except IOError, e: + except IOError as e: self.fail("Should have received an RTM_NEWNDUSEROPT message. " "Please ensure the kernel supports receiving the " "PREF64 RA option. Error: %s" % e) # Check that the message is received correctly. nlmsghdr, data = cstruct.Read(data, netlink.NLMsgHdr) - self.assertEquals(iproute.RTM_NEWNDUSEROPT, nlmsghdr.type) + self.assertEqual(iproute.RTM_NEWNDUSEROPT, nlmsghdr.type) # Check the option contents. ndopthdr, data = cstruct.Read(data, iproute.NdUseroptMsg) - self.assertEquals(AF_INET6, ndopthdr.family) - self.assertEquals(self.ND_ROUTER_ADVERT, ndopthdr.icmp_type) - self.assertEquals(len(opt), ndopthdr.opts_len) + self.assertEqual(AF_INET6, ndopthdr.family) + self.assertEqual(self.ND_ROUTER_ADVERT, ndopthdr.icmp_type) + self.assertEqual(len(opt), ndopthdr.opts_len) actual_opt = self.Pref64Option(data) - self.assertEquals(opt, actual_opt) + self.assertEqual(opt, actual_opt) @@ -986,7 +986,7 @@ class PMTUTest(multinetwork_base.InboundMarkingTest): # Send a packet and receive a packet too big. SendBigPacket(version, s, dstaddr, netid, payload) received = self.ReadAllPacketsOn(netid) - self.assertEquals(1, len(received), + self.assertEqual(1, len(received), "unexpected packets: %s" % received[1:]) _, toobig = packets.ICMPPacketTooBig(version, intermediate, srcaddr, received[0]) @@ -1000,7 +1000,7 @@ class PMTUTest(multinetwork_base.InboundMarkingTest): # If this is a connected socket, make sure the socket MTU was set. # Note that in IPv4 this only started working in Linux 3.6! if use_connect and (version == 6 or net_test.LINUX_VERSION >= (3, 6)): - self.assertEquals(packets.PTB_MTU, self.GetSocketMTU(version, s)) + self.assertEqual(packets.PTB_MTU, self.GetSocketMTU(version, s)) s.close() @@ -1010,16 +1010,16 @@ class PMTUTest(multinetwork_base.InboundMarkingTest): # here we use a mark for simplicity. s2 = self.BuildSocket(version, net_test.UDPSocket, netid, "mark") s2.connect((dstaddr, 1234)) - self.assertEquals(packets.PTB_MTU, self.GetSocketMTU(version, s2)) + self.assertEqual(packets.PTB_MTU, self.GetSocketMTU(version, s2)) # Also check the MTU reported by ip route get, this time using the oif. routes = self.iproute.GetRoutes(dstaddr, self.ifindices[netid], 0, None) self.assertTrue(routes) route = routes[0] rtmsg, attributes = route - self.assertEquals(iproute.RTN_UNICAST, rtmsg.type) + self.assertEqual(iproute.RTN_UNICAST, rtmsg.type) metrics = attributes["RTA_METRICS"] - self.assertEquals(packets.PTB_MTU, metrics["RTAX_MTU"]) + self.assertEqual(packets.PTB_MTU, metrics["RTAX_MTU"]) def testIPv4BasicPMTU(self): """Tests IPv4 path MTU discovery. @@ -1152,11 +1152,11 @@ class UidRoutingTest(multinetwork_base.MultiNetworkBaseTest): rules = self.GetRulesAtPriority(version, priority) self.assertTrue(rules) _, attributes = rules[-1] - self.assertEquals(priority, attributes["FRA_PRIORITY"]) + self.assertEqual(priority, attributes["FRA_PRIORITY"]) uidrange = attributes["FRA_UID_RANGE"] - self.assertEquals(start, uidrange.start) - self.assertEquals(end, uidrange.end) - self.assertEquals(table, attributes["FRA_TABLE"]) + self.assertEqual(start, uidrange.start) + self.assertEqual(end, uidrange.end) + self.assertEqual(table, attributes["FRA_TABLE"]) finally: self.iproute.UidRangeRule(version, False, start, end, table, priority) self.assertRaisesErrno( @@ -1223,15 +1223,15 @@ class UidRoutingTest(multinetwork_base.MultiNetworkBaseTest): try: routes = self.iproute.GetRoutes(addr, oif, mark, uid) rtmsg, _ = routes[0] - self.assertEquals(iproute.RTN_UNREACHABLE, rtmsg.type) - except IOError, e: + self.assertEqual(iproute.RTN_UNREACHABLE, rtmsg.type) + except IOError as e: if int(e.errno) != int(errno.ENETUNREACH): raise e def ExpectRoute(self, addr, oif, mark, uid): routes = self.iproute.GetRoutes(addr, oif, mark, uid) rtmsg, _ = routes[0] - self.assertEquals(iproute.RTN_UNICAST, rtmsg.type) + self.assertEqual(iproute.RTN_UNICAST, rtmsg.type) def CheckGetRoute(self, version, addr): self.ExpectNoRoute(addr, 0, 0, 0) @@ -1257,7 +1257,7 @@ class UidRoutingTest(multinetwork_base.MultiNetworkBaseTest): self.assertRaisesErrno(errno.ENETUNREACH, s.sendto, "foo", (remoteaddr, 53)) def CheckSendSucceeds(): - self.assertEquals(len("foo"), s.sendto("foo", (remoteaddr, 53))) + self.assertEqual(len("foo"), s.sendto("foo", (remoteaddr, 53))) CheckSendFails() self.iproute.UidRangeRule(6, True, uid, uid, table, self.PRIORITY_UID) @@ -1269,7 +1269,7 @@ class UidRoutingTest(multinetwork_base.MultiNetworkBaseTest): CheckSendSucceeds() os.fchown(s.fileno(), -1, 12345) CheckSendSucceeds() - os.fchmod(s.fileno(), 0777) + os.fchmod(s.fileno(), 0o777) CheckSendSucceeds() os.fchown(s.fileno(), 0, -1) CheckSendFails() @@ -1306,8 +1306,8 @@ class RulesTest(net_test.NetworkTest): # Check that the rule pointing at table 301 is still around. attributes = [a for _, a in self.iproute.DumpRules(version) if a.get("FRA_PRIORITY", 0) == self.RULE_PRIORITY] - self.assertEquals(1, len(attributes)) - self.assertEquals(301, attributes[0]["FRA_TABLE"]) + self.assertEqual(1, len(attributes)) + self.assertEqual(301, attributes[0]["FRA_TABLE"]) if __name__ == "__main__": diff --git a/net/test/namespace.py b/net/test/namespace.py index 85db654..c8f8f46 100644 --- a/net/test/namespace.py +++ b/net/test/namespace.py @@ -20,6 +20,7 @@ import ctypes import ctypes.util import os import socket +import sys import net_test import sock_diag @@ -112,10 +113,10 @@ def UnShare(flags): def DumpMounts(hdr): - print - print hdr - print open('/proc/mounts', 'r').read(), - print '---' + print('') + print(hdr) + sys.stdout.write(open('/proc/mounts', 'r').read()) + print('---') # Requires at least kernel configuration options: @@ -125,12 +126,12 @@ def DumpMounts(hdr): def IfPossibleEnterNewNetworkNamespace(): """Instantiate and transition into a fresh new network namespace if possible.""" - print 'Creating clean namespace...', + sys.stdout.write('Creating clean namespace... ') try: UnShare(CLONE_NEWNS | CLONE_NEWUTS | CLONE_NEWNET) except OSError as err: - print 'failed: %s (likely: no privs or lack of kernel support).' % err + print('failed: %s (likely: no privs or lack of kernel support).' % err) return False try: @@ -143,11 +144,11 @@ def IfPossibleEnterNewNetworkNamespace(): SetFileContents('/proc/sys/net/ipv4/ping_group_range', '0 2147483647') net_test.SetInterfaceUp('lo') except: - print 'failed.' + print('failed.') # We've already transitioned into the new netns -- it's too late to recover. raise - print 'succeeded.' + print('succeeded.') return True diff --git a/net/test/neighbour_test.py b/net/test/neighbour_test.py index 2cb5c23..8cea6da 100755 --- a/net/test/neighbour_test.py +++ b/net/test/neighbour_test.py @@ -93,7 +93,7 @@ class NeighbourTest(multinetwork_base.MultiNetworkBaseTest): self.sock.bind((0, RTMGRP_NEIGH)) net_test.SetNonBlocking(self.sock) - self.netid = random.choice(self.tuns.keys()) + self.netid = random.choice(list(self.tuns.keys())) self.ifindex = self.ifindices[self.netid] # MultinetworkBaseTest always uses NUD_PERMANENT for router ARP entries. @@ -144,19 +144,19 @@ class NeighbourTest(multinetwork_base.MultiNetworkBaseTest): self.assertRaisesErrno(errno.EAGAIN, self.sock.recvfrom, 4096, MSG_PEEK) def assertNeighbourState(self, state, addr): - self.assertEquals(state, self.GetNdEntry(addr)[0].state) + self.assertEqual(state, self.GetNdEntry(addr)[0].state) def assertNeighbourAttr(self, addr, name, value): - self.assertEquals(value, self.GetNdEntry(addr)[1][name]) + self.assertEqual(value, self.GetNdEntry(addr)[1][name]) def ExpectNeighbourNotification(self, addr, state, attrs=None): msg = self.sock.recv(4096) msg, actual_attrs = self.iproute.ParseNeighbourMessage(msg) - self.assertEquals(addr, actual_attrs["NDA_DST"]) - self.assertEquals(state, msg.state) + self.assertEqual(addr, actual_attrs["NDA_DST"]) + self.assertEqual(state, msg.state) if attrs: for name in attrs: - self.assertEquals(attrs[name], actual_attrs[name]) + self.assertEqual(attrs[name], actual_attrs[name]) def ExpectProbe(self, is_unicast, addr): version = csocket.AddressVersion(addr) @@ -225,7 +225,7 @@ class NeighbourTest(multinetwork_base.MultiNetworkBaseTest): sleep_ms = min(100, interval - slept) time.sleep(sleep_ms / 1000.0) slept += sleep_ms - print self.GetNdEntry(addr) + print(self.GetNdEntry(addr)) def MonitorSleep(self, intervalseconds, addr): self.MonitorSleepMs(intervalseconds * 1000, addr) @@ -319,7 +319,7 @@ class NeighbourTest(multinetwork_base.MultiNetworkBaseTest): self.assertNeighbourState(NUD_REACHABLE, addr) self.ExpectNeighbourNotification(addr, NUD_REACHABLE) - for _ in xrange(5): + for _ in range(5): ForceProbe(router6, routermac) def testIsRouterFlag(self): diff --git a/net/test/net_test.py b/net/test/net_test.py index 1c7f32f..c762cd8 100755 --- a/net/test/net_test.py +++ b/net/test/net_test.py @@ -20,6 +20,7 @@ import random import re from socket import * # pylint: disable=wildcard-import import struct +import sys import unittest from scapy import all as scapy @@ -91,7 +92,7 @@ AID_INET = 3003 KERN_INFO = 6 LINUX_VERSION = csocket.LinuxVersion() - +LINUX_ANY_VERSION = (0, 0) def GetWildcardAddress(version): return {4: "0.0.0.0", 6: "::"}[version] @@ -250,7 +251,7 @@ def CanonicalizeIPv6Address(addr): def FormatProcAddress(unformatted): groups = [] - for i in xrange(0, len(unformatted), 4): + for i in range(0, len(unformatted), 4): groups.append(unformatted[i:i+4]) formatted = ":".join(groups) # Compress the address. @@ -265,7 +266,7 @@ def FormatSockStatAddress(address): family = AF_INET binary = inet_pton(family, address) out = "" - for i in xrange(0, len(binary), 4): + for i in range(0, len(binary), 4): out += "%08X" % struct.unpack("=L", binary[i:i+4]) return out @@ -368,14 +369,14 @@ class RunAsUidGid(object): self.gid = gid def __enter__(self): + if self.gid: + self.saved_gid = os.getgid() + os.setgid(self.gid) if self.uid: self.saved_uids = os.getresuid() self.saved_groups = os.getgroups() os.setgroups(self.saved_groups + [AID_INET]) os.setresuid(self.uid, self.uid, self.saved_uids[0]) - if self.gid: - self.saved_gid = os.getgid() - os.setgid(self.gid) def __exit__(self, unused_type, unused_value, unused_traceback): if self.uid: @@ -392,6 +393,12 @@ class RunAsUid(RunAsUidGid): class NetworkTest(unittest.TestCase): + def assertRaisesRegex(self, *args, **kwargs): + if sys.version_info.major < 3: + return self.assertRaisesRegexp(*args, **kwargs) + else: + return super().assertRaisesRegex(*args, **kwargs) + def assertRaisesErrno(self, err_num, f=None, *args): """Test that the system returns an errno error. @@ -410,9 +417,9 @@ class NetworkTest(unittest.TestCase): """ msg = os.strerror(err_num) if f is None: - return self.assertRaisesRegexp(EnvironmentError, msg) + return self.assertRaisesRegex(EnvironmentError, msg) else: - self.assertRaisesRegexp(EnvironmentError, msg, f, *args) + self.assertRaisesRegex(EnvironmentError, msg, f, *args) def ReadProcNetSocket(self, protocol): # Read file. diff --git a/net/test/net_test.sh b/net/test/net_test.sh index 6a22c0e..52b168d 100755 --- a/net/test/net_test.sh +++ b/net/test/net_test.sh @@ -161,7 +161,13 @@ fi echo -e "Running $net_test $net_test_args\n" $net_test $net_test_args +rv="$?" # Write exit code of net_test to a file so that the builder can use it # to signal failure if any tests fail. -echo $? >$net_test_exitcode +echo "${rv}" > "${net_test_exitcode}" + +# Additionally on UML make it the exit code of UML kernel binary itself. +if [[ -e '/proc/exitcode' ]]; then + echo "${rv}" > /proc/exitcode +fi diff --git a/net/test/netlink.py b/net/test/netlink.py index 4e230d4..2c9c757 100644 --- a/net/test/netlink.py +++ b/net/test/netlink.py @@ -67,7 +67,7 @@ class NetlinkSocket(object): def _Debug(self, s): if self.DEBUG: - print s + print(s) def _NlAttr(self, nla_type, data): datalen = len(data) @@ -212,7 +212,7 @@ class NetlinkSocket(object): self._Debug(" %s" % nlmsghdr) if nlmsghdr.type == NLMSG_ERROR or nlmsghdr.type == NLMSG_DONE: - print "done" + print("done") return (None, None), data nlmsg, data = cstruct.Read(data, msgtype) diff --git a/net/test/pf_key.py b/net/test/pf_key.py index 875e01c..3136a85 100755 --- a/net/test/pf_key.py +++ b/net/test/pf_key.py @@ -190,6 +190,7 @@ def ParseExtension(exttype, data): return exttype, ext, attrs + class PfKey(object): """PF_KEY interface to kernel IPsec implementation.""" @@ -202,7 +203,7 @@ class PfKey(object): def Recv(self): reply = self.sock.recv(4096) msg = SadbMsg(reply) - # print "RECV:", self.DecodeSadbMsg(msg) + # print("RECV: " + self.DecodeSadbMsg(msg)) if msg.errno != 0: raise OSError(msg.errno, os.strerror(msg.errno)) return reply @@ -213,7 +214,7 @@ class PfKey(object): msg.pid = os.getpid() msg.len = (len(SadbMsg) + len(extensions)) / 8 self.sock.send(msg.Pack() + extensions) - # print "SEND:", self.DecodeSadbMsg(msg) + # print("SEND: " + self.DecodeSadbMsg(msg)) return self.Recv() def PackPfKeyExtensions(self, extlist): @@ -314,13 +315,14 @@ class PfKey(object): def PrintSaInfos(self, dump): for msg, extensions in dump: - print self.DecodeSadbMsg(msg) + print(self.DecodeSadbMsg(msg)) for exttype, ext, attrs in extensions: exttype = _GetMultiConstantName(exttype, ["SADB_EXT", "SADB_X_EXT"]) if exttype == SADB_EXT_SA: - print " ", exttype, self.DecodeSadbSa(ext), attrs.encode("hex") - print " ", exttype, ext, attrs.encode("hex") - print + print(" %s %s %s" % + (exttype, self.DecodeSadbSa(ext), attrs.encode("hex"))) + print(" %s %s %s" % (exttype, ext, attrs.encode("hex"))) + print("") if __name__ == "__main__": diff --git a/net/test/pf_key_test.py b/net/test/pf_key_test.py index e58947c..317ec7e 100755 --- a/net/test/pf_key_test.py +++ b/net/test/pf_key_test.py @@ -49,26 +49,26 @@ class PfKeyTest(unittest.TestCase): pf_key.SADB_X_AALG_SHA2_256HMAC, ENCRYPTION_KEY) sainfos = self.xfrm.DumpSaInfo() - self.assertEquals(2, len(sainfos)) + self.assertEqual(2, len(sainfos)) state4, attrs4 = [(s, a) for s, a in sainfos if s.family == AF_INET][0] state6, attrs6 = [(s, a) for s, a in sainfos if s.family == AF_INET6][0] pfkey_sainfos = self.pf_key.DumpSaInfo() - self.assertEquals(2, len(pfkey_sainfos)) + self.assertEqual(2, len(pfkey_sainfos)) self.assertTrue(all(msg.satype == pf_key.SDB_TYPE_ESP) for msg, _ in pfkey_sainfos) - self.assertEquals(xfrm.IPPROTO_ESP, state4.id.proto) - self.assertEquals(xfrm.IPPROTO_ESP, state6.id.proto) - self.assertEquals(54321, state4.reqid) - self.assertEquals(12345, state6.reqid) - self.assertEquals(0xdeadbeef, state4.id.spi) - self.assertEquals(0xbeefdead, state6.id.spi) + self.assertEqual(xfrm.IPPROTO_ESP, state4.id.proto) + self.assertEqual(xfrm.IPPROTO_ESP, state6.id.proto) + self.assertEqual(54321, state4.reqid) + self.assertEqual(12345, state6.reqid) + self.assertEqual(0xdeadbeef, state4.id.spi) + self.assertEqual(0xbeefdead, state6.id.spi) - self.assertEquals(xfrm.PaddedAddress("192.0.2.1"), state4.saddr) - self.assertEquals(xfrm.PaddedAddress("192.0.2.2"), state4.id.daddr) - self.assertEquals(xfrm.PaddedAddress("2001:db8::1"), state6.saddr) - self.assertEquals(xfrm.PaddedAddress("2001:db8::2"), state6.id.daddr) + self.assertEqual(xfrm.PaddedAddress("192.0.2.1"), state4.saddr) + self.assertEqual(xfrm.PaddedAddress("192.0.2.2"), state4.id.daddr) + self.assertEqual(xfrm.PaddedAddress("2001:db8::1"), state6.saddr) + self.assertEqual(xfrm.PaddedAddress("2001:db8::2"), state6.id.daddr) # The algorithm names are null-terminated, but after that contain garbage. # Kernel bug? @@ -79,20 +79,20 @@ class PfKeyTest(unittest.TestCase): self.assertTrue(attrs4["XFRMA_ALG_AUTH"].name.startswith(sha256_name)) self.assertTrue(attrs6["XFRMA_ALG_AUTH"].name.startswith(sha256_name)) - self.assertEquals(256, attrs4["XFRMA_ALG_CRYPT"].key_len) - self.assertEquals(256, attrs4["XFRMA_ALG_CRYPT"].key_len) - self.assertEquals(256, attrs6["XFRMA_ALG_AUTH"].key_len) - self.assertEquals(256, attrs6["XFRMA_ALG_AUTH"].key_len) - self.assertEquals(256, attrs6["XFRMA_ALG_AUTH_TRUNC"].key_len) - self.assertEquals(256, attrs6["XFRMA_ALG_AUTH_TRUNC"].key_len) + self.assertEqual(256, attrs4["XFRMA_ALG_CRYPT"].key_len) + self.assertEqual(256, attrs4["XFRMA_ALG_CRYPT"].key_len) + self.assertEqual(256, attrs6["XFRMA_ALG_AUTH"].key_len) + self.assertEqual(256, attrs6["XFRMA_ALG_AUTH"].key_len) + self.assertEqual(256, attrs6["XFRMA_ALG_AUTH_TRUNC"].key_len) + self.assertEqual(256, attrs6["XFRMA_ALG_AUTH_TRUNC"].key_len) - self.assertEquals(128, attrs4["XFRMA_ALG_AUTH_TRUNC"].trunc_len) - self.assertEquals(128, attrs4["XFRMA_ALG_AUTH_TRUNC"].trunc_len) + self.assertEqual(128, attrs4["XFRMA_ALG_AUTH_TRUNC"].trunc_len) + self.assertEqual(128, attrs4["XFRMA_ALG_AUTH_TRUNC"].trunc_len) self.pf_key.DelSa(src4, dst4, 0xdeadbeef, pf_key.SADB_TYPE_ESP) - self.assertEquals(1, len(self.xfrm.DumpSaInfo())) + self.assertEqual(1, len(self.xfrm.DumpSaInfo())) self.pf_key.DelSa(src6, dst6, 0xbeefdead, pf_key.SADB_TYPE_ESP) - self.assertEquals(0, len(self.xfrm.DumpSaInfo())) + self.assertEqual(0, len(self.xfrm.DumpSaInfo())) if __name__ == "__main__": diff --git a/net/test/ping6_test.py b/net/test/ping6_test.py index dd73e88..d551b5f 100755 --- a/net/test/ping6_test.py +++ b/net/test/ping6_test.py @@ -185,7 +185,7 @@ class PingReplyThread(threading.Thread): packet = scapy.Ether(src=self._routermac, dst=self._mymac) / packet try: posix.write(self._tun.fileno(), str(packet)) - except Exception, e: + except Exception as e: if not self._stopped: raise e @@ -194,12 +194,12 @@ class PingReplyThread(threading.Thread): while not self._stopped: try: packet = posix.read(self._tun.fileno(), 4096) - except OSError, e: + except OSError as e: if e.errno == errno.EAGAIN: continue else: break - except ValueError, e: + except ValueError as e: if not self._stopped: raise e @@ -220,9 +220,9 @@ class Ping6Test(multinetwork_base.MultiNetworkBaseTest): # tests fail. _INTERVAL = 0.1 _ATTEMPTS = 20 - for i in xrange(0, _ATTEMPTS): + for i in range(0, _ATTEMPTS): for netid in cls.NETIDS: - if all(thread.IsStarted() for thread in cls.reply_threads.values()): + if all(thread.IsStarted() for thread in list(cls.reply_threads.values())): return time.sleep(_INTERVAL) msg = "WARNING: reply threads not all started after %.1f seconds\n" % ( @@ -231,7 +231,7 @@ class Ping6Test(multinetwork_base.MultiNetworkBaseTest): @classmethod def StopReplyThreads(cls): - for thread in cls.reply_threads.values(): + for thread in list(cls.reply_threads.values()): thread.Stop() @classmethod @@ -295,9 +295,9 @@ class Ping6Test(multinetwork_base.MultiNetworkBaseTest): # Check that the flow label is zero and that the scope ID is sane. self.assertEqual(flowlabel, 0) if addr.startswith("fe80::"): - self.assertTrue(scope_id in self.ifindices.values()) + self.assertTrue(scope_id in list(self.ifindices.values())) else: - self.assertEquals(0, scope_id) + self.assertEqual(0, scope_id) # TODO: check the checksum. We can't do this easily now for ICMPv6 because # we don't have the IP addresses so we can't construct the pseudoheader. @@ -356,32 +356,32 @@ class Ping6Test(multinetwork_base.MultiNetworkBaseTest): def testIPv4PingUsingSendto(self): s = net_test.IPv4PingSocket() written = s.sendto(net_test.IPV4_PING, (net_test.IPV4_ADDR, 55)) - self.assertEquals(len(net_test.IPV4_PING), written) + self.assertEqual(len(net_test.IPV4_PING), written) self.assertValidPingResponse(s, net_test.IPV4_PING) def testIPv6PingUsingSendto(self): s = net_test.IPv6PingSocket() written = s.sendto(net_test.IPV6_PING, (net_test.IPV6_ADDR, 55)) - self.assertEquals(len(net_test.IPV6_PING), written) + self.assertEqual(len(net_test.IPV6_PING), written) self.assertValidPingResponse(s, net_test.IPV6_PING) def testIPv4NoCrash(self): # Python 2.x does not provide either read() or recvmsg. s = net_test.IPv4PingSocket() written = s.sendto(net_test.IPV4_PING, ("127.0.0.1", 55)) - self.assertEquals(len(net_test.IPV4_PING), written) + self.assertEqual(len(net_test.IPV4_PING), written) fd = s.fileno() reply = posix.read(fd, 4096) - self.assertEquals(written, len(reply)) + self.assertEqual(written, len(reply)) def testIPv6NoCrash(self): # Python 2.x does not provide either read() or recvmsg. s = net_test.IPv6PingSocket() written = s.sendto(net_test.IPV6_PING, ("::1", 55)) - self.assertEquals(len(net_test.IPV6_PING), written) + self.assertEqual(len(net_test.IPV6_PING), written) fd = s.fileno() reply = posix.read(fd, 4096) - self.assertEquals(written, len(reply)) + self.assertEqual(written, len(reply)) def testCrossProtocolCrash(self): # Checks that an ICMP error containing a ping packet that matches the ID @@ -458,12 +458,12 @@ class Ping6Test(multinetwork_base.MultiNetworkBaseTest): # Bind to unspecified address. s = net_test.IPv4PingSocket() s.bind(("0.0.0.0", 544)) - self.assertEquals(("0.0.0.0", 544), s.getsockname()) + self.assertEqual(("0.0.0.0", 544), s.getsockname()) # Bind to loopback. s = net_test.IPv4PingSocket() s.bind(("127.0.0.1", 99)) - self.assertEquals(("127.0.0.1", 99), s.getsockname()) + self.assertEqual(("127.0.0.1", 99), s.getsockname()) # Binding twice is not allowed. self.assertRaisesErrno(errno.EINVAL, s.bind, ("127.0.0.1", 22)) @@ -471,10 +471,10 @@ class Ping6Test(multinetwork_base.MultiNetworkBaseTest): # But binding two different sockets to the same ID is allowed. s2 = net_test.IPv4PingSocket() s2.bind(("127.0.0.1", 99)) - self.assertEquals(("127.0.0.1", 99), s2.getsockname()) + self.assertEqual(("127.0.0.1", 99), s2.getsockname()) s3 = net_test.IPv4PingSocket() s3.bind(("127.0.0.1", 99)) - self.assertEquals(("127.0.0.1", 99), s3.getsockname()) + self.assertEqual(("127.0.0.1", 99), s3.getsockname()) # If two sockets bind to the same port, the first one to call read() gets # the response. @@ -501,12 +501,12 @@ class Ping6Test(multinetwork_base.MultiNetworkBaseTest): # Bind to unspecified address. s = net_test.IPv6PingSocket() s.bind(("::", 769)) - self.assertEquals(("::", 769, 0, 0), s.getsockname()) + self.assertEqual(("::", 769, 0, 0), s.getsockname()) # Bind to loopback. s = net_test.IPv6PingSocket() s.bind(("::1", 99)) - self.assertEquals(("::1", 99, 0, 0), s.getsockname()) + self.assertEqual(("::1", 99, 0, 0), s.getsockname()) # Binding twice is not allowed. self.assertRaisesErrno(errno.EINVAL, s.bind, ("::1", 22)) @@ -514,10 +514,10 @@ class Ping6Test(multinetwork_base.MultiNetworkBaseTest): # But binding two different sockets to the same ID is allowed. s2 = net_test.IPv6PingSocket() s2.bind(("::1", 99)) - self.assertEquals(("::1", 99, 0, 0), s2.getsockname()) + self.assertEqual(("::1", 99, 0, 0), s2.getsockname()) s3 = net_test.IPv6PingSocket() s3.bind(("::1", 99)) - self.assertEquals(("::1", 99, 0, 0), s3.getsockname()) + self.assertEqual(("::1", 99, 0, 0), s3.getsockname()) # Binding both IPv4 and IPv6 to the same socket works. s4 = net_test.IPv4PingSocket() @@ -542,7 +542,7 @@ class Ping6Test(multinetwork_base.MultiNetworkBaseTest): try: s.setsockopt(SOL_IP, net_test.IP_TRANSPARENT, 1) s.bind((net_test.IPV4_ADDR, 651)) - except IOError, e: + except IOError as e: if e.errno == errno.EACCES: pass # We're not root. let it go for now. @@ -557,7 +557,7 @@ class Ping6Test(multinetwork_base.MultiNetworkBaseTest): try: s.setsockopt(net_test.SOL_IPV6, net_test.IPV6_TRANSPARENT, 1) s.bind((net_test.IPV6_ADDR, 651)) - except IOError, e: + except IOError as e: if e.errno == errno.EACCES: pass # We're not root. let it go for now. @@ -567,7 +567,7 @@ class Ping6Test(multinetwork_base.MultiNetworkBaseTest): sockaddr = csocket.Sockaddr(("0.0.0.0", 12996)) sockaddr.family = AF_UNSPEC csocket.Bind(s4, sockaddr) - self.assertEquals(("0.0.0.0", 12996), s4.getsockname()) + self.assertEqual(("0.0.0.0", 12996), s4.getsockname()) # But not if the address is anything else. sockaddr = csocket.Sockaddr(("127.0.0.1", 58234)) @@ -591,7 +591,7 @@ class Ping6Test(multinetwork_base.MultiNetworkBaseTest): # returns "fe80:1%foo", even though it does not understand it. expected = self.lladdr + "%" + self.ifname s.bind((self.lladdr, 4646, 0, self.ifindex)) - self.assertEquals((expected, 4646, 0, self.ifindex), s.getsockname()) + self.assertEqual((expected, 4646, 0, self.ifindex), s.getsockname()) # Of course, for the above to work the address actually has to be configured # on the machine. @@ -601,18 +601,18 @@ class Ping6Test(multinetwork_base.MultiNetworkBaseTest): # Scope IDs on non-link-local addresses are silently ignored. s = net_test.IPv6PingSocket() s.bind(("::1", 1234, 0, 1)) - self.assertEquals(("::1", 1234, 0, 0), s.getsockname()) + self.assertEqual(("::1", 1234, 0, 0), s.getsockname()) def testBindAffectsIdentifier(self): s = net_test.IPv6PingSocket() s.bind((self.globaladdr, 0xf976)) s.sendto(net_test.IPV6_PING, (net_test.IPV6_ADDR, 55)) - self.assertEquals("\xf9\x76", s.recv(32768)[4:6]) + self.assertEqual("\xf9\x76", s.recv(32768)[4:6]) s = net_test.IPv6PingSocket() s.bind((self.globaladdr, 0xace)) s.sendto(net_test.IPV6_PING, (net_test.IPV6_ADDR, 55)) - self.assertEquals("\x0a\xce", s.recv(32768)[4:6]) + self.assertEqual("\x0a\xce", s.recv(32768)[4:6]) def testLinkLocalAddress(self): s = net_test.IPv6PingSocket() @@ -751,8 +751,8 @@ class Ping6Test(multinetwork_base.MultiNetworkBaseTest): s = net_test.Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP) s.bind(("127.0.0.1", 0xace)) s.connect(("127.0.0.1", 0xbeef)) - self.assertEquals(numrows + 1, len(self.ReadProcNetSocket("icmp"))) - self.assertEquals(numrows6, len(self.ReadProcNetSocket("icmp6"))) + self.assertEqual(numrows + 1, len(self.ReadProcNetSocket("icmp"))) + self.assertEqual(numrows6, len(self.ReadProcNetSocket("icmp6"))) @unittest.skipUnless(HAVE_PROC_NET_ICMP6, "skipping: no /proc/net/icmp6") def testIcmp6SocketsNotInIcmp(self): @@ -761,8 +761,8 @@ class Ping6Test(multinetwork_base.MultiNetworkBaseTest): s = net_test.IPv6PingSocket() s.bind(("::1", 0xace)) s.connect(("::1", 0xbeef)) - self.assertEquals(numrows, len(self.ReadProcNetSocket("icmp"))) - self.assertEquals(numrows6 + 1, len(self.ReadProcNetSocket("icmp6"))) + self.assertEqual(numrows, len(self.ReadProcNetSocket("icmp"))) + self.assertEqual(numrows6 + 1, len(self.ReadProcNetSocket("icmp6"))) def testProcNetIcmp(self): s = net_test.Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP) @@ -780,7 +780,7 @@ class Ping6Test(multinetwork_base.MultiNetworkBaseTest): # Check the row goes away when the socket is closed. s.close() - self.assertEquals(numrows6, len(self.ReadProcNetSocket("icmp6"))) + self.assertEqual(numrows6, len(self.ReadProcNetSocket("icmp6"))) # Try send, bind and connect to check the addresses and the state. s = net_test.IPv6PingSocket() @@ -841,7 +841,7 @@ class Ping6Test(multinetwork_base.MultiNetworkBaseTest): ident = struct.pack("!H", s.getsockname()[1]) pkt = pkt[:4] + ident + pkt[6:] data = data[:2] + "\x00\x00" + pkt[4:] - self.assertEquals(pkt, data) + self.assertEqual(pkt, data) # Check the address that the packet was sent to. # ... except in 4.1, where it just returns an AF_UNSPEC, like this: @@ -851,7 +851,7 @@ class Ping6Test(multinetwork_base.MultiNetworkBaseTest): # msg_controllen=64, {cmsg_len=60, cmsg_level=SOL_IPV6, cmsg_type=, ...}, # msg_flags=MSG_ERRQUEUE}, MSG_ERRQUEUE) = 1232 if net_test.LINUX_VERSION != (4, 1, 0): - self.assertEquals(csocket.Sockaddr(("2001:4860:4860::8888", 0)), addr) + self.assertEqual(csocket.Sockaddr(("2001:4860:4860::8888", 0)), addr) # Check the cmsg data, including the link MTU. mtu = PingReplyThread.LINK_MTU @@ -869,7 +869,7 @@ class Ping6Test(multinetwork_base.MultiNetworkBaseTest): if net_test.LINUX_VERSION <= (3, 14, 0): msglist[0][2][1].port = cmsg[0][2][1].port - self.assertEquals(msglist, cmsg) + self.assertEqual(msglist, cmsg) if __name__ == "__main__": diff --git a/net/test/removed_feature_test.py b/net/test/removed_feature_test.py index 487af41..e58b4e3 100755 --- a/net/test/removed_feature_test.py +++ b/net/test/removed_feature_test.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import errno +from socket import * # pylint: disable=wildcard-import import unittest import gzip @@ -66,6 +68,27 @@ class RemovedFeatureTest(net_test.NetworkTest): self.assertFeatureEnabled("CONFIG_IP6_NF_TARGET_REJECT") self.assertFeatureAbsent("CONFIG_IP6_NF_TARGET_REJECT_SKERR") + @unittest.skipUnless(net_test.LINUX_VERSION >= (4, 19, 0), "removed in 4.14-r") + def testRemovedAndroidParanoidNetwork(self): + """Verify that ANDROID_PARANOID_NETWORK is gone.""" + + AID_NET_RAW = 3004 + with net_test.RunAsUidGid(12345, AID_NET_RAW): + self.assertRaisesErrno(errno.EPERM, socket, AF_PACKET, SOCK_RAW, 0) + + @unittest.skipUnless(net_test.LINUX_VERSION >= (4, 19, 0), "exists in 4.14-P") + def testRemovedQtaguid(self): + self.assertRaisesErrno(errno.ENOENT, open, "/proc/net/xt_qtaguid") + + @unittest.skipUnless(net_test.LINUX_VERSION >= (4, 19, 0), "exists in 4.14-P") + def testRemovedTcpMemSysctls(self): + self.assertRaisesErrno(errno.ENOENT, open, "/sys/kernel/ipv4/tcp_rmem_def") + self.assertRaisesErrno(errno.ENOENT, open, "/sys/kernel/ipv4/tcp_rmem_max") + self.assertRaisesErrno(errno.ENOENT, open, "/sys/kernel/ipv4/tcp_rmem_min") + self.assertRaisesErrno(errno.ENOENT, open, "/sys/kernel/ipv4/tcp_wmem_def") + self.assertRaisesErrno(errno.ENOENT, open, "/sys/kernel/ipv4/tcp_wmem_max") + self.assertRaisesErrno(errno.ENOENT, open, "/sys/kernel/ipv4/tcp_wmem_min") + if __name__ == "__main__": unittest.main() diff --git a/net/test/resilient_rs_test.py b/net/test/resilient_rs_test.py index 12843c4..be3210b 100755 --- a/net/test/resilient_rs_test.py +++ b/net/test/resilient_rs_test.py @@ -97,14 +97,14 @@ class ResilientRouterSolicitationTest(multinetwork_base.MultiNetworkBaseTest): def testRouterSolicitationBackoff(self): # Test error tolerance - EPSILON = 0.1 + EPSILON = 0.15 # Minimum RFC3315 S14 backoff MIN_EXP = 1.9 - EPSILON # Maximum RFC3315 S14 backoff MAX_EXP = 2.1 + EPSILON SOLICITATION_INTERVAL = 1 - # Linear backoff for 4 samples yields 3.6 < T < 4.4 - # Exponential backoff for 4 samples yields 4.83 < T < 9.65 + # Linear backoff for 4 samples yields 3.5 < T < 4.5 + # Exponential backoff for 4 samples yields 4.36 < T < 10.39 REQUIRED_SAMPLES = 4 # Give up after 10 seconds. Tuned for REQUIRED_SAMPLES = 4 SAMPLE_INTERVAL = 10 @@ -159,8 +159,8 @@ class ResilientRouterSolicitationTest(multinetwork_base.MultiNetworkBaseTest): # Compute minimum and maximum bounds for RFC3315 S14 exponential backoff. # First retransmit is linear backoff, subsequent retransmits are exponential - min_exp_bound = accumulate(map(lambda i: MIN_LIN * pow(MIN_EXP, i), range(0, len(rsSendTimes)))) - max_exp_bound = accumulate(map(lambda i: MAX_LIN * pow(MAX_EXP, i), range(0, len(rsSendTimes)))) + min_exp_bound = accumulate([MIN_LIN * pow(MIN_EXP, i) for i in range(0, len(rsSendTimes))]) + max_exp_bound = accumulate([MAX_LIN * pow(MAX_EXP, i) for i in range(0, len(rsSendTimes))]) # Assert that each sample falls within the worst case interval. If all samples fit we accept # the exponential backoff hypothesis diff --git a/net/test/run_net_test.sh b/net/test/run_net_test.sh index d1a66f5..6f32f81 100755 --- a/net/test/run_net_test.sh +++ b/net/test/run_net_test.sh @@ -12,7 +12,7 @@ EOF } # Common kernel options -OPTIONS=" DEBUG_SPINLOCK DEBUG_ATOMIC_SLEEP DEBUG_MUTEXES DEBUG_RT_MUTEXES" +OPTIONS=" ANDROID DEBUG_SPINLOCK DEBUG_ATOMIC_SLEEP DEBUG_MUTEXES DEBUG_RT_MUTEXES" OPTIONS="$OPTIONS WARN_ALL_UNSEEDED_RANDOM IKCONFIG IKCONFIG_PROC" OPTIONS="$OPTIONS DEVTMPFS DEVTMPFS_MOUNT FHANDLE" OPTIONS="$OPTIONS IPV6 IPV6_ROUTER_PREF IPV6_MULTIPLE_TABLES IPV6_ROUTE_INFO" @@ -23,6 +23,7 @@ OPTIONS="$OPTIONS IP_NF_IPTABLES IP_NF_MANGLE IP_NF_FILTER" OPTIONS="$OPTIONS IP6_NF_IPTABLES IP6_NF_MANGLE IP6_NF_FILTER INET6_IPCOMP" OPTIONS="$OPTIONS IPV6_OPTIMISTIC_DAD" OPTIONS="$OPTIONS IPV6_ROUTE_INFO IPV6_ROUTER_PREF" +OPTIONS="$OPTIONS NETFILTER_XT_TARGET_IDLETIMER" OPTIONS="$OPTIONS NETFILTER_XT_TARGET_NFLOG" OPTIONS="$OPTIONS NETFILTER_XT_MATCH_POLICY" OPTIONS="$OPTIONS NETFILTER_XT_MATCH_QUOTA" @@ -36,6 +37,7 @@ OPTIONS="$OPTIONS IP_NF_TARGET_REJECT IP_NF_TARGET_REJECT_SKERR" OPTIONS="$OPTIONS IP6_NF_TARGET_REJECT IP6_NF_TARGET_REJECT_SKERR" OPTIONS="$OPTIONS NET_KEY XFRM_USER XFRM_STATISTICS CRYPTO_CBC" OPTIONS="$OPTIONS CRYPTO_CTR CRYPTO_HMAC CRYPTO_AES CRYPTO_SHA1" +OPTIONS="$OPTIONS CRYPTO_XCBC CRYPTO_CHACHA20POLY1305" OPTIONS="$OPTIONS CRYPTO_USER INET_ESP INET_XFRM_MODE_TRANSPORT" OPTIONS="$OPTIONS INET_XFRM_MODE_TUNNEL INET6_ESP" OPTIONS="$OPTIONS INET6_XFRM_MODE_TRANSPORT INET6_XFRM_MODE_TUNNEL" @@ -45,6 +47,7 @@ OPTIONS="$OPTIONS DUMMY" # Kernel version specific options OPTIONS="$OPTIONS XFRM_INTERFACE" # Various device kernels +OPTIONS="$OPTIONS XFRM_MIGRATE" # Added in 5.10 OPTIONS="$OPTIONS CGROUP_BPF" # Added in android-4.9 OPTIONS="$OPTIONS NF_SOCKET_IPV4 NF_SOCKET_IPV6" # Added in 4.9 OPTIONS="$OPTIONS INET_SCTP_DIAG" # Added in 4.7 @@ -209,6 +212,9 @@ if [ ! -f $ROOTFS ]; then echo "Uncompressing $COMPRESSED_ROOTFS" >&2 unxz $COMPRESSED_ROOTFS fi +if ! [[ "${ROOTFS}" =~ ^/ ]]; then + ROOTFS="${SCRIPT_DIR}/${ROOTFS}" +fi echo "Using $ROOTFS" cd - @@ -257,7 +263,7 @@ if ((nobuild == 0)); then if [ "$ARCH" == "um" ]; then # Exporting ARCH=um SUBARCH=x86_64 doesn't seem to work, as it # "sometimes" (?) results in a 32-bit kernel. - make_flags="$make_flags ARCH=$ARCH SUBARCH=x86_64 CROSS_COMPILE= " + make_flags="$make_flags ARCH=$ARCH SUBARCH=${SUBARCH:-x86_64} CROSS_COMPILE= " fi if [ -n "$CC" ]; then # The CC flag is *not* inherited from the environment, so it must be @@ -318,8 +324,10 @@ if [ "$ARCH" == "um" ]; then # Get the absolute path to the test file that's being run. cmdline="$cmdline net_test=/host$SCRIPT_DIR/$test" - # Use UML's /proc/exitcode feature to communicate errors on test failure - cmdline="$cmdline net_test_exitcode=/proc/exitcode" + # We'd use UML's /proc/exitcode feature to communicate errors on test failure, + # if not for UML having a tendency to crash during shutdown, + # so instead use an extra serial line we'll redirect to an open fd... + cmdline="$cmdline net_test_exitcode=/dev/ttyS3" # Map the --readonly flag to UML block device names if ((nowrite == 0)); then @@ -328,11 +336,30 @@ if [ "$ARCH" == "um" ]; then blockdevice=ubdar fi + # Create a temp file for 'serial line 3' for return code. + SSL3="$(mktemp)" + exitcode=0 - $KERNEL_BINARY >&2 umid=net_test mem=512M \ - $blockdevice=$SCRIPT_DIR/$ROOTFS $netconfig $consolemode $cmdline \ + $KERNEL_BINARY >&2 3>"${SSL3}" umid=net_test mem=512M \ + $blockdevice=$ROOTFS $netconfig $consolemode ssl3=null,fd:3 $cmdline \ || exitcode=$? + if [[ "${exitcode}" == 134 && -s "${SSL3}" && "$(tr -d '\r' < "${SSL3}")" == 0 ]]; then + # Sometimes the tests all pass, but UML crashes during the shutdown process itself. + # As such we can't actually rely on the /proc/exitcode returned value. + echo "Warning: UML appears to have crashed after successfully executing the tests." 1>&2 + elif [[ "${exitcode}" != 0 ]]; then + echo "Warning: UML exited with ${exitcode} instead of zero." 1>&2 + fi + + if [[ -s "${SSL3}" ]]; then + exitcode="$(tr -d '\r' < "${SSL3}")" + echo "Info: retrieved exit code ${exitcode}." 1>&2 + fi + + rm -f "${SSL3}" + unset SSL3 + # UML is kind of crazy in how guest syscalls work. It requires host kernel # to not be in vsyscall=none mode. if [[ "${exitcode}" != '0' ]]; then @@ -369,7 +396,7 @@ else else blockdevice= fi - blockdevice="-drive file=$SCRIPT_DIR/$ROOTFS,format=raw,if=none,id=drive-virtio-disk0$blockdevice" + blockdevice="-drive file=$ROOTFS,format=raw,if=none,id=drive-virtio-disk0$blockdevice" blockdevice="$blockdevice -device virtio-blk-pci,drive=drive-virtio-disk0" # Pass through our current console/screen size to inner shell session diff --git a/net/test/sock_diag.py b/net/test/sock_diag.py index 46cc92d..03d5587 100755 --- a/net/test/sock_diag.py +++ b/net/test/sock_diag.py @@ -164,7 +164,7 @@ class SockDiag(netlink.NetlinkSocket): if "ALL" not in self.NL_DEBUG and "SOCK" not in self.NL_DEBUG: return parsed = self._ParseNLMsg(data, InetDiagReqV2) - print "%s %s" % (name, str(parsed)) + print("%s %s" % (name, str(parsed))) @staticmethod def _EmptyInetDiagSockId(): @@ -246,15 +246,15 @@ class SockDiag(netlink.NetlinkSocket): positions.append(positions[-1] + 4) # Why 4? Because the kernel uses 4. assert len(args) == len(instructions) == len(positions) - 2 - # print positions + # print(positions) packed = "" for i, (op, yes, no, arg) in enumerate(instructions): yes = positions[i + yes] - positions[i] no = positions[i + no] - positions[i] instruction = InetDiagBcOp((op, yes, no)).Pack() + args[i] - #print "%3d: %d %3d %3d %s %s" % (positions[i], op, yes, no, - # arg, instruction.encode("hex")) + #print("%3d: %d %3d %3d %s %s" % (positions[i], op, yes, no, + # arg, instruction.encode("hex"))) packed += instruction #print @@ -363,7 +363,7 @@ class SockDiag(netlink.NetlinkSocket): src, sport = s.getsockname()[:2] try: dst, dport = s.getpeername()[:2] - except error, e: + except error as e: if e.errno == errno.ENOTCONN: dport = 0 dst = "::" if family == AF_INET6 else "0.0.0.0" @@ -430,4 +430,4 @@ if __name__ == "__main__": states = 0xffffffff diag_msgs = n.DumpAllInetSockets(IPPROTO_TCP, "", sock_id=sock_id, ext=ext, states=states) - print diag_msgs + print(diag_msgs) diff --git a/net/test/sock_diag_test.py b/net/test/sock_diag_test.py index daa2fa4..39ace4c 100755 --- a/net/test/sock_diag_test.py +++ b/net/test/sock_diag_test.py @@ -105,7 +105,7 @@ class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest): def _CreateLotsOfSockets(socktype): # Dict mapping (addr, sport, dport) tuples to socketpairs. socketpairs = {} - for _ in xrange(NUM_SOCKETS): + for _ in range(NUM_SOCKETS): family, addr = random.choice([ (AF_INET, "127.0.0.1"), (AF_INET6, "::1"), @@ -151,7 +151,7 @@ class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest): def PackAndCheckBytecode(self, instructions): bytecode = self.sock_diag.PackBytecode(instructions) decoded = self.sock_diag.DecodeBytecode(bytecode) - self.assertEquals(len(instructions), len(decoded)) + self.assertEqual(len(instructions), len(decoded)) self.assertFalse("???" in decoded) return bytecode @@ -192,7 +192,7 @@ class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest): self.socketpairs = {} def tearDown(self): - for socketpair in self.socketpairs.values(): + for socketpair in list(self.socketpairs.values()): for s in socketpair: s.close() super(SockDiagBaseTest, self).tearDown() @@ -228,9 +228,9 @@ class SockDiagTest(SockDiagBaseTest): cookies[(addr, sport, dport)] = diag_msg.id.cookie # Did we find all the cookies? - self.assertEquals(2 * NUM_SOCKETS, len(cookies)) + self.assertEqual(2 * NUM_SOCKETS, len(cookies)) - socketpairs = self.socketpairs.values() + socketpairs = list(self.socketpairs.values()) random.shuffle(socketpairs) for socketpair in socketpairs: for sock in socketpair: @@ -284,7 +284,7 @@ class SockDiagTest(SockDiagBaseTest): ) states = 1 << tcp_test.TCP_ESTABLISHED self.assertMultiLineEqual(expected, bytecode.encode("hex")) - self.assertEquals(76, len(bytecode)) + self.assertEqual(76, len(bytecode)) self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM) filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode, states=states) @@ -294,7 +294,7 @@ class SockDiagTest(SockDiagBaseTest): # Pick a few sockets in hash table order, and check that the bytecode we # compiled selects them properly. - for socketpair in self.socketpairs.values()[:20]: + for socketpair in list(self.socketpairs.values())[:20]: for s in socketpair: diag_msg = self.sock_diag.FindSockDiagFromFd(s) instructions = [ @@ -304,12 +304,12 @@ class SockDiagTest(SockDiagBaseTest): (sock_diag.INET_DIAG_BC_D_LE, 1, 2, diag_msg.id.dport), ] bytecode = self.PackAndCheckBytecode(instructions) - self.assertEquals(32, len(bytecode)) + self.assertEqual(32, len(bytecode)) sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode) - self.assertEquals(1, len(sockets)) + self.assertEqual(1, len(sockets)) # TODO: why doesn't comparing the cstructs work? - self.assertEquals(diag_msg.Pack(), sockets[0][0].Pack()) + self.assertEqual(diag_msg.Pack(), sockets[0][0].Pack()) def testCrossFamilyBytecode(self): """Checks for a cross-family bug in inet_diag_hostcond matching. @@ -365,7 +365,7 @@ class SockDiagTest(SockDiagBaseTest): 5e1f542 inet_diag: validate port comparison byte code to prevent unsafe reads """ bytecode = sock_diag.InetDiagBcOp((sock_diag.INET_DIAG_BC_D_GE, 4, 8)) - self.assertEquals("???", + self.assertEqual("???", self.sock_diag.DecodeBytecode(bytecode)) self.assertRaisesErrno( EINVAL, @@ -481,7 +481,7 @@ class SockDestroyTest(SockDiagBaseTest): def testClosesSockets(self): self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM) - for _, socketpair in self.socketpairs.iteritems(): + for _, socketpair in self.socketpairs.items(): # Close one of the sockets. # This will send a RST that will close the other side as well. s = random.choice(socketpair) @@ -521,7 +521,7 @@ class SocketExceptionThread(threading.Thread): def run(self): try: self.operation(self.sock) - except (IOError, AssertionError), e: + except (IOError, AssertionError) as e: self.exception = e @@ -534,7 +534,7 @@ class SockDiagTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest): android-3.4: 457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state """ - netid = random.choice(self.tuns.keys()) + netid = random.choice(list(self.tuns.keys())) self.IncomingConnection(5, tcp_test.TCP_SYN_RECV, netid) sock_id = self.sock_diag._EmptyInetDiagSockId() sock_id.sport = self.port @@ -599,7 +599,7 @@ class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest): def setUp(self): super(SockDestroyTcpTest, self).setUp() - self.netid = random.choice(self.tuns.keys()) + self.netid = random.choice(list(self.tuns.keys())) def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True): """Closes the socket and checks whether a RST is sent or not.""" @@ -650,7 +650,7 @@ class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest): self.accepted.close() diag_req.states = 1 << tcp_test.TCP_FIN_WAIT1 diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req) - self.assertEquals(tcp_test.TCP_FIN_WAIT1, diag_msg.state) + self.assertEqual(tcp_test.TCP_FIN_WAIT1, diag_msg.state) desc, fin = self.FinPacket() self.ExpectPacketOn(self.netid, "Closing FIN_WAIT1 socket", fin) @@ -660,7 +660,7 @@ class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest): # The socket is still there in FIN_WAIT1: SOCK_DESTROY did nothing # because userspace had already closed it. - self.assertEquals(tcp_test.TCP_FIN_WAIT1, diag_msg.state) + self.assertEqual(tcp_test.TCP_FIN_WAIT1, diag_msg.state) # ACK the FIN so we don't trip over retransmits in future tests. finversion = 4 if version == 5 else version @@ -702,7 +702,7 @@ class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest): d = self.sock_diag.FindSockDiagFromFd(self.s) parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) children = self.FindChildSockets(self.s) - self.assertEquals(1, len(children)) + self.assertEqual(1, len(children)) is_established = (state == tcp_test.TCP_NOT_YET_ACCEPTED) expected_state = tcp_test.TCP_ESTABLISHED if is_established else state @@ -716,7 +716,7 @@ class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest): for child in children: if can_close_children: diag_msg, attrs = self.sock_diag.GetSockInfo(child) - self.assertEquals(diag_msg.state, expected_state) + self.assertEqual(diag_msg.state, expected_state) self.assertMarkIs(self.netid, attrs) else: self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child) @@ -772,7 +772,7 @@ class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest): self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo") self.assertRaisesErrno(EINVAL, self.s.accept) # TODO: this should really return an error such as ENOTCONN... - self.assertEquals("", self.s.recv(4096)) + self.assertEqual("", self.s.recv(4096)) def testReadInterrupted(self): """Tests that read() is interrupted by SOCK_DESTROY.""" @@ -782,8 +782,8 @@ class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest): ECONNABORTED) # Writing returns EPIPE, and reading returns EOF. self.assertRaisesErrno(EPIPE, self.accepted.send, "foo") - self.assertEquals("", self.accepted.recv(4096)) - self.assertEquals("", self.accepted.recv(4096)) + self.assertEqual("", self.accepted.recv(4096)) + self.assertEqual("", self.accepted.recv(4096)) def testConnectInterrupted(self): """Tests that connect() is interrupted by SOCK_DESTROY.""" @@ -818,7 +818,7 @@ class PollOnCloseTest(tcp_test.TcpBaseTest, SockDiagBaseTest): def setUp(self): super(PollOnCloseTest, self).setUp() - self.netid = random.choice(self.tuns.keys()) + self.netid = random.choice(list(self.tuns.keys())) POLL_FLAGS = [(select.POLLIN, "IN"), (select.POLLOUT, "OUT"), (select.POLLERR, "ERR"), (select.POLLHUP, "HUP")] @@ -852,8 +852,8 @@ class PollOnCloseTest(tcp_test.TcpBaseTest, SockDiagBaseTest): # Subsequent operations behave as normal. self.assertRaisesErrno(EPIPE, self.accepted.send, "foo") - self.assertEquals("", self.accepted.recv(4096)) - self.assertEquals("", self.accepted.recv(4096)) + self.assertEqual("", self.accepted.recv(4096)) + self.assertEqual("", self.accepted.recv(4096)) def CheckPollDestroy(self, mask, expected, ignoremask): """Interrupts a poll() with SOCK_DESTROY.""" @@ -917,7 +917,7 @@ class SockDestroyUdpTest(SockDiagBaseTest): def testClosesUdpSockets(self): self.socketpairs = self._CreateLotsOfSockets(SOCK_DGRAM) - for _, socketpair in self.socketpairs.iteritems(): + for _, socketpair in self.socketpairs.items(): s1, s2 = socketpair self.assertSocketConnected(s1) @@ -930,12 +930,12 @@ class SockDestroyUdpTest(SockDiagBaseTest): def BindToRandomPort(self, s, addr): ATTEMPTS = 20 - for i in xrange(20): + for i in range(20): port = random.randrange(1024, 65535) try: s.bind((addr, port)) return port - except error, e: + except error as e: if e.errno != EADDRINUSE: raise e raise ValueError("Could not find a free port on %s after %d attempts" % @@ -989,13 +989,13 @@ class SockDestroyUdpTest(SockDiagBaseTest): # Check that reads on connected sockets are interrupted. s.connect((addr, 53)) - self.assertEquals(3, s.send("foo")) + self.assertEqual(3, s.send("foo")) self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096), ECONNABORTED) # A destroyed socket is no longer connected, but still usable. self.assertRaisesErrno(EDESTADDRREQ, s.send, "foo") - self.assertEquals(3, s.sendto("foo", (addr, 53))) + self.assertEqual(3, s.sendto("foo", (addr, 53))) # Check that reads on unconnected sockets are also interrupted. self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096), @@ -1049,7 +1049,7 @@ class SockDiagMarkTest(tcp_test.TcpBaseTest, SockDiagBaseTest): def assertSamePorts(self, ports, diag_msgs): expected = sorted(ports) actual = sorted([msg[0].id.sport for msg in diag_msgs]) - self.assertEquals(expected, actual) + self.assertEqual(expected, actual) def SockInfoMatchesSocket(self, s, info): try: @@ -1076,7 +1076,7 @@ class SockDiagMarkTest(tcp_test.TcpBaseTest, SockDiagBaseTest): self.SocketDescription(s)) for i in infos: - if i not in matches.values(): + if i not in list(matches.values()): self.fail("Too many sockets in dump, first unexpected: %s" % str(i)) def testMarkBytecode(self): @@ -1101,7 +1101,7 @@ class SockDiagMarkTest(tcp_test.TcpBaseTest, SockDiagBaseTest): self.assertFoundSockets(infos, [s1, s2]) infos = self.FilterEstablishedSockets(0xfff0000, 0xf0fed00) - self.assertEquals(0, len(infos)) + self.assertEqual(0, len(infos)) with net_test.RunAsUid(12345): self.assertRaisesErrno(EPERM, self.FilterEstablishedSockets, diff --git a/net/test/srcaddr_selection_test.py b/net/test/srcaddr_selection_test.py index e57ce16..1e7a107 100755 --- a/net/test/srcaddr_selection_test.py +++ b/net/test/srcaddr_selection_test.py @@ -91,11 +91,11 @@ class IPv6SourceAddressSelectionTest(multinetwork_base.MultiNetworkBaseTest): def assertAddressHasExpectedAttributes( self, address, expected_ifindex, expected_flags): ifa_msg = self.iproute.GetAddress(address)[0] - self.assertEquals(AF_INET6 if ":" in address else AF_INET, ifa_msg.family) - self.assertEquals(64, ifa_msg.prefixlen) - self.assertEquals(iproute.RT_SCOPE_UNIVERSE, ifa_msg.scope) - self.assertEquals(expected_ifindex, ifa_msg.index) - self.assertEquals(expected_flags, ifa_msg.flags & expected_flags) + self.assertEqual(AF_INET6 if ":" in address else AF_INET, ifa_msg.family) + self.assertEqual(64, ifa_msg.prefixlen) + self.assertEqual(iproute.RT_SCOPE_UNIVERSE, ifa_msg.scope) + self.assertEqual(expected_ifindex, ifa_msg.index) + self.assertEqual(expected_flags, ifa_msg.flags & expected_flags) def AddressIsTentative(self, address): ifa_msg = self.iproute.GetAddress(address)[0] @@ -122,13 +122,13 @@ class IPv6SourceAddressSelectionTest(multinetwork_base.MultiNetworkBaseTest): self.SendWithSourceAddress, address, netid) def assertAddressSelected(self, address, netid): - self.assertEquals(address, self.GetSourceIP(netid)) + self.assertEqual(address, self.GetSourceIP(netid)) def assertAddressNotSelected(self, address, netid): - self.assertNotEquals(address, self.GetSourceIP(netid)) + self.assertNotEqual(address, self.GetSourceIP(netid)) def WaitForDad(self, address): - for _ in xrange(20): + for _ in range(20): if not self.AddressIsTentative(address): return time.sleep(0.1) @@ -149,7 +149,7 @@ class MultiInterfaceSourceAddressSelectionTest(IPv6SourceAddressSelectionTest): self.SetIPv6Sysctl(ifname, "use_oif_addrs_only", 0) # [1] Pick an interface on which to test. - self.test_netid = random.choice(self.tuns.keys()) + self.test_netid = random.choice(list(self.tuns.keys())) self.test_ip = self.MyAddress(6, self.test_netid) self.test_ifindex = self.ifindices[self.test_netid] self.test_ifname = self.GetInterfaceName(self.test_netid) @@ -254,7 +254,7 @@ class ValidBeforeOptimisticTest(MultiInterfaceSourceAddressSelectionTest): self.iproute.AddAddress(preferred_ip, 64, self.test_ifindex) self.assertAddressHasExpectedAttributes( preferred_ip, self.test_ifindex, iproute.IFA_F_PERMANENT) - self.assertEquals(preferred_ip, self.GetSourceIP(self.test_netid)) + self.assertEqual(preferred_ip, self.GetSourceIP(self.test_netid)) # [4] Get another IPv6 address, in optimistic DAD start-up. self.SetDAD(self.test_ifname, 1) # Enable DAD diff --git a/net/test/sysctls_test.py b/net/test/sysctls_test.py new file mode 100755 index 0000000..cb608f6 --- /dev/null +++ b/net/test/sysctls_test.py @@ -0,0 +1,43 @@ +#!/usr/bin/python +# +# Copyright 2021 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import net_test + + +class SysctlsTest(net_test.NetworkTest): + + def check(self, f): + algs = open(f).readline().strip().split(' ') + bad_algs = [a for a in algs if a not in ['cubic', 'reno']] + msg = ("Obsolete TCP congestion control algorithm found. These " + "algorithms will decrease real-world networking performance for " + "users and must be disabled. Found: %s" % bad_algs) + self.assertEqual(bad_algs, [], msg) + + def testAllowedCongestionControl(self): + self.check('/proc/sys/net/ipv4/tcp_allowed_congestion_control') + + def testAvailableCongestionControl(self): + self.check('/proc/sys/net/ipv4/tcp_available_congestion_control') + + def testCongestionControl(self): + self.check('/proc/sys/net/ipv4/tcp_congestion_control') + + +if __name__ == "__main__": + unittest.main() diff --git a/net/test/tcp_fastopen_test.py b/net/test/tcp_fastopen_test.py index 9257a19..eadae79 100755 --- a/net/test/tcp_fastopen_test.py +++ b/net/test/tcp_fastopen_test.py @@ -55,7 +55,7 @@ class TcpFastOpenTest(multinetwork_base.MultiNetworkBaseTest): daddr = self.GetRemoteAddress(version) self.tcp_metrics.DelMetrics(saddr, daddr) with self.assertRaisesErrno(ESRCH): - print self.tcp_metrics.GetMetrics(saddr, daddr) + print(self.tcp_metrics.GetMetrics(saddr, daddr)) def assertNoTcpMetrics(self, version, netid): saddr = self.MyAddress(version, netid) diff --git a/net/test/tcp_metrics.py b/net/test/tcp_metrics.py index 574a755..03f604f 100755 --- a/net/test/tcp_metrics.py +++ b/net/test/tcp_metrics.py @@ -134,4 +134,4 @@ class TcpMetrics(genetlink.GenericNetlink): if __name__ == "__main__": t = TcpMetrics() - print t.DumpMetrics() + print(t.DumpMetrics()) diff --git a/net/test/tcp_nuke_addr_test.py b/net/test/tcp_nuke_addr_test.py index 1f0de76..e5d17b2 100755 --- a/net/test/tcp_nuke_addr_test.py +++ b/net/test/tcp_nuke_addr_test.py @@ -88,8 +88,8 @@ class TcpNukeAddrTest(net_test.NetworkTest): self.assertRaisesErrno(errno.ENOTTY, KillAddrIoctl, addr) data = "foo" try: - self.assertEquals(len(data), s1.send(data)) - self.assertEquals(data, s2.recv(4096)) + self.assertEqual(len(data), s1.send(data)) + self.assertEqual(data, s2.recv(4096)) self.assertSocketsNotClosed(socketpair) finally: s1.close() diff --git a/net/test/tcp_repair_test.py b/net/test/tcp_repair_test.py index ce54aba..e0b156e 100755 --- a/net/test/tcp_repair_test.py +++ b/net/test/tcp_repair_test.py @@ -149,7 +149,7 @@ class TcpRepairTest(multinetwork_base.MultiNetworkBaseTest): sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_OFF) readData = sock.recv(4096) - self.assertEquals(readData, TEST_RECEIVED) + self.assertEqual(readData, TEST_RECEIVED) sock.close() # Test whether tcp read/write sequence number can be fetched correctly @@ -164,20 +164,20 @@ class TcpRepairTest(multinetwork_base.MultiNetworkBaseTest): # test write queue sequence number sequence_before = self.GetWriteSequenceNumber(version, sock) expect_sequence = self.last_sent.getlayer("TCP").seq - self.assertEquals(sequence_before & 0xffffffff, expect_sequence) + self.assertEqual(sequence_before & 0xffffffff, expect_sequence) TEST_SEND = net_test.UDP_PAYLOAD self.sendData(netid, version, sock, TEST_SEND) sequence_after = self.GetWriteSequenceNumber(version, sock) - self.assertEquals(sequence_before + len(TEST_SEND), sequence_after) + self.assertEqual(sequence_before + len(TEST_SEND), sequence_after) # test read queue sequence number sequence_before = self.GetReadSequenceNumber(version, sock) expect_sequence = self.last_received.getlayer("TCP").seq + 1 - self.assertEquals(sequence_before & 0xffffffff, expect_sequence) + self.assertEqual(sequence_before & 0xffffffff, expect_sequence) TEST_READ = net_test.UDP_PAYLOAD self.receiveData(netid, version, TEST_READ) sequence_after = self.GetReadSequenceNumber(version, sock) - self.assertEquals(sequence_before + len(TEST_READ), sequence_after) + self.assertEqual(sequence_before + len(TEST_READ), sequence_after) sock.close() def GetWriteSequenceNumber(self, version, sock): @@ -267,12 +267,12 @@ class TcpRepairTest(multinetwork_base.MultiNetworkBaseTest): buf = ctypes.c_int() fcntl.ioctl(sock, SIOCINQ, buf) - self.assertEquals(buf.value, 0) + self.assertEqual(buf.value, 0) TEST_RECV_PAYLOAD = net_test.UDP_PAYLOAD self.receiveData(netid, version, TEST_RECV_PAYLOAD) fcntl.ioctl(sock, SIOCINQ, buf) - self.assertEquals(buf.value, len(TEST_RECV_PAYLOAD)) + self.assertEqual(buf.value, len(TEST_RECV_PAYLOAD)) sock.close() def writeQueueIdleTest(self, version): @@ -281,14 +281,14 @@ class TcpRepairTest(multinetwork_base.MultiNetworkBaseTest): sock = self.createConnectedSocket(version, netid) buf = ctypes.c_int() fcntl.ioctl(sock, SIOCOUTQ, buf) - self.assertEquals(buf.value, 0) + self.assertEqual(buf.value, 0) # Change to repair mode with SEND_QUEUE, writing some data to the queue. sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_ON) TEST_SEND_PAYLOAD = net_test.UDP_PAYLOAD sock.setsockopt(SOL_TCP, TCP_REPAIR_QUEUE, TCP_SEND_QUEUE) self.sendData(netid, version, sock, TEST_SEND_PAYLOAD) fcntl.ioctl(sock, SIOCOUTQ, buf) - self.assertEquals(buf.value, len(TEST_SEND_PAYLOAD)) + self.assertEqual(buf.value, len(TEST_SEND_PAYLOAD)) sock.close() # Setup a connected socket again. @@ -297,14 +297,14 @@ class TcpRepairTest(multinetwork_base.MultiNetworkBaseTest): # Send out some data and don't receive ACK yet. self.sendData(netid, version, sock, TEST_SEND_PAYLOAD) fcntl.ioctl(sock, SIOCOUTQ, buf) - self.assertEquals(buf.value, len(TEST_SEND_PAYLOAD)) + self.assertEqual(buf.value, len(TEST_SEND_PAYLOAD)) # Receive response ACK. remoteaddr = self.GetRemoteAddress(version) myaddr = self.MyAddress(version, netid) desc_ack, ack = packets.ACK(version, remoteaddr, myaddr, self.last_sent) self.ReceivePacketOn(netid, ack) fcntl.ioctl(sock, SIOCOUTQ, buf) - self.assertEquals(buf.value, 0) + self.assertEqual(buf.value, 0) sock.close() @@ -315,7 +315,7 @@ class TcpRepairTest(multinetwork_base.MultiNetworkBaseTest): events = p.poll(500) for fd,event in events: if fd == sock.fileno(): - self.assertEquals(event, expected) + self.assertEqual(event, expected) else: raise AssertionError("unexpected poll fd") @@ -334,7 +334,7 @@ class SocketExceptionThread(threading.Thread): def run(self): try: self.operation(self.sock) - except (IOError, AssertionError), e: + except (IOError, AssertionError) as e: self.exception = e if __name__ == '__main__': diff --git a/net/test/tun_twister.py b/net/test/tun_twister.py index 2ed25c9..f42d789 100644 --- a/net/test/tun_twister.py +++ b/net/test/tun_twister.py @@ -52,8 +52,8 @@ class TunTwister(object): # Set up routing so packets go to my_tun. def ValidatePortNumber(packet): - self.assertEquals(8080, packet.getlayer(scapy.UDP).sport) - self.assertEquals(8080, packet.getlayer(scapy.UDP).dport) + self.assertEqual(8080, packet.getlayer(scapy.UDP).sport) + self.assertEqual(8080, packet.getlayer(scapy.UDP).dport) with TunTwister(tun_fd=my_tun, validator=ValidatePortNumber): sock = socket(AF_INET, SOCK_DGRAM, 0) @@ -61,8 +61,8 @@ class TunTwister(object): sock.settimeout(1.0) sock.sendto("hello", ("1.2.3.4", 8080)) data, addr = sock.recvfrom(1024) - self.assertEquals("hello", data) - self.assertEquals(("1.2.3.4", 8080), addr) + self.assertEqual("hello", data) + self.assertEqual(("1.2.3.4", 8080), addr) """ # Hopefully larger than any packet. diff --git a/net/test/xfrm.py b/net/test/xfrm.py index acdfd4f..83437bd 100755 --- a/net/test/xfrm.py +++ b/net/test/xfrm.py @@ -123,12 +123,15 @@ XFRM_STATE_AF_UNSPEC = 32 # XFRM algorithm names, as defined in net/xfrm/xfrm_algo.c. XFRM_EALG_CBC_AES = "cbc(aes)" +XFRM_EALG_CTR_AES = "rfc3686(ctr(aes))" XFRM_AALG_HMAC_MD5 = "hmac(md5)" XFRM_AALG_HMAC_SHA1 = "hmac(sha1)" XFRM_AALG_HMAC_SHA256 = "hmac(sha256)" XFRM_AALG_HMAC_SHA384 = "hmac(sha384)" XFRM_AALG_HMAC_SHA512 = "hmac(sha512)" +XFRM_AALG_AUTH_XCBC_AES = "xcbc(aes)" XFRM_AEAD_GCM_AES = "rfc4106(gcm(aes))" +XFRM_AEAD_CHACHA20_POLY1305 = "rfc7539esp(chacha20,poly1305)" # Data structure formats. # These aren't constants, they're classes. So, pylint: disable=invalid-name @@ -137,6 +140,11 @@ XfrmSelector = cstruct.Struct( "daddr saddr dport dport_mask sport sport_mask " "family prefixlen_d prefixlen_s proto ifindex user") +XfrmMigrate = cstruct.Struct( + "XfrmMigrate", "=16s16s16s16sBBxxIHH", + "old_daddr old_saddr new_daddr new_saddr proto " + "mode reqid old_family new_family") + XfrmLifetimeCfg = cstruct.Struct( "XfrmLifetimeCfg", "=QQQQQQQQ", "soft_byte hard_byte soft_packet hard_packet " @@ -356,9 +364,9 @@ class Xfrm(netlink.NetlinkSocket): cmdname = self._GetConstantName(command, "XFRM_MSG_") if struct_type: - print "%s %s" % (cmdname, str(self._ParseNLMsg(data, struct_type))) + print("%s %s" % (cmdname, str(self._ParseNLMsg(data, struct_type)))) else: - print "%s" % cmdname + print("%s" % cmdname) def _Decode(self, command, unused_msg, nla_type, nla_data): """Decodes netlink attributes to Python types.""" @@ -707,8 +715,54 @@ class Xfrm(netlink.NetlinkSocket): for selector in selectors: self.DeletePolicyInfo(selector, direction, mark, xfrm_if_id) + def MigrateTunnel(self, direction, selector, old_saddr, old_daddr, + new_saddr, new_daddr, spi, + encryption, auth_trunc, aead, + encap, new_output_mark, xfrm_if_id): + """Update addresses and underlying network of Policies and an SA + + Args: + direction: XFRM_POLICY_IN or XFRM_POLICY_OUT + selector: An XfrmSelector of the tunnel that needs to be updated. + If the passed-in selector is None, it means the tunnel is + dual-stack and thus both IPv4 and IPv6 policies will be updated. + old_saddr: the old (current) source address of the tunnel + old_daddr: the old (current) destination address of the tunnel + new_saddr: the new source address the IPsec SA will be migrated to + new_daddr: the new destination address the tunnel will be migrated to + spi: The SPI for the IPsec SA that encapsulates the tunneled packets + encryption: A tuple of an XfrmAlgo and raw key bytes, or None. + auth_trunc: A tuple of an XfrmAlgoAuth and raw key bytes, or None. + aead: A tuple of an XfrmAlgoAead and raw key bytes, or None. + encap: An XfrmEncapTmpl structure, or None. + new_output_mark: The mark used to select the new underlying network + for packets outbound from xfrm. None means unspecified. + xfrm_if_id: The XFRM interface ID + """ + + if selector is None: + selectors = [EmptySelector(AF_INET), EmptySelector(AF_INET6)] + else: + selectors = [selector] + + nlattrs = [] + xfrmMigrate = XfrmMigrate((PaddedAddress(old_daddr), PaddedAddress(old_saddr), + PaddedAddress(new_daddr), PaddedAddress(new_saddr), + IPPROTO_ESP, XFRM_MODE_TUNNEL, 0, + net_test.GetAddressFamily(net_test.GetAddressVersion(old_saddr)), + net_test.GetAddressFamily(net_test.GetAddressVersion(new_saddr)))) + nlattrs.append((XFRMA_MIGRATE, xfrmMigrate)) + + for selector in selectors: + self.SendXfrmNlRequest(XFRM_MSG_MIGRATE, + XfrmUserpolicyId(sel=selector, dir=direction), nlattrs) + + # UPDSA is called exclusively to update the set_mark=new_output_mark. + self.AddSaInfo(new_saddr, new_daddr, spi, XFRM_MODE_TUNNEL, 0, encryption, + auth_trunc, aead, encap, None, new_output_mark, True, xfrm_if_id) + if __name__ == "__main__": x = Xfrm() - print x.DumpSaInfo() - print x.DumpPolicyInfo() + print(x.DumpSaInfo()) + print(x.DumpPolicyInfo()) diff --git a/net/test/xfrm_algorithm_test.py b/net/test/xfrm_algorithm_test.py index 0176265..8a50fde 100755 --- a/net/test/xfrm_algorithm_test.py +++ b/net/test/xfrm_algorithm_test.py @@ -30,29 +30,43 @@ from tun_twister import TapTwister import util import xfrm import xfrm_base +import xfrm_test + +ANY_KVER = net_test.LINUX_ANY_VERSION # List of encryption algorithms for use in ParamTests. CRYPT_ALGOS = [ - xfrm.XfrmAlgo((xfrm.XFRM_EALG_CBC_AES, 128)), - xfrm.XfrmAlgo((xfrm.XFRM_EALG_CBC_AES, 192)), - xfrm.XfrmAlgo((xfrm.XFRM_EALG_CBC_AES, 256)), + (xfrm.XfrmAlgo((xfrm.XFRM_EALG_CBC_AES, 128)), ANY_KVER), + (xfrm.XfrmAlgo((xfrm.XFRM_EALG_CBC_AES, 192)), ANY_KVER), + (xfrm.XfrmAlgo((xfrm.XFRM_EALG_CBC_AES, 256)), ANY_KVER), + # RFC 3686 specifies that key length must be 128, 192 or 256 bits, with + # an additional 4 bytes (32 bits) of nonce. A fresh nonce value MUST be + # assigned for each SA. + # CTR-AES is enforced since kernel version 5.8 + (xfrm.XfrmAlgo((xfrm.XFRM_EALG_CTR_AES, 128+32)), (5, 8)), + (xfrm.XfrmAlgo((xfrm.XFRM_EALG_CTR_AES, 192+32)), (5, 8)), + (xfrm.XfrmAlgo((xfrm.XFRM_EALG_CTR_AES, 256+32)), (5, 8)), ] # List of auth algorithms for use in ParamTests. AUTH_ALGOS = [ # RFC 4868 specifies that the only supported truncation length is half the # hash size. - xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_MD5, 128, 96)), - xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA1, 160, 96)), - xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA256, 256, 128)), - xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA384, 384, 192)), - xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA512, 512, 256)), + (xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_MD5, 128, 96)), ANY_KVER), + (xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA1, 160, 96)), ANY_KVER), + (xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA256, 256, 128)), ANY_KVER), + (xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA384, 384, 192)), ANY_KVER), + (xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA512, 512, 256)), ANY_KVER), # Test larger truncation lengths for good measure. - xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_MD5, 128, 128)), - xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA1, 160, 160)), - xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA256, 256, 256)), - xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA384, 384, 384)), - xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA512, 512, 512)), + (xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_MD5, 128, 128)), ANY_KVER), + (xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA1, 160, 160)), ANY_KVER), + (xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA256, 256, 256)), ANY_KVER), + (xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA384, 384, 384)), ANY_KVER), + (xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA512, 512, 512)), ANY_KVER), + # RFC 3566 specifies that the only supported truncation length + # is 96 bits. + # XCBC-AES is enforced since kernel version 5.8 + (xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_AUTH_XCBC_AES, 128, 96)), (5, 8)), ] # List of aead algorithms for use in ParamTests. @@ -61,17 +75,88 @@ AEAD_ALGOS = [ # with an additional 4 bytes (32 bits) of salt. The salt must be unique # for each new SA using the same key. # RFC 4106 specifies that ICV length must be 8, 12, or 16 bytes - xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 128+32, 8*8)), - xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 128+32, 12*8)), - xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 128+32, 16*8)), - xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 192+32, 8*8)), - xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 192+32, 12*8)), - xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 192+32, 16*8)), - xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 256+32, 8*8)), - xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 256+32, 12*8)), - xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 256+32, 16*8)), + (xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 128+32, 8*8)), ANY_KVER), + (xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 128+32, 12*8)), ANY_KVER), + (xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 128+32, 16*8)), ANY_KVER), + (xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 192+32, 8*8)), ANY_KVER), + (xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 192+32, 12*8)), ANY_KVER), + (xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 192+32, 16*8)), ANY_KVER), + (xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 256+32, 8*8)), ANY_KVER), + (xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 256+32, 12*8)), ANY_KVER), + (xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 256+32, 16*8)), ANY_KVER), + # RFC 7634 specifies that key length must be 256 bits, with an additional + # 4 bytes (32 bits) of nonce. A fresh nonce value MUST be assigned for + # each SA. RFC 7634 also specifies that ICV length must be 16 bytes. + # ChaCha20-Poly1305 is enforced since kernel version 5.8 + (xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_CHACHA20_POLY1305, 256+32, 16*8)), (5, 8)), ] +def GenerateKey(key_len): + if key_len % 8 != 0: + raise ValueError("Invalid key length in bits: " + str(key_len)) + return os.urandom(key_len / 8) + +# Does the kernel support this algorithm? +def HaveAlgo(crypt_algo, auth_algo, aead_algo): + try: + test_xfrm = xfrm.Xfrm() + test_xfrm.FlushSaInfo() + test_xfrm.FlushPolicyInfo() + + test_xfrm.AddSaInfo( + src=xfrm_test.TEST_ADDR1, + dst=xfrm_test.TEST_ADDR2, + spi=xfrm_test.TEST_SPI, + mode=xfrm.XFRM_MODE_TRANSPORT, + reqid=100, + encryption=(crypt_algo, + GenerateKey(crypt_algo.key_len)) if crypt_algo else None, + auth_trunc=(auth_algo, + GenerateKey(auth_algo.key_len)) if auth_algo else None, + aead=(aead_algo, GenerateKey(aead_algo.key_len)) if aead_algo else None, + encap=None, + mark=None, + output_mark=None) + + test_xfrm.FlushSaInfo() + test_xfrm.FlushPolicyInfo() + + return True + except IOError as err: + if err.errno == ENOSYS: + return False + else: + print("Unexpected error:", err.errno) + return True + +# Dictionary to record the algorithm state. Mark the state True if this +# algorithm is enforced or enabled on this kernel. Otherwise, mark it +# False. +algoState = {} + +def AlgoEnforcedOrEnabled(crypt, auth, aead, target_algo, target_kernel): + if algoState.get(target_algo) is None: + algoState[target_algo] = net_test.LINUX_VERSION >= target_kernel or HaveAlgo( + crypt, auth, aead) + return algoState.get(target_algo) + +# Return true if this algorithm should be enforced or is enabled on this kernel +def AuthEnforcedOrEnabled(authCase): + auth = authCase[0] + crypt = xfrm.XfrmAlgo(("ecb(cipher_null)", 0)) + return AlgoEnforcedOrEnabled(crypt, auth, None, auth.name, authCase[1]) + +# Return true if this algorithm should be enforced or is enabled on this kernel +def CryptEnforcedOrEnabled(cryptCase): + crypt = cryptCase[0] + auth = xfrm.XfrmAlgoAuth(("digest_null", 0, 0)) + return AlgoEnforcedOrEnabled(crypt, auth, None, crypt.name, cryptCase[1]) + +# Return true if this algorithm should be enforced or is enabled on this kernel +def AeadEnforcedOrEnabled(aeadCase): + aead = aeadCase[0] + return AlgoEnforcedOrEnabled(None, None, aead, aead.name, aeadCase[1]) + def InjectTests(): XfrmAlgorithmTest.InjectTests() @@ -92,18 +177,21 @@ class XfrmAlgorithmTest(xfrm_base.XfrmLazyTest): util.InjectParameterizedTest(cls, param_list, cls.TestNameGenerator) @staticmethod - def TestNameGenerator(version, proto, auth, crypt, aead): + def TestNameGenerator(version, proto, authCase, cryptCase, aeadCase): # Produce a unique and readable name for each test. e.g. # testSocketPolicySimple_cbc-aes_256_hmac-sha512_512_256_IPv6_UDP param_string = "" - if crypt is not None: + if cryptCase is not None: + crypt = cryptCase[0] param_string += "%s_%d_" % (crypt.name, crypt.key_len) - if auth is not None: + if authCase is not None: + auth = authCase[0] param_string += "%s_%d_%d_" % (auth.name, auth.key_len, auth.trunc_len) - if aead is not None: + if aeadCase is not None: + aead = aeadCase[0] param_string += "%s_%d_%d_" % (aead.name, aead.key_len, aead.icv_len) @@ -111,16 +199,29 @@ class XfrmAlgorithmTest(xfrm_base.XfrmLazyTest): "UDP" if proto == SOCK_DGRAM else "TCP") return param_string - def ParamTestSocketPolicySimple(self, version, proto, auth, crypt, aead): + def ParamTestSocketPolicySimple(self, version, proto, authCase, cryptCase, aeadCase): """Test two-way traffic using transport mode and socket policies.""" + # Bypass the test if any algorithm going to be tested is not enforced + # or enabled on this kernel + if authCase is not None and not AuthEnforcedOrEnabled(authCase): + return + if cryptCase is not None and not CryptEnforcedOrEnabled(cryptCase): + return + if aeadCase is not None and not AeadEnforcedOrEnabled(aeadCase): + return + + auth = authCase[0] if authCase else None + crypt = cryptCase[0] if cryptCase else None + aead = aeadCase[0] if aeadCase else None + def AssertEncrypted(packet): # This gives a free pass to ICMP and ICMPv6 packets, which show up # nondeterministically in tests. - self.assertEquals(None, + self.assertEqual(None, packet.getlayer(scapy.UDP), "UDP packet sent in the clear") - self.assertEquals(None, + self.assertEqual(None, packet.getlayer(scapy.TCP), "TCP packet sent in the clear") @@ -237,10 +338,10 @@ class XfrmAlgorithmTest(xfrm_base.XfrmLazyTest): sock.listen(1) server_ready.set() accepted, peer = sock.accept() - self.assertEquals(remote_addr, peer[0]) - self.assertEquals(client_port, peer[1]) + self.assertEqual(remote_addr, peer[0]) + self.assertEqual(client_port, peer[1]) data = accepted.recv(2048) - self.assertEquals("hello request", data) + self.assertEqual("hello request", data) accepted.send("hello response") except Exception as e: server_error = e @@ -251,9 +352,9 @@ class XfrmAlgorithmTest(xfrm_base.XfrmLazyTest): try: server_ready.set() data, peer = sock.recvfrom(2048) - self.assertEquals(remote_addr, peer[0]) - self.assertEquals(client_port, peer[1]) - self.assertEquals("hello request", data) + self.assertEqual(remote_addr, peer[0]) + self.assertEqual(client_port, peer[1]) + self.assertEqual("hello request", data) sock.sendto("hello response", peer) except Exception as e: server_error = e @@ -283,7 +384,7 @@ class XfrmAlgorithmTest(xfrm_base.XfrmLazyTest): sock_left.connect((remote_addr, right_port)) sock_left.send("hello request") data = sock_left.recv(2048) - self.assertEquals("hello response", data) + self.assertEqual("hello response", data) sock_left.close() server.join() if server_error: diff --git a/net/test/xfrm_base.py b/net/test/xfrm_base.py index 1eaa302..e61322e 100644 --- a/net/test/xfrm_base.py +++ b/net/test/xfrm_base.py @@ -196,7 +196,7 @@ def EncryptPacketWithNull(packet, spi, seq, tun_addrs): esplen = (len(inner_layer) + 2) # UDP length plus Pad Length and Next Header. padlen = util.GetPadLength(4, esplen) # The pad bytes are consecutive integers starting from 0x01. - padding = "".join((chr(i) for i in xrange(1, padlen + 1))) + padding = "".join((chr(i) for i in range(1, padlen + 1))) trailer = padding + struct.pack("BB", padlen, esp_nexthdr) # Assemble the packet. @@ -270,6 +270,13 @@ def DecryptPacketWithNull(packet): class XfrmBaseTest(multinetwork_base.MultiNetworkBaseTest): """Base test class for all XFRM-related testing.""" + def _isIcmpv6(self, payload): + if not isinstance(payload, scapy.IPv6): + return False + if payload.nh == IPPROTO_ICMPV6: + return True + return payload.nh == IPPROTO_HOPOPTS and payload.payload.nh == IPPROTO_ICMPV6 + def _ExpectEspPacketOn(self, netid, spi, seq, length, src_addr, dst_addr): """Read a packet from a netid and verify its properties. @@ -284,18 +291,22 @@ class XfrmBaseTest(multinetwork_base.MultiNetworkBaseTest): Returns: scapy.IP/IPv6: the read packet """ - packets = self.ReadAllPacketsOn(netid) - self.assertEquals(1, len(packets)) + packets = [] + for packet in self.ReadAllPacketsOn(netid): + if not self._isIcmpv6(packet): + packets.append(packet) + + self.assertEqual(1, len(packets)) packet = packets[0] if length is not None: - self.assertEquals(length, len(packet.payload)) + self.assertEqual(length, len(packet.payload)) if dst_addr is not None: - self.assertEquals(dst_addr, packet.dst) + self.assertEqual(dst_addr, packet.dst) if src_addr is not None: - self.assertEquals(src_addr, packet.src) + self.assertEqual(src_addr, packet.src) # extract the ESP header esp_hdr, _ = cstruct.Read(str(packet.payload), xfrm.EspHdr) - self.assertEquals(xfrm.EspHdr((spi, seq)), esp_hdr) + self.assertEqual(xfrm.EspHdr((spi, seq)), esp_hdr) return packet diff --git a/net/test/xfrm_test.py b/net/test/xfrm_test.py index 64be084..439a2d2 100755 --- a/net/test/xfrm_test.py +++ b/net/test/xfrm_test.py @@ -53,14 +53,14 @@ TEST_SPI2 = 0x1235 class XfrmFunctionalTest(xfrm_base.XfrmLazyTest): def assertIsUdpEncapEsp(self, packet, spi, seq, length): - self.assertEquals(IPPROTO_UDP, packet.proto) + self.assertEqual(IPPROTO_UDP, packet.proto) udp_hdr = packet[scapy.UDP] - self.assertEquals(4500, udp_hdr.dport) - self.assertEquals(length, len(udp_hdr)) + self.assertEqual(4500, udp_hdr.dport) + self.assertEqual(length, len(udp_hdr)) esp_hdr, _ = cstruct.Read(str(udp_hdr.payload), xfrm.EspHdr) # FIXME: this file currently swaps SPI byte order manually, so SPI needs to # be double-swapped here. - self.assertEquals(xfrm.EspHdr((spi, seq)), esp_hdr) + self.assertEqual(xfrm.EspHdr((spi, seq)), esp_hdr) def CreateNewSa(self, localAddr, remoteAddr, spi, reqId, encap_tmpl, null_auth=False): @@ -93,12 +93,12 @@ class XfrmFunctionalTest(xfrm_base.XfrmLazyTest): self.xfrm.DeleteSaInfo(TEST_ADDR1, TEST_SPI, IPPROTO_ESP) def testFlush(self): - self.assertEquals(0, len(self.xfrm.DumpSaInfo())) + self.assertEqual(0, len(self.xfrm.DumpSaInfo())) self.CreateNewSa("::", "2000::", TEST_SPI, 1234, None) self.CreateNewSa("0.0.0.0", "192.0.2.1", TEST_SPI, 4321, None) - self.assertEquals(2, len(self.xfrm.DumpSaInfo())) + self.assertEqual(2, len(self.xfrm.DumpSaInfo())) self.xfrm.FlushSaInfo() - self.assertEquals(0, len(self.xfrm.DumpSaInfo())) + self.assertEqual(0, len(self.xfrm.DumpSaInfo())) def _TestSocketPolicy(self, version): # Open a UDP socket and connect it. @@ -141,7 +141,7 @@ class XfrmFunctionalTest(xfrm_base.XfrmLazyTest): try: self.xfrm.DeleteSaInfo(self.GetRemoteAddress(xfrm_version), TEST_SPI, IPPROTO_ESP) except IOError as e: - self.assertEquals(ESRCH, e.errno, "Unexpected error when deleting ACQ SA") + self.assertEqual(ESRCH, e.errno, "Unexpected error when deleting ACQ SA") # Adding a matching SA causes the packet to go out encrypted. The SA's # SPI must match the one in our template, and the destination address must @@ -171,11 +171,11 @@ class XfrmFunctionalTest(xfrm_base.XfrmLazyTest): self.SelectInterface(s2, netid, "mark") s2.sendto(net_test.UDP_PAYLOAD, (remotesockaddr, 53)) pkts = self.ReadAllPacketsOn(netid) - self.assertEquals(1, len(pkts)) + self.assertEqual(1, len(pkts)) packet = pkts[0] protocol = packet.nh if version == 6 else packet.proto - self.assertEquals(IPPROTO_UDP, protocol) + self.assertEqual(IPPROTO_UDP, protocol) # Deleting the SA causes the first socket to return errors again. self.xfrm.DeleteSaInfo(self.GetRemoteAddress(xfrm_version), TEST_SPI, @@ -268,7 +268,7 @@ class XfrmFunctionalTest(xfrm_base.XfrmLazyTest): # Expect to see an UDP encapsulated packet. pkts = self.ReadAllPacketsOn(netid) - self.assertEquals(1, len(pkts)) + self.assertEqual(1, len(pkts)) packet = pkts[0] auth_algo = ( @@ -314,13 +314,13 @@ class XfrmFunctionalTest(xfrm_base.XfrmLazyTest): # integrity-verified portion of the packet. if not null_auth and in_spi != out_spi: self.assertRaisesErrno(EAGAIN, twisted_socket.recv, 4096) - self.assertEquals(start_integrity_failures + 1, + self.assertEqual(start_integrity_failures + 1, sainfo.stats.integrity_failed) else: data, src = twisted_socket.recvfrom(4096) - self.assertEquals(net_test.UDP_PAYLOAD, data) - self.assertEquals((remoteaddr, srcport), src) - self.assertEquals(start_integrity_failures, sainfo.stats.integrity_failed) + self.assertEqual(net_test.UDP_PAYLOAD, data) + self.assertEqual((remoteaddr, srcport), src) + self.assertEqual(start_integrity_failures, sainfo.stats.integrity_failed) # Check that unencrypted packets on twisted_socket are not received. unencrypted = ( @@ -400,13 +400,13 @@ class XfrmFunctionalTest(xfrm_base.XfrmLazyTest): def testAllocSpecificSpi(self): spi = 0xABCD new_sa = self.xfrm.AllocSpi("::", IPPROTO_ESP, spi, spi) - self.assertEquals(spi, new_sa.id.spi) + self.assertEqual(spi, new_sa.id.spi) def testAllocSpecificSpiUnavailable(self): """Attempt to allocate the same SPI twice.""" spi = 0xABCD new_sa = self.xfrm.AllocSpi("::", IPPROTO_ESP, spi, spi) - self.assertEquals(spi, new_sa.id.spi) + self.assertEqual(spi, new_sa.id.spi) with self.assertRaisesErrno(ENOENT): new_sa = self.xfrm.AllocSpi("::", IPPROTO_ESP, spi, spi) @@ -427,7 +427,7 @@ class XfrmFunctionalTest(xfrm_base.XfrmLazyTest): # Allocating range_size + 1 SPIs is guaranteed to fail. Due to the way # kernel picks random SPIs, this has a high probability of failing before # reaching that limit. - for i in xrange(range_size + 1): + for i in range(range_size + 1): new_sa = self.xfrm.AllocSpi("::", IPPROTO_ESP, start, end) spi = new_sa.id.spi self.assertNotIn(spi, spis) @@ -514,21 +514,21 @@ class XfrmFunctionalTest(xfrm_base.XfrmLazyTest): self.ReceivePacketOn(netid, input_pkt) msg, addr = sock.recvfrom(1024) - self.assertEquals("input hello", msg) - self.assertEquals((remote_addr, remote_port), addr[:2]) + self.assertEqual("input hello", msg) + self.assertEqual((remote_addr, remote_port), addr[:2]) # Send and capture a packet. sock.sendto("output hello", (remote_addr, remote_port)) packets = self.ReadAllPacketsOn(netid) - self.assertEquals(1, len(packets)) + self.assertEqual(1, len(packets)) output_pkt = packets[0] output_pkt, esp_hdr = xfrm_base.DecryptPacketWithNull(output_pkt) - self.assertEquals(output_pkt[scapy.UDP].len, len("output_hello") + 8) - self.assertEquals(remote_addr, output_pkt.dst) - self.assertEquals(remote_port, output_pkt[scapy.UDP].dport) + self.assertEqual(output_pkt[scapy.UDP].len, len("output_hello") + 8) + self.assertEqual(remote_addr, output_pkt.dst) + self.assertEqual(remote_port, output_pkt[scapy.UDP].dport) # length of the payload plus the UDP header - self.assertEquals("output hello", str(output_pkt[scapy.UDP].payload)) - self.assertEquals(0xABCD, esp_hdr.spi) + self.assertEqual("output hello", str(output_pkt[scapy.UDP].payload)) + self.assertEqual(0xABCD, esp_hdr.spi) def testNullEncryptionTunnelMode(self): """Verify null encryption in tunnel mode. @@ -577,21 +577,21 @@ class XfrmFunctionalTest(xfrm_base.XfrmLazyTest): self.ReceivePacketOn(netid, input_pkt) msg, addr = sock.recvfrom(1024) - self.assertEquals("input hello", msg) - self.assertEquals((remote_addr, remote_port), addr[:2]) + self.assertEqual("input hello", msg) + self.assertEqual((remote_addr, remote_port), addr[:2]) # Send and capture a packet. sock.sendto("output hello", (remote_addr, remote_port)) packets = self.ReadAllPacketsOn(netid) - self.assertEquals(1, len(packets)) + self.assertEqual(1, len(packets)) output_pkt = packets[0] output_pkt, esp_hdr = xfrm_base.DecryptPacketWithNull(output_pkt) # length of the payload plus the UDP header - self.assertEquals(output_pkt[scapy.UDP].len, len("output_hello") + 8) - self.assertEquals(remote_addr, output_pkt.dst) - self.assertEquals(remote_port, output_pkt[scapy.UDP].dport) - self.assertEquals("output hello", str(output_pkt[scapy.UDP].payload)) - self.assertEquals(0xABCD, esp_hdr.spi) + self.assertEqual(output_pkt[scapy.UDP].len, len("output_hello") + 8) + self.assertEqual(remote_addr, output_pkt.dst) + self.assertEqual(remote_port, output_pkt[scapy.UDP].dport) + self.assertEqual("output hello", str(output_pkt[scapy.UDP].payload)) + self.assertEqual(0xABCD, esp_hdr.spi) def testNullEncryptionTransportMode(self): """Verify null encryption in transport mode. @@ -638,9 +638,9 @@ class XfrmFunctionalTest(xfrm_base.XfrmLazyTest): def _CheckTemplateMatch(tmpl): """Dump the SPD and match a single template on a single policy.""" dump = self.xfrm.DumpPolicyInfo() - self.assertEquals(1, len(dump)) + self.assertEqual(1, len(dump)) _, attributes = dump[0] - self.assertEquals(attributes['XFRMA_TMPL'], tmpl) + self.assertEqual(attributes['XFRMA_TMPL'], tmpl) # Create a new policy using update. self.xfrm.UpdatePolicyInfo(policy, tmpl1, mark, None) @@ -763,9 +763,9 @@ class XfrmOutputMarkTest(xfrm_base.XfrmLazyTest): xfrm.XFRM_MODE_TUNNEL, 100, xfrm_base._ALGO_CBC_AES_256, xfrm_base._ALGO_HMAC_SHA1, None, None, None, mark) dump = self.xfrm.DumpSaInfo() - self.assertEquals(1, len(dump)) + self.assertEqual(1, len(dump)) sainfo, attributes = dump[0] - self.assertEquals(mark, attributes["XFRMA_OUTPUT_MARK"]) + self.assertEqual(mark, attributes["XFRMA_OUTPUT_MARK"]) def testInvalidAlgorithms(self): key = "af442892cdcd0ef650e9c299f9a8436a".decode("hex") @@ -795,9 +795,9 @@ class XfrmOutputMarkTest(xfrm_base.XfrmLazyTest): xfrm_base._ALGO_HMAC_SHA1, None, None, mark, 0, is_update=True) dump = self.xfrm.DumpSaInfo() - self.assertEquals(1, len(dump)) # check that update updated + self.assertEqual(1, len(dump)) # check that update updated sainfo, attributes = dump[0] - self.assertEquals(mark, attributes["XFRMA_MARK"]) + self.assertEqual(mark, attributes["XFRMA_MARK"]) self.xfrm.DeleteSaInfo(net_test.GetWildcardAddress(version), spi, IPPROTO_ESP, mark) @@ -836,7 +836,7 @@ class XfrmOutputMarkTest(xfrm_base.XfrmLazyTest): s.sendto(net_test.UDP_PAYLOAD, (remote, 53)) # Check to make sure XfrmOutNoStates is incremented by exactly 1 - self.assertEquals(outNoStateCount + 1, + self.assertEqual(outNoStateCount + 1, self.getXfrmStat(XFRM_STATS_OUT_NO_STATES)) length = xfrm_base.GetEspPacketLength(xfrm.XFRM_MODE_TUNNEL, @@ -854,7 +854,7 @@ class XfrmOutputMarkTest(xfrm_base.XfrmLazyTest): xfrm_base._ALGO_HMAC_SHA1, None, None, mark, 0, is_update=False) except IOError as e: - self.assertEquals(EEXIST, e.errno, "SA exists") + self.assertEqual(EEXIST, e.errno, "SA exists") self.xfrm.AddSaInfo(local, remote, TEST_SPI, xfrm.XFRM_MODE_TUNNEL, 0, @@ -893,9 +893,9 @@ class XfrmOutputMarkTest(xfrm_base.XfrmLazyTest): dump = self.xfrm.DumpSaInfo() - self.assertEquals(1, len(dump)) # check that update updated + self.assertEqual(1, len(dump)) # check that update updated sainfo, attributes = dump[0] - self.assertEquals(reroute_netid, attributes["XFRMA_OUTPUT_MARK"]) + self.assertEqual(reroute_netid, attributes["XFRMA_OUTPUT_MARK"]) self.xfrm.DeleteSaInfo(remote, TEST_SPI, IPPROTO_ESP, mark) self.xfrm.DeletePolicyInfo(sel, xfrm.XFRM_POLICY_OUT, mark) diff --git a/net/test/xfrm_tunnel_test.py b/net/test/xfrm_tunnel_test.py index eb1a46e..e319a7d 100755 --- a/net/test/xfrm_tunnel_test.py +++ b/net/test/xfrm_tunnel_test.py @@ -37,9 +37,13 @@ import xfrm_base _LOOPBACK_IFINDEX = 1 _TEST_XFRM_IFNAME = "ipsec42" _TEST_XFRM_IF_ID = 42 +_TEST_SPI = 0x1234 # Does the kernel support xfrmi interfaces? def HaveXfrmInterfaces(): + if net_test.LINUX_VERSION >= (4, 19, 0): + return True + try: i = iproute.IPRoute() i.CreateXfrmInterface(_TEST_XFRM_IFNAME, _TEST_XFRM_IF_ID, @@ -56,13 +60,45 @@ def HaveXfrmInterfaces(): HAVE_XFRM_INTERFACES = HaveXfrmInterfaces() +# Does the kernel support CONFIG_XFRM_MIGRATE? +def SupportsXfrmMigrate(): + if net_test.LINUX_VERSION >= (5, 10, 0): + return True + + # XFRM_MIGRATE depends on xfrmi interfaces + if not HAVE_XFRM_INTERFACES: + return False + + try: + x = xfrm.Xfrm() + wildcard_addr = net_test.GetWildcardAddress(6) + selector = xfrm.EmptySelector(AF_INET6) + + # Expect migration to fail with EINVAL because it is trying to migrate a + # non-existent SA. + x.MigrateTunnel(xfrm.XFRM_POLICY_OUT, selector, wildcard_addr, wildcard_addr, + wildcard_addr, wildcard_addr, _TEST_SPI, + None, None, None, None, None, None) + print("Migration succeeded unexpectedly, assuming XFRM_MIGRATE is enabled") + return True + except IOError as err: + if err.errno == ENOPROTOOPT: + return False + elif err.errno == EINVAL: + return True + else: + print("Unexpected error, assuming XFRM_MIGRATE is enabled:", err.errno) + return True + +SUPPORTS_XFRM_MIGRATE = SupportsXfrmMigrate() + # Parameters to setup tunnels as special networks _TUNNEL_NETID_OFFSET = 0xFC00 # Matches reserved netid range for IpSecService _BASE_TUNNEL_NETID = {4: 40, 6: 60} _BASE_VTI_OKEY = 2000000100 _BASE_VTI_IKEY = 2000000200 -_TEST_OUT_SPI = 0x1234 +_TEST_OUT_SPI = _TEST_SPI _TEST_IN_SPI = _TEST_OUT_SPI _TEST_OKEY = 2000000100 @@ -132,6 +168,7 @@ def InjectTests(): InjectParameterizedTests(XfrmTunnelTest) InjectParameterizedTests(XfrmInterfaceTest) InjectParameterizedTests(XfrmVtiTest) + InjectParameterizedTests(XfrmInterfaceMigrateTest) def InjectParameterizedTests(cls): @@ -169,8 +206,8 @@ class XfrmTunnelTest(xfrm_base.XfrmLazyTest): # Verify that the packet data and src are correct data, src = read_sock.recvfrom(4096) - self.assertEquals(net_test.UDP_PAYLOAD, data) - self.assertEquals((remote_inner, _TEST_REMOTE_PORT), src[:2]) + self.assertEqual(net_test.UDP_PAYLOAD, data) + self.assertEqual((remote_inner, _TEST_REMOTE_PORT), src[:2]) def _TestTunnel(self, inner_version, outer_version, func, direction, test_output_mark_unset): @@ -233,13 +270,13 @@ class XfrmTunnelTest(xfrm_base.XfrmLazyTest): class XfrmAddDeleteVtiTest(xfrm_base.XfrmBaseTest): def _VerifyVtiInfoData(self, vti_info_data, version, local_addr, remote_addr, ikey, okey): - self.assertEquals(vti_info_data["IFLA_VTI_IKEY"], ikey) - self.assertEquals(vti_info_data["IFLA_VTI_OKEY"], okey) + self.assertEqual(vti_info_data["IFLA_VTI_IKEY"], ikey) + self.assertEqual(vti_info_data["IFLA_VTI_OKEY"], okey) family = AF_INET if version == 4 else AF_INET6 - self.assertEquals(inet_ntop(family, vti_info_data["IFLA_VTI_LOCAL"]), + self.assertEqual(inet_ntop(family, vti_info_data["IFLA_VTI_LOCAL"]), local_addr) - self.assertEquals(inet_ntop(family, vti_info_data["IFLA_VTI_REMOTE"]), + self.assertEqual(inet_ntop(family, vti_info_data["IFLA_VTI_REMOTE"]), remote_addr) def testAddVti(self): @@ -275,7 +312,7 @@ class XfrmAddDeleteVtiTest(xfrm_base.XfrmBaseTest): if_index = self.iproute.GetIfIndex(_TEST_XFRM_IFNAME) # Validate that the netlink interface matches the ioctl interface. - self.assertEquals(net_test.GetInterfaceIndex(_TEST_XFRM_IFNAME), if_index) + self.assertEqual(net_test.GetInterfaceIndex(_TEST_XFRM_IFNAME), if_index) self.iproute.DeleteLink(_TEST_XFRM_IFNAME) with self.assertRaises(IOError): self.iproute.GetIfIndex(_TEST_XFRM_IFNAME) @@ -334,6 +371,9 @@ class IpSecBaseInterface(object): else: auth, crypt = xfrm_base._ALGO_HMAC_SHA1, xfrm_base._ALGO_CBC_AES_256 + self.auth = auth + self.crypt = crypt + self._SetupXfrmByType(auth, crypt) def Rekey(self, outer_family, new_out_sa, new_in_sa): @@ -439,7 +479,7 @@ class XfrmAddDeleteXfrmInterfaceTest(xfrm_base.XfrmBaseTest): net_test.SetInterfaceUp(_TEST_XFRM_IFNAME) # Validate that the netlink interface matches the ioctl interface. - self.assertEquals(net_test.GetInterfaceIndex(_TEST_XFRM_IFNAME), if_index) + self.assertEqual(net_test.GetInterfaceIndex(_TEST_XFRM_IFNAME), if_index) self.iproute.DeleteLink(_TEST_XFRM_IFNAME) with self.assertRaises(IOError): self.iproute.GetIfIndex(_TEST_XFRM_IFNAME) @@ -448,7 +488,7 @@ class XfrmAddDeleteXfrmInterfaceTest(xfrm_base.XfrmBaseTest): class XfrmInterface(IpSecBaseInterface): def __init__(self, iface, netid, underlying_netid, ifindex, local, remote, - version): + version, use_null_crypt=False): super(XfrmInterface, self).__init__(iface, netid, underlying_netid, local, remote, version) @@ -456,7 +496,7 @@ class XfrmInterface(IpSecBaseInterface): self.xfrm_if_id = netid self.SetupInterface() - self.SetupXfrm(False) + self.SetupXfrm(use_null_crypt) def SetupInterface(self): """Create an XFRM interface.""" @@ -505,9 +545,30 @@ class XfrmInterface(IpSecBaseInterface): self.xfrm.DeleteSaInfo(self.remote, old_out_spi, IPPROTO_ESP, None, self.xfrm_if_id) + def Migrate(self, new_underlying_netid, new_local, new_remote): + self.xfrm.MigrateTunnel(xfrm.XFRM_POLICY_IN, None, self.remote, self.local, + new_remote, new_local, self.in_sa.spi, + self.crypt, self.auth, None, None, + new_underlying_netid, self.xfrm_if_id) + + self.xfrm.MigrateTunnel(xfrm.XFRM_POLICY_OUT, None, self.local, self.remote, + new_local, new_remote, self.out_sa.spi, + self.crypt, self.auth, None, None, + new_underlying_netid, self.xfrm_if_id) + + self.local = new_local + self.remote = new_remote + self.underlying_netid = new_underlying_netid + class XfrmTunnelBase(xfrm_base.XfrmBaseTest): + # Subclass that does not allow multiple tunnels (e.g. XfrmInterfaceMigrateTest) + # should override this method. + @classmethod + def allowMultipleTunnels(cls): + return True + @classmethod def setUpClass(cls): xfrm_base.XfrmBaseTest.setUpClass() @@ -520,6 +581,10 @@ class XfrmTunnelBase(xfrm_base.XfrmBaseTest): # IPv6 tunnel cls.tunnelsV4 = {} cls.tunnelsV6 = {} + + if not cls.allowMultipleTunnels(): + return + for i, underlying_netid in enumerate(cls.tuns): for version in 4, 6: netid = _BASE_TUNNEL_NETID[version] + _TUNNEL_NETID_OFFSET + i @@ -545,7 +610,7 @@ class XfrmTunnelBase(xfrm_base.XfrmBaseTest): def tearDownClass(cls): # The sysctls are restored by MultinetworkBaseTest.tearDownClass. cls.SetInboundMarks(False) - for tunnel in cls.tunnelsV4.values() + cls.tunnelsV6.values(): + for tunnel in list(cls.tunnelsV4.values()) + list(cls.tunnelsV6.values()): cls._SetInboundMarking(tunnel.netid, tunnel.iface, False) cls._SetupTunnelNetwork(tunnel, False) tunnel.Teardown() @@ -553,7 +618,7 @@ class XfrmTunnelBase(xfrm_base.XfrmBaseTest): def randomTunnel(self, outer_version): version_dict = self.tunnelsV4 if outer_version == 4 else self.tunnelsV6 - return random.choice(version_dict.values()) + return random.choice(list(version_dict.values())) def setUp(self): multinetwork_base.MultiNetworkBaseTest.setUp(self) @@ -636,13 +701,13 @@ class XfrmTunnelBase(xfrm_base.XfrmBaseTest): def assertReceivedPacket(self, tunnel, sa_info): tunnel.rx += 1 - self.assertEquals((tunnel.rx, tunnel.tx), + self.assertEqual((tunnel.rx, tunnel.tx), self.iproute.GetRxTxPackets(tunnel.iface)) sa_info.seq_num += 1 def assertSentPacket(self, tunnel, sa_info): tunnel.tx += 1 - self.assertEquals((tunnel.rx, tunnel.tx), + self.assertEqual((tunnel.rx, tunnel.tx), self.iproute.GetRxTxPackets(tunnel.iface)) sa_info.seq_num += 1 @@ -664,8 +729,8 @@ class XfrmTunnelBase(xfrm_base.XfrmBaseTest): # Verify that the packet data and src are correct data, src = read_sock.recvfrom(4096) self.assertReceivedPacket(tunnel, sa_info) - self.assertEquals(net_test.UDP_PAYLOAD, data) - self.assertEquals((remote_inner, _TEST_REMOTE_PORT), src[:2]) + self.assertEqual(net_test.UDP_PAYLOAD, data) + self.assertEqual((remote_inner, _TEST_REMOTE_PORT), src[:2]) def _CheckTunnelOutput(self, tunnel, inner_version, local_inner, remote_inner, sa_info=None): @@ -701,12 +766,12 @@ class XfrmTunnelBase(xfrm_base.XfrmBaseTest): # Check outer header manually (Avoids having to overwrite outer header's # id, flags or flow label) self.assertSentPacket(tunnel, sa_info) - self.assertEquals(expected.src, pkt.src) - self.assertEquals(expected.dst, pkt.dst) - self.assertEquals(len(expected), len(pkt)) + self.assertEqual(expected.src, pkt.src) + self.assertEqual(expected.dst, pkt.dst) + self.assertEqual(len(expected), len(pkt)) # Check everything else - self.assertEquals(str(expected.payload), str(pkt.payload)) + self.assertEqual(str(expected.payload), str(pkt.payload)) def _CheckTunnelEncryption(self, tunnel, inner_version, local_inner, remote_inner): @@ -728,8 +793,8 @@ class XfrmTunnelBase(xfrm_base.XfrmBaseTest): self.assertTrue(str(net_test.UDP_PAYLOAD) not in str(pkt)) # Check src/dst - self.assertEquals(tunnel.local, pkt.src) - self.assertEquals(tunnel.remote, pkt.dst) + self.assertEqual(tunnel.local, pkt.src) + self.assertEqual(tunnel.remote, pkt.dst) # Check that the interface statistics recorded the outbound packet self.assertSentPacket(tunnel, tunnel.out_sa) @@ -748,8 +813,8 @@ class XfrmTunnelBase(xfrm_base.XfrmBaseTest): # Verify that the packet data and src are correct data, src = read_sock.recvfrom(4096) - self.assertEquals(net_test.UDP_PAYLOAD, data) - self.assertEquals((local_inner, src_port), src[:2]) + self.assertEqual(net_test.UDP_PAYLOAD, data) + self.assertEqual((local_inner, src_port), src[:2]) # Check that the interface statistics recorded the inbound packet self.assertReceivedPacket(tunnel, tunnel.in_sa) @@ -784,10 +849,10 @@ class XfrmTunnelBase(xfrm_base.XfrmBaseTest): # Check that the packet too big reduced the MTU. routes = self.iproute.GetRoutes(tunnel.remote, 0, tunnel.underlying_netid, None) - self.assertEquals(1, len(routes)) + self.assertEqual(1, len(routes)) rtmsg, attributes = routes[0] - self.assertEquals(iproute.RTN_UNICAST, rtmsg.type) - self.assertEquals(packets.PTB_MTU, attributes["RTA_METRICS"]["RTAX_MTU"]) + self.assertEqual(iproute.RTN_UNICAST, rtmsg.type) + self.assertEqual(packets.PTB_MTU, attributes["RTA_METRICS"]["RTAX_MTU"]) # Clear PMTU information so that future tests don't have to worry about it. self.InvalidateDstCache(tunnel.version, tunnel.underlying_netid) @@ -947,6 +1012,80 @@ class XfrmInterfaceTest(XfrmTunnelBase): def ParamTestXfrmIntfRekey(self, inner_version, outer_version): self._TestTunnelRekey(inner_version, outer_version) +@unittest.skipUnless(SUPPORTS_XFRM_MIGRATE, "XFRM migration unsupported") +class XfrmInterfaceMigrateTest(XfrmTunnelBase): + # TODO: b/172497215 There is a kernel issue that XFRM_MIGRATE cannot work correctly + # when there are multiple tunnels with the same selectors. Thus before this issue + # is fixed, #allowMultipleTunnels must be overridden to avoid setting up multiple + # tunnels. This need to be removed after the kernel issue is fixed. + @classmethod + def allowMultipleTunnels(cls): + return False + + def setUpTunnel(self, outer_version, use_null_crypt): + underlying_netid = self.RandomNetid() + netid = _BASE_TUNNEL_NETID[outer_version] + _TUNNEL_NETID_OFFSET + iface = "ipsec%s" % netid + ifindex = self.ifindices[underlying_netid] + + local = self.MyAddress(outer_version, underlying_netid) + remote = net_test.IPV4_ADDR if outer_version == 4 else net_test.IPV6_ADDR + + tunnel = XfrmInterface(iface, netid, underlying_netid, ifindex, + local, remote, outer_version, use_null_crypt) + self._SetInboundMarking(netid, iface, True) + self._SetupTunnelNetwork(tunnel, True) + + return tunnel + + def tearDownTunnel(self, tunnel): + self._SetInboundMarking(tunnel.netid, tunnel.iface, False) + self._SetupTunnelNetwork(tunnel, False) + tunnel.Teardown() + + def _TestTunnel(self, inner_version, outer_version, func, use_null_crypt): + try: + tunnel = self.setUpTunnel(outer_version, use_null_crypt) + + # Verify functionality before migration + local_inner = tunnel.addrs[inner_version] + remote_inner = _GetRemoteInnerAddress(inner_version) + func(tunnel, inner_version, local_inner, remote_inner) + + # Migrate tunnel + # TODO:b/169170981 Add tests that migrate 4 -> 6 and 6 -> 4 + new_underlying_netid = self.RandomNetid(exclude=tunnel.underlying_netid) + new_local = self.MyAddress(outer_version, new_underlying_netid) + new_remote = net_test.IPV4_ADDR2 if outer_version == 4 else net_test.IPV6_ADDR2 + + tunnel.Migrate(new_underlying_netid, new_local, new_remote) + + # Verify functionality after migration + func(tunnel, inner_version, local_inner, remote_inner) + finally: + self.tearDownTunnel(tunnel) + + def ParamTestMigrateXfrmIntfInput(self, inner_version, outer_version): + self._TestTunnel(inner_version, outer_version, self._CheckTunnelInput, True) + + def ParamTestMigrateXfrmIntfOutput(self, inner_version, outer_version): + self._TestTunnel(inner_version, outer_version, self._CheckTunnelOutput, + True) + + def ParamTestMigrateXfrmIntfInOutEncrypted(self, inner_version, outer_version): + self._TestTunnel(inner_version, outer_version, self._CheckTunnelEncryption, + False) + + def ParamTestMigrateXfrmIntfIcmp(self, inner_version, outer_version): + self._TestTunnel(inner_version, outer_version, self._CheckTunnelIcmp, False) + + def ParamTestMigrateXfrmIntfEncryptionWithIcmp(self, inner_version, outer_version): + self._TestTunnel(inner_version, outer_version, + self._CheckTunnelEncryptionWithIcmp, False) + + def ParamTestMigrateXfrmIntfRekey(self, inner_version, outer_version): + self._TestTunnel(inner_version, outer_version, self._CheckTunnelRekey, + True) if __name__ == "__main__": InjectTests() |