diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/net_test/cstruct.py | 31 | ||||
-rwxr-xr-x | tests/net_test/cstruct_test.py | 60 | ||||
-rwxr-xr-x | tests/net_test/sock_diag_test.py | 35 |
3 files changed, 116 insertions, 10 deletions
diff --git a/tests/net_test/cstruct.py b/tests/net_test/cstruct.py index 62f08235..91cd72ec 100644 --- a/tests/net_test/cstruct.py +++ b/tests/net_test/cstruct.py @@ -57,7 +57,7 @@ def CalcNumElements(fmt): return len(elements) -def Struct(name, fmt, fields, substructs={}): +def Struct(name, fmt, fieldnames, substructs={}): """Function that returns struct classes.""" class Meta(type): @@ -77,12 +77,12 @@ def Struct(name, fmt, fields, substructs={}): # Name of the struct. _name = name # List of field names. - _fields = fields + _fieldnames = fieldnames # Dict mapping field indices to nested struct classes. _nested = {} - if isinstance(_fields, str): - _fields = _fields.split(" ") + if isinstance(_fieldnames, str): + _fieldnames = _fieldnames.split(" ") # Parse fmt into _format, converting any S format characters to "XXs", # where XX is the length of the struct type's packed representation. @@ -121,14 +121,14 @@ def Struct(name, fmt, fields, substructs={}): self._Parse(values) else: # Initializing from a tuple. - if len(values) != len(self._fields): - raise TypeError("%s has exactly %d fields (%d given)" % - (self._name, len(self._fields), len(values))) + if len(values) != len(self._fieldnames): + raise TypeError("%s has exactly %d fieldnames (%d given)" % + (self._name, len(self._fieldnames), len(values))) self._SetValues(values) def _FieldIndex(self, attr): try: - return self._fields.index(attr) + return self._fieldnames.index(attr) except ValueError: raise AttributeError("'%s' has no attribute '%s'" % (self._name, attr)) @@ -143,6 +143,15 @@ def Struct(name, fmt, fields, substructs={}): def __len__(cls): return cls._length + def __ne__(self, other): + return not self.__eq__(other) + + def __eq__(self, other): + return (isinstance(other, self.__class__) and + self._name == other._name and + self._fieldnames == other._fieldnames and + self._values == other._values) + @staticmethod def _MaybePackStruct(value): if hasattr(value, "__metaclass__"):# and value.__metaclass__ == Meta: @@ -156,12 +165,14 @@ def Struct(name, fmt, fields, substructs={}): def __str__(self): def FieldDesc(index, name, value): - if isinstance(value, str) and any (c not in string.printable for c in value): + if isinstance(value, str) and any( + c not in string.printable for c in value): value = value.encode("hex") return "%s=%s" % (name, value) descriptions = [ - FieldDesc(i, n, v) for i, (n, v) in enumerate(zip(self._fields, self._values))] + FieldDesc(i, n, v) for i, (n, v) in + enumerate(zip(self._fieldnames, self._values))] return "%s(%s)" % (self._name, ", ".join(descriptions)) diff --git a/tests/net_test/cstruct_test.py b/tests/net_test/cstruct_test.py new file mode 100755 index 00000000..2d5a4081 --- /dev/null +++ b/tests/net_test/cstruct_test.py @@ -0,0 +1,60 @@ +#!/usr/bin/python +# +# Copyright 2016 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import cstruct + + +# These aren't constants, they're classes. So, pylint: disable=invalid-name +TestStructA = cstruct.Struct("TestStructA", "=BI", "byte1 int2") +TestStructB = cstruct.Struct("TestStructB", "=BI", "byte1 int2") + + +class CstructTest(unittest.TestCase): + + def CheckEquals(self, a, b): + self.assertEquals(a, b) + self.assertEquals(b, a) + assert a == b + assert b == a + assert not (a != b) # pylint: disable=g-comparison-negation,superfluous-parens + assert not (b != a) # pylint: disable=g-comparison-negation,superfluous-parens + + def CheckNotEquals(self, a, b): + self.assertNotEquals(a, b) + self.assertNotEquals(b, a) + assert a != b + assert b != a + assert not (a == b) # pylint: disable=g-comparison-negation,superfluous-parens + assert not (b == a) # pylint: disable=g-comparison-negation,superfluous-parens + + def testEqAndNe(self): + a1 = TestStructA((1, 2)) + a2 = TestStructA((2, 3)) + a3 = TestStructA((1, 2)) + b = TestStructB((1, 2)) + self.CheckNotEquals(a1, b) + self.CheckNotEquals(a2, b) + self.CheckNotEquals(a1, a2) + self.CheckNotEquals(a2, a3) + for i in [a1, a2, a3, b]: + self.CheckEquals(i, i) + self.CheckEquals(a1, a3) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/net_test/sock_diag_test.py b/tests/net_test/sock_diag_test.py index b13befe9..1baafbdc 100755 --- a/tests/net_test/sock_diag_test.py +++ b/tests/net_test/sock_diag_test.py @@ -198,6 +198,41 @@ class SockDiagTest(SockDiagBaseTest): # TODO: why doesn't comparing the cstructs work? self.assertEquals(diag_msg.Pack(), sockets[0][0].Pack()) + def testCrossFamilyBytecode(self): + """Checks for a cross-family bug in inet_diag_hostcond matching. + + Relevant kernel commits: + android-3.4: + f67caec inet_diag: avoid unsafe and nonsensical prefix matches in inet_diag_bc_run() + """ + pair4 = net_test.CreateSocketPair(AF_INET, SOCK_STREAM, "127.0.0.1") + pair6 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::1") + + bytecode4 = self.sock_diag.PackBytecode([ + (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("0.0.0.0", 0, -1))]) + bytecode6 = self.sock_diag.PackBytecode([ + (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::", 0, -1))]) + + # IPv4/v6 filters must never match IPv6/IPv4 sockets... + v4sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4) + self.assertTrue(v4sockets) + self.assertTrue(all(d.family == AF_INET for d, _ in v4sockets)) + + v6sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6) + self.assertTrue(v6sockets) + self.assertTrue(all(d.family == AF_INET6 for d, _ in v6sockets)) + + # Except for mapped addresses, which match both IPv4 and IPv6. + pair5 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, + "::ffff:127.0.0.1") + diag_msgs = [self.sock_diag.FindSockDiagFromFd(s) for s in pair5] + v4sockets = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, + bytecode4)] + v6sockets = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, + bytecode6)] + self.assertTrue(all(d in v4sockets for d in diag_msgs)) + self.assertTrue(all(d in v6sockets for d in diag_msgs)) + @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") def testClosesSockets(self): self.socketpairs = self._CreateLotsOfSockets() |