summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLorenzo Colitti <lorenzo@google.com>2014-03-31 13:25:48 +0900
committerLorenzo Colitti <lorenzo@google.com>2015-02-02 17:47:24 +0900
commita25ebd97dd121c7f89e59fa440b86ffc9822d012 (patch)
tree4684a3e7184cc78bc65d0f8c936e146121cfd280
parent717357af1ea56f3474a17fb909efc0f86db31e15 (diff)
downloadextras-a25ebd97dd121c7f89e59fa440b86ffc9822d012.tar.gz
Improve MarkTest.
1. Add TCP SYN+ACK tests including syncookies and checks that accepting connections succeeids and that the sockets returned by accept() are marked. 2. Mark the tests more robust with respect to extra packets by always explicitly expecting packets (including when testing outgoing kernel-generated packets) and looking for them anywhere in the queue instead of insisting they're the first packet in the queue. 3. Make the tests more robust by using random source port, disabling ICMP rate limits, setting SO_REUSEADDR, and clearing queues more reliably. 4. Move from 2 to 4 interfaces (mostly made possible by the robustness improvements above). 5. Use named constants instead of repeating the numbers in multiple places. Change-Id: I596e557a7eea02ccf603c812a9b8ea6f5b2f95da
-rwxr-xr-xtests/net_test/mark_test.py427
-rwxr-xr-xtests/net_test/net_test.py10
-rwxr-xr-xtests/net_test/run_net_test.sh2
3 files changed, 307 insertions, 132 deletions
diff --git a/tests/net_test/mark_test.py b/tests/net_test/mark_test.py
index 775b4418..d8376626 100755
--- a/tests/net_test/mark_test.py
+++ b/tests/net_test/mark_test.py
@@ -1,9 +1,11 @@
#!/usr/bin/python
-import fcntl
import errno
+import fcntl
import os
import posix
+import random
+import re
import struct
import time
import unittest
@@ -21,6 +23,21 @@ TUNSETIFF = 0x400454ca
AUTOCONF_TABLE_SYSCTL = "/proc/sys/net/ipv6/route/autoconf_table_offset"
+PING_IDENT = 0xff19
+PING_PAYLOAD = "foobarbaz"
+PING_SEQ = 3
+PING_TOS = 0x83
+
+TCP_SYN = 2
+TCP_RST = 4
+TCP_ACK = 16
+
+TCP_SEQ = 1692871236
+TCP_WINDOW = 14400
+
+UDP_PAYLOAD = "hello"
+
+
class ConfigurationError(AssertionError):
pass
@@ -32,6 +49,10 @@ class UnexpectedPacketError(AssertionError):
class Packets(object):
@staticmethod
+ def RandomPort():
+ return random.randint(1025, 65535)
+
+ @staticmethod
def _GetIpLayer(version):
return {4: scapy.IP, 6: scapy.IPv6}[version]
@@ -45,19 +66,25 @@ class Packets(object):
raise ValueError("Can't find ToS Field")
@classmethod
- def UdpPacket(self, version, srcaddr, dstaddr):
+ def UDP(self, version, srcaddr, dstaddr, sport=0):
ip = self._GetIpLayer(version)
+ # Can't just use "if sport" because None has meaning (it means unspecified).
+ if sport == 0:
+ sport = self.RandomPort()
return ("UDPv%d packet" % version,
ip(src=srcaddr, dst=dstaddr) /
- scapy.UDP(sport=999, dport=1234) / "hello")
+ scapy.UDP(sport=sport, dport=53) / UDP_PAYLOAD)
@classmethod
- def SYN(self, port, version, srcaddr, dstaddr):
+ def SYN(self, dport, version, srcaddr, dstaddr, sport=0, seq=TCP_SEQ):
ip = self._GetIpLayer(version)
+ if sport == 0:
+ sport = self.RandomPort()
return ("TCP SYN",
ip(src=srcaddr, dst=dstaddr) /
- scapy.TCP(sport=50999, dport=port, seq=1692871236, ack=0,
- flags=2, window=14400))
+ scapy.TCP(sport=sport, dport=dport,
+ seq=seq, ack=0,
+ flags=TCP_SYN, window=TCP_WINDOW))
@classmethod
def RST(self, version, srcaddr, dstaddr, packet):
@@ -67,7 +94,7 @@ class Packets(object):
ip(src=srcaddr, dst=dstaddr) /
scapy.TCP(sport=original.dport, dport=original.sport,
ack=original.seq + 1, seq=None,
- flags=20, window=None))
+ flags=TCP_RST | TCP_ACK, window=TCP_WINDOW))
@classmethod
def SYNACK(self, version, srcaddr, dstaddr, packet):
@@ -77,7 +104,18 @@ class Packets(object):
ip(src=srcaddr, dst=dstaddr) /
scapy.TCP(sport=original.dport, dport=original.sport,
ack=original.seq + 1, seq=None,
- flags=18, window=None))
+ flags=TCP_SYN | TCP_ACK, window=None))
+
+ @classmethod
+ def ACK(self, version, srcaddr, dstaddr, packet):
+ ip = self._GetIpLayer(version)
+ original = packet.getlayer("TCP")
+ was_syn = (original.flags & TCP_SYN) != 0
+ return ("TCP ACK",
+ ip(src=srcaddr, dst=dstaddr) /
+ scapy.TCP(sport=original.dport, dport=original.sport,
+ ack=original.seq + was_syn, seq=original.ack,
+ flags=TCP_ACK, window=TCP_WINDOW))
@classmethod
def ICMPPortUnreachable(self, version, srcaddr, dstaddr, packet):
@@ -97,26 +135,26 @@ class Packets(object):
ip = self._GetIpLayer(version)
icmp = {4: scapy.ICMP, 6: scapy.ICMPv6EchoRequest}[version]
packet = (ip(src=srcaddr, dst=dstaddr) /
- icmp(id=0xff19, seq=3) / "foobarbaz")
- self._SetPacketTos(packet, 0x83)
+ icmp(id=PING_IDENT, seq=PING_SEQ) / PING_PAYLOAD)
+ self._SetPacketTos(packet, PING_TOS)
return ("ICMPv%d echo" % version, packet)
@classmethod
- def ICMPReply(self, version, srcaddr, dstaddr, packet, tos=None):
+ def ICMPReply(self, version, srcaddr, dstaddr, packet):
ip = self._GetIpLayer(version)
-
# Scapy doesn't provide an ICMP echo reply constructor.
icmpv4_reply = lambda **kwargs: scapy.ICMP(type=0, **kwargs)
icmp = {4: icmpv4_reply, 6: scapy.ICMPv6EchoReply}[version]
packet = (ip(src=srcaddr, dst=dstaddr) /
- icmp(id=0xff19, seq=3) / "foobarbaz")
- self._SetPacketTos(packet, 0x83)
+ icmp(id=PING_IDENT, seq=PING_SEQ) / PING_PAYLOAD)
+ self._SetPacketTos(packet, PING_TOS)
return ("ICMPv%d echo" % version, packet)
class MarkTest(net_test.NetworkTest):
- NETIDS = [100, 200]
+ # Must be between 1 and 256, since we put them in MAC addresses and IIDs.
+ NETIDS = [100, 150, 200, 250]
@staticmethod
def _RouterMacAddress(netid):
@@ -135,19 +173,27 @@ class MarkTest(net_test.NetworkTest):
else:
raise ValueError("Don't support IPv%s" % version)
- @staticmethod
- def _MyIPv4Address(netid):
+ @classmethod
+ def _MyIPv4Address(self, netid):
return "10.0.%d.2" % netid
@classmethod
+ def _MyIPv6Address(self, netid):
+ return net_test.GetLinkAddress(self._GetInterfaceName(netid), False)
+
+ @classmethod
+ def _MyAddress(self, version, netid):
+ return {4: self._MyIPv4Address(netid),
+ 6: self._MyIPv6Address(netid)}[version]
+
+ @classmethod
def _CreateTunInterface(self, netid):
iface = self._GetInterfaceName(netid)
f = open("/dev/net/tun", "r+b")
ifr = struct.pack("16sH", iface, IFF_TAP | IFF_NO_PI)
- ifr = ifr + "\x00" * (40 - len(ifr))
+ ifr += "\x00" * (40 - len(ifr))
fcntl.ioctl(f, TUNSETIFF, ifr)
# Give ourselves a predictable MAC address.
- macaddr = self._MyMacAddress(netid)
net_test.SetInterfaceHWAddr(iface, self._MyMacAddress(netid))
# Disable DAD so we don't have to wait for it.
open("/proc/sys/net/ipv6/conf/%s/dad_transmits" % iface, "w").write("0")
@@ -244,13 +290,36 @@ class MarkTest(net_test.NetworkTest):
if self.AUTOCONF_TABLE_OFFSET >= 0:
return self.ifindices[netid] + self.AUTOCONF_TABLE_OFFSET
else:
- return netid
+ return netid
+
+ @classmethod
+ def _ICMPRatelimitFilename(self, version):
+ return "/proc/sys/net/" + {4: "ipv4/icmp_ratelimit",
+ 6: "ipv6/icmp/ratelimit"}[version]
+
+ @classmethod
+ def _GetICMPRatelimit(self, version):
+ return int(open(self._ICMPRatelimitFilename(version), "r").read().strip())
+
+ @classmethod
+ def _SetICMPRatelimit(self, version, limit):
+ return open(self._ICMPRatelimitFilename(version), "w").write("%d" % limit)
@classmethod
def setUpClass(self):
+ # This is per-class setup instead of per-testcase setup because shelling out
+ # to ip and iptables is slow, and because routing configuration doesn't
+ # change during the test.
self.tuns = {}
self.ifindices = {}
self._SetAutoconfTableSysctl(1000)
+
+ # Disable ICMP rate limits.
+ self.ratelimits = {}
+ for version in [4, 6]:
+ self.ratelimits[version] = self._GetICMPRatelimit(version)
+ self._SetICMPRatelimit(version, 0)
+
for netid in self.NETIDS:
self.tuns[netid] = self._CreateTunInterface(netid)
@@ -268,6 +337,7 @@ class MarkTest(net_test.NetworkTest):
# combination is tried.
self.listenport = 1234
self.listensocket = net_test.IPv6TCPSocket()
+ self.listensocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
self.listensocket.bind(("::", self.listenport))
self.listensocket.listen(100)
@@ -276,77 +346,108 @@ class MarkTest(net_test.NetworkTest):
# Uncomment to look around at interface and rule configuration while
# running in the background. (Once the test finishes running, all the
# interfaces and rules are gone.)
- #time.sleep(30)
+ # time.sleep(30)
@classmethod
def tearDownClass(self):
for netid in self.tuns:
self._RunSetupCommands(netid, False)
self.tuns[netid].close()
-
- def CheckExpectedPacket(self, expected, actual, msg):
- # Remove the Ethernet header from the incoming packet.
- actual = scapy.Ether(actual).payload
-
- # Blank out IPv4 fields that we can't predict, like ID and the DF bit.
- actualip = actual.getlayer("IP")
- expectedip = expected.getlayer("IP")
- if actualip and expectedip:
- actualip.id = expectedip.id
- actualip.flags &= 5
- actualip.chksum = None # Change the header, recalculate the checksum.
-
- # Blank out TCP fields that we can't predict.
- actualtcp = actual.getlayer("TCP")
- expectedtcp = expected.getlayer("TCP")
- if actualtcp and expectedtcp:
- actualtcp.dataofs = expectedtcp.dataofs
- actualtcp.options = expectedtcp.options
- actualtcp.window = expectedtcp.window
- if expectedtcp.seq is None:
- actualtcp.seq = None
- if expectedtcp.ack is None:
- actualtcp.ack = None
- actualtcp.chksum = None
-
- # Serialize the packet so:
- # - Expected packet fields that are only set when a packet is serialized
- # (e.g., the checksum) are filled in.
- # - The packet is readable. Scapy has detailed dissection capabilities,
- # but they only seem to be usable to print the packet, not return its
- # dissection as a string.
- # TODO: Check if this is true.
- self.assertMultiLineEqual(str(expected).encode("hex"),
- str(actual).encode("hex"))
-
- def assertNoPacketsOn(self, netids, msg):
- for netid in netids:
+ self._SetAutoconfTableSysctl(-1)
+ for version in [4, 6]:
+ self._SetICMPRatelimit(version, self.ratelimits[version])
+
+ def assertPacketMatches(self, expected, actual):
+ # Remove the Ethernet header from the incoming packet.
+ actual = scapy.Ether(actual).payload
+
+ # Blank out IPv4 fields that we can't predict, like ID and the DF bit.
+ actualip = actual.getlayer("IP")
+ expectedip = expected.getlayer("IP")
+ if actualip and expectedip:
+ actualip.id = expectedip.id
+ actualip.flags &= 5
+ actualip.chksum = None # Change the header, recalculate the checksum.
+
+ # Blank out UDP fields that we can't predict (e.g., the source port for
+ # kernel-originated packets).
+ actualudp = actual.getlayer("UDP")
+ expectedudp = expected.getlayer("UDP")
+ if actualudp and expectedudp:
+ if expectedudp.sport is None:
+ actualudp.sport = None
+ actualudp.chksum = None
+
+ # Since the TCP code below messes with options, recalculate the length.
+ if actualip:
+ actualip.len = None
+ actualipv6 = actual.getlayer("IPv6")
+ if actualipv6:
+ actualipv6.plen = None
+
+ # Blank out TCP fields that we can't predict.
+ actualtcp = actual.getlayer("TCP")
+ expectedtcp = expected.getlayer("TCP")
+ if actualtcp and expectedtcp:
+ actualtcp.dataofs = expectedtcp.dataofs
+ actualtcp.options = expectedtcp.options
+ actualtcp.window = expectedtcp.window
+ if expectedtcp.sport is None:
+ actualtcp.sport = None
+ if expectedtcp.seq is None:
+ actualtcp.seq = None
+ if expectedtcp.ack is None:
+ actualtcp.ack = None
+ actualtcp.chksum = None
+
+ # Serialize the packet so:
+ # - Expected packet fields that are only set when a packet is serialized
+ # (e.g., the checksum) are filled in.
+ # - The packet is vaguely human-readable. Scapy has sophisticated packet
+ # dissection capabilities, but unfortunately they can only be used to
+ # print the packet, not to return its dissection as as string.
+ self.assertMultiLineEqual(str(expected).encode("hex"),
+ str(actual).encode("hex"))
+
+ def PacketMatches(self, expected, actual):
+ try:
+ self.assertPacketMatches(expected, actual)
+ return True
+ except AssertionError:
+ return False
+
+ def ReadAllPacketsOn(self, netid):
+ packets = []
+ while True:
try:
- self.assertRaisesErrno(errno.EAGAIN, self.tuns[netid].read, 4096)
- except AssertionError, e:
- raise UnexpectedPacketError("%s: Unexpected packet on %s" % (
- msg, self._GetInterfaceName(netid)))
-
- def assertNoOtherPackets(self, msg):
- self.assertNoPacketsOn([netid for netid in self.tuns], msg)
-
- def assertNoPacketsExceptOn(self, netid, msg):
- self.assertNoPacketsOn([n for n in self.tuns if n != netid], msg)
-
- def ExpectPacketOn(self, netid, msg, expected=None):
- # Check no packets were sent on any other netid.
- self.assertNoPacketsExceptOn(netid, msg)
-
- # Check that a packet was sent on netid.
+ packets.append(posix.read(self.tuns[netid].fileno(), 4096))
+ except OSError, e:
+ # EAGAIN means there are no more packets waiting.
+ if re.match(e.message, os.strerror(errno.EAGAIN)):
+ break
+ # Anything else is unexpected.
+ else:
+ raise e
+ return packets
+
+ def ExpectPacketOn(self, netid, msg, expected):
+ packets = self.ReadAllPacketsOn(netid)
+ self.assertTrue(packets, msg + ": received no packets")
+
+ # If we receive a packet that matches what we expected, return it.
+ for packet in packets:
+ if self.PacketMatches(expected, packet):
+ return scapy.Ether(packet).payload
+
+ # None of the packets matched. Call assertPacketMatches to output a diff
+ # between the expected packet and the last packet we received. In theory,
+ # we'd output a diff to the packet that's the best match for what we
+ # expected, but this is good enough for now.
try:
- actual = self.tuns[netid].read(4096)
- except IOError, e:
- raise AssertionError(msg + ": " + str(e))
- self.assertTrue(actual)
-
- # If we know what sort of packet we expect, check that here.
- if expected:
- self.CheckExpectedPacket(expected, actual, msg)
+ self.assertPacketMatches(expected, packets[-1])
+ except Exception, e:
+ raise UnexpectedPacketError(
+ "%s: diff with last packet:\n%s" % (msg, e.message))
def ReceivePacketOn(self, netid, ip_packet):
routermac = self._RouterMacAddress(netid)
@@ -355,12 +456,10 @@ class MarkTest(net_test.NetworkTest):
posix.write(self.tuns[netid].fileno(), str(packet))
def ClearTunQueues(self):
- for f in self.tuns.values():
- try:
- f.read(4096)
- except IOError:
- continue
- self.assertNoOtherPackets("Unexpected packets after clearing queues")
+ # Keep reading packets on all netids until we get no packets on any of them.
+ waiting = None
+ while waiting != 0:
+ waiting = sum(len(self.ReadAllPacketsOn(netid)) for netid in self.NETIDS)
def setUp(self):
self.ClearTunQueues()
@@ -369,33 +468,53 @@ class MarkTest(net_test.NetworkTest):
def _GetRemoteAddress(version):
return {4: net_test.IPV4_ADDR, 6: net_test.IPV6_ADDR}[version]
- def MarkSocket(self, s, netid):
+ def SetSocketMark(self, s, netid):
s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid)
+ def GetSocketMark(self, s):
+ return s.getsockopt(SOL_SOCKET, net_test.SO_MARK)
+
def GetProtocolFamily(self, version):
return {4: AF_INET, 6: AF_INET6}[version]
def testOutgoingPackets(self):
"""Checks that socket marking selects the right outgoing interface."""
- def CheckPingPacket(version, netid, packet):
+ def CheckPingPacket(version, netid, dstaddr, packet):
s = net_test.PingSocket(self.GetProtocolFamily(version))
- dstaddr = self._GetRemoteAddress(version)
- self.MarkSocket(s, netid)
- s.sendto(packet, (dstaddr, 19321))
- self.ExpectPacketOn(netid, "IPv%d ping: mark %d" % (version, netid))
+ myaddr = self._MyAddress(version, netid)
+ s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
+ s.bind((myaddr, PING_IDENT))
+ self.SetSocketMark(s, netid)
+ net_test.SetSocketTos(s, PING_TOS)
+
+ desc, expected = Packets.ICMPEcho(version, myaddr, dstaddr)
+
+ self.ClearTunQueues()
+ s.sendto(packet + PING_PAYLOAD, (dstaddr, 19321))
+ msg = "IPv%d ping: expected %s on %s" % (
+ version, desc, self._GetInterfaceName(netid))
+ self.ExpectPacketOn(netid, msg, expected)
for netid in self.tuns:
- CheckPingPacket(4, netid, net_test.IPV4_PING)
- CheckPingPacket(6, netid, net_test.IPV6_PING)
+ CheckPingPacket(4, netid, net_test.IPV4_ADDR, net_test.IPV4_PING)
+ CheckPingPacket(6, netid, net_test.IPV6_ADDR, net_test.IPV6_PING)
def CheckTCPSYNPacket(version, netid, dstaddr):
s = net_test.TCPSocket(self.GetProtocolFamily(version))
- self.MarkSocket(s, netid)
+ self.SetSocketMark(s, netid)
+ if version == 6 and dstaddr.startswith("::ffff"):
+ version = 4
+ myaddr = self._MyAddress(version, netid)
+ desc, expected = Packets.SYN(53, version, myaddr, dstaddr,
+ sport=None, seq=None)
+
+ self.ClearTunQueues()
# Non-blocking TCP connects always return EINPROGRESS.
self.assertRaisesErrno(errno.EINPROGRESS, s.connect, (dstaddr, 53))
- self.ExpectPacketOn(netid, "IPv%d TCP connect: mark %d" % (version,
- netid))
+ msg = "IPv%s TCP connect: expected %s on %s" % (
+ version, desc, self._GetInterfaceName(netid))
+ self.ExpectPacketOn(netid, msg, expected)
s.close()
for netid in self.tuns:
@@ -405,13 +524,22 @@ class MarkTest(net_test.NetworkTest):
def CheckUDPPacket(version, netid, dstaddr):
s = net_test.UDPSocket(self.GetProtocolFamily(version))
- self.MarkSocket(s, netid)
- s.sendto("hello", (dstaddr, 53))
- self.ExpectPacketOn(netid, "IPv%d UDP sendto: mark %d" % (version, netid))
+ self.SetSocketMark(s, netid)
+ if version == 6 and dstaddr.startswith("::ffff"):
+ version = 4
+ myaddr = self._MyAddress(version, netid)
+ desc, expected = Packets.UDP(version, myaddr, dstaddr, sport=None)
+ msg = "IPv%s UDP %%s: expected %s on %s" % (
+ version, desc, self._GetInterfaceName(netid))
+
+ self.ClearTunQueues()
+ s.sendto(UDP_PAYLOAD, (dstaddr, 53))
+ self.ExpectPacketOn(netid, msg % "sendto", expected)
+
+ self.ClearTunQueues()
s.connect((dstaddr, 53))
- s.send("hello")
- self.ExpectPacketOn(netid, "IPv%d UDP connect/send: mark %d" % (version,
- netid))
+ s.send(UDP_PAYLOAD)
+ self.ExpectPacketOn(netid, msg % "connect/send", expected)
s.close()
for netid in self.tuns:
@@ -419,38 +547,53 @@ class MarkTest(net_test.NetworkTest):
CheckUDPPacket(6, netid, net_test.IPV6_ADDR)
CheckUDPPacket(6, netid, "::ffff:" + net_test.IPV4_ADDR)
- def CheckReflection(self, version, packet_generator, reply_generator):
- """Checks that replies go out on the same interface as the original."""
-
+ def CheckReflection(self, version, packet_generator, reply_generator,
+ callback=None):
+ """Checks that replies go out on the same interface as the original.
+
+ Iterates through all the combinations of the interfaces in self.tuns and the
+ IP addresses assigned to them. For each combination:
+ - Calls packet_generator to generate a packet to that IP address.
+ - Writes the packet generated by packet_generator on the given tun
+ interface, causing the kernel to receive it.
+ - Checks that the kernel's reply matches the packet generated by
+ reply_generator.
+ - Calls the given callback function.
+
+ Args:
+ version: An integer, 4 or 6.
+ packet_generator: A function taking an IP version (an integer), a source
+ address and a destination address (strings), and returning a scapy
+ packet.
+ reply_generator: A function taking the same arguments as packet_generator,
+ plus a scapy packet, and returning a scapy packet.
+ callback: A function to call to perform extra checks if the packet
+ matches. Takes netid, version, local address, remote address, original
+ packet, kernel reply, and a message.
+ """
# Check packets addressed to the IP addresses of all our interfaces...
for dest_ip_netid in self.tuns:
dest_ip_iface = self._GetInterfaceName(dest_ip_netid)
- if version == 4:
- myaddr = self._MyIPv4Address(dest_ip_netid)
- else:
- myaddr = net_test.GetLinkAddress(self._GetInterfaceName(dest_ip_netid),
- False)
+ myaddr = self._MyAddress(version, dest_ip_netid)
remote_addr = self._GetRemoteAddress(version)
# ... coming in on all our interfaces.
for iif_netid in self.tuns:
iif = self._GetInterfaceName(iif_netid)
desc, packet = packet_generator(version, remote_addr, myaddr)
- if reply_generator:
- # We know what we want a reply to.
- reply_desc, reply = reply_generator(version, myaddr, remote_addr,
- packet)
- else:
- # Expect any reply.
- reply_desc, reply = "any packet", None
+ reply_desc, reply = reply_generator(version, myaddr, remote_addr,
+ packet)
msg = "Receiving %s on %s to %s IP: Expecting %s on %s" % (
desc, iif, dest_ip_iface, reply_desc, iif)
- # Expect a reply on the interface the original packet came in on.
self.ClearTunQueues()
+ # Cause the kernel to receive packet on iif_netid.
self.ReceivePacketOn(iif_netid, packet)
- self.ExpectPacketOn(iif_netid, msg, reply)
+ # Expect the kernel to send out reply on the same interface.
+ reply = self.ExpectPacketOn(iif_netid, msg, reply)
+ if callback:
+ callback(iif_netid, version, myaddr, remote_addr, packet, reply, msg)
def SYNToClosedPort(self, *args):
return Packets.SYN(999, *args)
@@ -459,10 +602,10 @@ class MarkTest(net_test.NetworkTest):
return Packets.SYN(self.listenport, *args)
def testIPv4ICMPErrorsReflectMark(self):
- self.CheckReflection(4, Packets.UdpPacket, Packets.ICMPPortUnreachable)
+ self.CheckReflection(4, Packets.UDP, Packets.ICMPPortUnreachable)
def testIPv6ICMPErrorsReflectMark(self):
- self.CheckReflection(6, Packets.UdpPacket, Packets.ICMPPortUnreachable)
+ self.CheckReflection(6, Packets.UDP, Packets.ICMPPortUnreachable)
def testIPv4PingRepliesReflectMarkAndTos(self):
self.CheckReflection(4, Packets.ICMPEcho, Packets.ICMPReply)
@@ -476,13 +619,39 @@ class MarkTest(net_test.NetworkTest):
def testIPv6RSTsReflectMark(self):
self.CheckReflection(6, self.SYNToClosedPort, Packets.RST)
- @unittest.skipUnless(False, "skipping: doesn't work yet")
+ def CheckAcceptedSocketMarkCallback(self, netid, version, myaddr,
+ remote_addr, packet, reply, msg):
+ establishing_ack = Packets.ACK(version, remote_addr, myaddr, reply)[1]
+ self.ReceivePacketOn(netid, establishing_ack)
+ s, unused_peer = self.listensocket.accept()
+ try:
+ mark = self.GetSocketMark(s)
+ finally:
+ s.close()
+ self.assertEquals(netid, mark,
+ msg + ": Accepted socket: Expected mark %d, got %d" % (
+ netid, mark))
+
def testIPv4SYNACKsReflectMark(self):
- self.CheckReflection(4, Packets.SYNToOpenPort, Packets.SYNACK)
+ self.CheckReflection(4, self.SYNToOpenPort, Packets.SYNACK,
+ self.CheckAcceptedSocketMarkCallback)
- @unittest.skipUnless(False, "skipping: doesn't work yet")
def testIPv6SYNACKsReflectMark(self):
- self.CheckReflection(6, Packets.SYNToOpenPort, Packets.SYNACK)
+ self.CheckReflection(6, self.SYNToOpenPort, Packets.SYNACK,
+ self.CheckAcceptedSocketMarkCallback)
+
+ def testSynCookiesSYNACKsReflectMark(self):
+ # Force SYN cookies on all connections.
+ open("/proc/sys/net/ipv4/tcp_syncookies", "w").write("2")
+ try:
+ self.CheckReflection(4, self.SYNToOpenPort, Packets.SYNACK,
+ self.CheckAcceptedSocketMarkCallback)
+ self.CheckReflection(6, self.SYNToOpenPort, Packets.SYNACK,
+ self.CheckAcceptedSocketMarkCallback)
+ finally:
+ # Stop forcing SYN cookies on all connections.
+ open("/proc/sys/net/ipv4/tcp_syncookies", "w").write("1")
+
if __name__ == "__main__":
diff --git a/tests/net_test/net_test.py b/tests/net_test/net_test.py
index caf79a0f..bf67785b 100755
--- a/tests/net_test/net_test.py
+++ b/tests/net_test/net_test.py
@@ -17,10 +17,11 @@ from socket import *
from scapy import all as scapy
SOL_IPV6 = 41
-IP_TRANSPARENT = 19
-IPV6_TRANSPARENT = 75
IP_RECVERR = 11
IPV6_RECVERR = 25
+IP_TRANSPARENT = 19
+IPV6_TRANSPARENT = 75
+IPV6_TCLASS = 67
SO_BINDTODEVICE = 25
SO_MARK = 36
IPV6_FLOWLABEL_MGR = 32
@@ -60,6 +61,11 @@ def SetSocketTimeout(sock, ms):
us = (ms % 1000) * 1000
sock.setsockopt(SOL_SOCKET, SO_RCVTIMEO, struct.pack("LL", s, us))
+def SetSocketTos(s, tos):
+ level = {AF_INET: SOL_IP, AF_INET6: SOL_IPV6}[s.family]
+ option = {AF_INET: IP_TOS, AF_INET6: IPV6_TCLASS}[s.family]
+ s.setsockopt(level, option, tos)
+
def SetNonBlocking(fd):
flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0)
fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
diff --git a/tests/net_test/run_net_test.sh b/tests/net_test/run_net_test.sh
index adaa3544..306ebe41 100755
--- a/tests/net_test/run_net_test.sh
+++ b/tests/net_test/run_net_test.sh
@@ -2,7 +2,7 @@
# Kernel configration options.
OPTIONS=" IPV6 IPV6_ROUTER_PREF IPV6_MULTIPLE_TABLES IPV6_ROUTE_INFO"
-OPTIONS="$OPTIONS TUN IP_ADVANCED_ROUTER IP_MULTIPLE_TABLES"
+OPTIONS="$OPTIONS TUN SYN_COOKIES IP_ADVANCED_ROUTER IP_MULTIPLE_TABLES"
OPTIONS="$OPTIONS NETFILTER NETFILTER_ADVANCED NETFILTER_XTABLES"
OPTIONS="$OPTIONS NETFILTER_XT_MARK NETFILTER_XT_TARGET_MARK"
OPTIONS="$OPTIONS IP_NF_IPTABLES IP_NF_MANGLE"