diff options
-rw-r--r-- | tests/net_test/packets.py | 5 | ||||
-rwxr-xr-x | tests/net_test/run_net_test.sh | 1 | ||||
-rwxr-xr-x | tests/net_test/sock_diag.py | 10 | ||||
-rwxr-xr-x | tests/net_test/sock_diag_test.py | 335 |
4 files changed, 346 insertions, 5 deletions
diff --git a/tests/net_test/packets.py b/tests/net_test/packets.py index d92a97e4..c02adc0a 100644 --- a/tests/net_test/packets.py +++ b/tests/net_test/packets.py @@ -120,11 +120,12 @@ def ACK(version, srcaddr, dstaddr, packet, payload=""): def FIN(version, srcaddr, dstaddr, packet): ip = _GetIpLayer(version) original = packet.getlayer("TCP") - was_fin = (original.flags & TCP_FIN) != 0 + was_syn_or_fin = (original.flags & (TCP_SYN | TCP_FIN)) != 0 + ack_delta = was_syn_or_fin + len(original.payload) return ("TCP FIN", ip(src=srcaddr, dst=dstaddr) / scapy.TCP(sport=original.dport, dport=original.sport, - ack=original.seq + was_fin, seq=original.ack, + ack=original.seq + ack_delta, seq=original.ack, flags=TCP_ACK | TCP_FIN, window=TCP_WINDOW)) def GRE(version, srcaddr, dstaddr, proto, packet): diff --git a/tests/net_test/run_net_test.sh b/tests/net_test/run_net_test.sh index d745ec31..080aac73 100755 --- a/tests/net_test/run_net_test.sh +++ b/tests/net_test/run_net_test.sh @@ -12,6 +12,7 @@ OPTIONS="$OPTIONS IPV6_PRIVACY IPV6_OPTIMISTIC_DAD" OPTIONS="$OPTIONS CONFIG_NETFILTER_XT_TARGET_NFLOG" OPTIONS="$OPTIONS CONFIG_NETFILTER_XT_MATCH_QUOTA CONFIG_NETFILTER_XT_MATCH_QUOTA2" OPTIONS="$OPTIONS CONFIG_NETFILTER_XT_MATCH_QUOTA2_LOG" +OPTIONS="$OPTIONS CONFIG_INET_UDP_DIAG CONFIG_INET_DIAG_DESTROY" # For 3.1 kernels, where devtmpfs is not on by default. OPTIONS="$OPTIONS DEVTMPFS DEVTMPFS_MOUNT" diff --git a/tests/net_test/sock_diag.py b/tests/net_test/sock_diag.py index 69785aa6..a9de3458 100755 --- a/tests/net_test/sock_diag.py +++ b/tests/net_test/sock_diag.py @@ -235,6 +235,16 @@ class SockDiag(netlink.NetlinkSocket): """Constructs a diag_req from a diag_msg the kernel has given us.""" return InetDiagReqV2((d.family, protocol, 0, 1 << d.state, d.id)) + def CloseSocket(self, req): + self._SendNlRequest(SOCK_DESTROY, req.Pack(), + netlink.NLM_F_REQUEST | netlink.NLM_F_ACK) + + def CloseSocketFromFd(self, s): + diag_msg = self.FindSockDiagFromFd(s) + protocol = s.getsockopt(SOL_SOCKET, net_test.SO_PROTOCOL) + req = self.DiagReqFromDiagMsg(diag_msg, protocol) + return self.CloseSocket(req) + if __name__ == "__main__": n = SockDiag() diff --git a/tests/net_test/sock_diag_test.py b/tests/net_test/sock_diag_test.py index 7eff7e40..59759315 100755 --- a/tests/net_test/sock_diag_test.py +++ b/tests/net_test/sock_diag_test.py @@ -14,7 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import errno +from errno import * +import os import random from socket import * import time @@ -26,12 +27,15 @@ import multinetwork_base import net_test import packets import sock_diag +import threading NUM_SOCKETS = 100 ALL_NON_TIME_WAIT = 0xffffffff & ~(1 << sock_diag.TCP_TIME_WAIT) +# TODO: Backport SOCK_DESTROY and delete this. +HAVE_SOCK_DESTROY = net_test.LINUX_VERSION >= (4, 4) class SockDiagTest(multinetwork_base.MultiNetworkBaseTest): @@ -48,11 +52,13 @@ class SockDiagTest(multinetwork_base.MultiNetworkBaseTest): return socketpairs def setUp(self): + super(SockDiagTest, self).setUp() self.sock_diag = sock_diag.SockDiag() - self.socketpairs = self._CreateLotsOfSockets() + self.socketpairs = {} def tearDown(self): [s.close() for socketpair in self.socketpairs.values() for s in socketpair] + super(SockDiagTest, self).tearDown() def testFixupDiagMsg(self): src = "0a00fa02303030312030312038302031" @@ -78,6 +84,16 @@ class SockDiagTest(multinetwork_base.MultiNetworkBaseTest): msg4.id.dst = dst.decode("hex")[:4] + 12 * "\x00" self.assertEquals(msg4.Pack(), fixed4.Pack()) + def assertSocketClosed(self, sock): + self.assertRaisesErrno(ENOTCONN, sock.getpeername) + + def assertSocketConnected(self, sock): + sock.getpeername() # No errors? Socket is alive and connected. + + def assertSocketsClosed(self, socketpair): + for sock in socketpair: + self.assertSocketClosed(sock) + def assertSockDiagMatchesSocket(self, s, diag_msg): family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN) self.assertEqual(diag_msg.family, family) @@ -93,9 +109,10 @@ class SockDiagTest(multinetwork_base.MultiNetworkBaseTest): self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst)) self.assertEqual(diag_msg.id.dport, dport) else: - assertRaisesErrno(errno.ENOTCONN, s.getpeername) + assertRaisesErrno(ENOTCONN, s.getpeername) def testFindsAllMySockets(self): + self.socketpairs = self._CreateLotsOfSockets() sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, states=ALL_NON_TIME_WAIT) self.assertGreaterEqual(len(sockets), NUM_SOCKETS) @@ -131,6 +148,318 @@ class SockDiagTest(multinetwork_base.MultiNetworkBaseTest): diag_msg = self.sock_diag.GetSockDiag(req) self.assertSockDiagMatchesSocket(sock, diag_msg) + @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") + def testClosesSockets(self): + self.socketpairs = self._CreateLotsOfSockets() + for (addr, _, _), socketpair in self.socketpairs.iteritems(): + # Close one of the sockets. + # This will send a RST that will close the other side as well. + s = random.choice(socketpair) + if random.randrange(0, 2) == 1: + self.sock_diag.CloseSocketFromFd(s) + else: + diag_msg = self.sock_diag.FindSockDiagFromFd(s) + family = AF_INET6 if ":" in addr else AF_INET + + # Get the cookie wrong and ensure that we get an error and the socket + # is not closed. + real_cookie = diag_msg.id.cookie + diag_msg.id.cookie = os.urandom(len(real_cookie)) + req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP) + self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req) + self.assertSocketConnected(s) + + # Now close it with the correct cookie. + req.id.cookie = real_cookie + self.sock_diag.CloseSocket(req) + + # Check that both sockets in the pair are closed. + self.assertSocketsClosed(socketpair) + + @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") + def testNonTcpSockets(self): + s = socket(AF_INET6, SOCK_DGRAM, 0) + s.connect(("::1", 53)) + diag_msg = self.sock_diag.FindSockDiagFromFd(s) + self.assertRaisesErrno(EOPNOTSUPP, self.sock_diag.CloseSocketFromFd, s) + + def testNonSockDiagCommand(self): + def DiagDump(code): + sock_id = self.sock_diag._EmptyInetDiagSockId() + req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff, + sock_id)) + self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg) + + op = sock_diag.SOCK_DIAG_BY_FAMILY + DiagDump(op) # No errors? Good. + self.assertRaisesErrno(EINVAL, DiagDump, op + 17) + + # TODO: + # Test that killing unix sockets returns EOPNOTSUPP. + + +class SocketExceptionThread(threading.Thread): + + def __init__(self, sock, operation): + self.exception = None + super(SocketExceptionThread, self).__init__() + self.daemon = True + self.sock = sock + self.operation = operation + + def run(self): + try: + self.operation(self.sock) + except Exception, e: + self.exception = e + + +# TODO: Take a tun fd as input, make this a utility class, and reuse at least +# in forwarding_test. +class TcpTest(SockDiagTest): + + NOT_YET_ACCEPTED = -1 + + def setUp(self): + super(TcpTest, self).setUp() + self.sock_diag = sock_diag.SockDiag() + self.netid = random.choice(self.tuns.keys()) + + def OpenListenSocket(self, version): + self.port = packets.RandomPort() + family = {4: AF_INET, 6: AF_INET6}[version] + address = {4: "0.0.0.0", 6: "::"}[version] + s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP) + s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) + s.bind((address, self.port)) + # We haven't configured inbound iptables marking, so bind explicitly. + self.SelectInterface(s, self.netid, "mark") + s.listen(100) + return s + + def _ReceiveAndExpectResponse(self, netid, packet, reply, msg): + pkt = super(TcpTest, self)._ReceiveAndExpectResponse(netid, packet, + reply, msg) + self.last_packet = pkt + return pkt + + def ReceivePacketOn(self, netid, packet): + super(TcpTest, self).ReceivePacketOn(netid, packet) + self.last_packet = packet + + def RstPacket(self): + return packets.RST(self.version, self.myaddr, self.remoteaddr, + self.last_packet) + + def IncomingConnection(self, version, end_state, netid): + self.version = version + self.s = self.OpenListenSocket(version) + self.end_state = end_state + + remoteaddr = self.remoteaddr = self.GetRemoteAddress(version) + myaddr = self.myaddr = self.MyAddress(version, netid) + + if end_state == sock_diag.TCP_LISTEN: + return + + desc, syn = packets.SYN(self.port, version, remoteaddr, myaddr) + synack_desc, synack = packets.SYNACK(version, myaddr, remoteaddr, syn) + msg = "Received %s, expected to see reply %s" % (desc, synack_desc) + reply = self._ReceiveAndExpectResponse(netid, syn, synack, msg) + if end_state == sock_diag.TCP_SYN_RECV: + return + + establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1] + self.ReceivePacketOn(netid, establishing_ack) + + if end_state == self.NOT_YET_ACCEPTED: + return + + self.accepted, _ = self.s.accept() + if end_state == sock_diag.TCP_ESTABLISHED: + return + + desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack, + payload=net_test.UDP_PAYLOAD) + self.accepted.send(net_test.UDP_PAYLOAD) + self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data) + + desc, fin = packets.FIN(version, remoteaddr, myaddr, data) + fin = packets._GetIpLayer(version)(str(fin)) + ack_desc, ack = packets.ACK(version, myaddr, remoteaddr, fin) + msg = "Received %s, expected to see reply %s" % (desc, ack_desc) + + # TODO: Why can't we use this? + # self._ReceiveAndExpectResponse(netid, fin, ack, msg) + self.ReceivePacketOn(netid, fin) + time.sleep(0.1) + self.ExpectPacketOn(netid, msg + ": expecting %s" % ack_desc, ack) + if end_state == sock_diag.TCP_CLOSE_WAIT: + return + + raise ValueError("Invalid TCP state %d specified" % end_state) + + def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True): + """Closes the socket and checks whether a RST is sent or not.""" + if sock is not None: + self.assertIsNone(req, "Must specify sock or req, not both") + self.sock_diag.CloseSocketFromFd(sock) + self.assertRaisesErrno(EINVAL, sock.accept) + else: + self.assertIsNone(sock, "Must specify sock or req, not both") + self.sock_diag.CloseSocket(req) + + if expect_reset: + desc, rst = self.RstPacket() + msg = "%s: expecting %s: " % (msg, desc) + self.ExpectPacketOn(self.netid, msg, rst) + else: + msg = "%s: " % msg + self.ExpectNoPacketsOn(self.netid, msg) + + if sock is not None and do_close: + sock.close() + + def CheckTcpReset(self, state, statename): + for version in [4, 6]: + msg = "Closing incoming IPv%d %s socket" % (version, statename) + self.IncomingConnection(version, state, self.netid) + self.CheckRstOnClose(self.s, None, False, msg) + if state != sock_diag.TCP_LISTEN: + msg = "Closing accepted IPv%d %s socket" % (version, statename) + self.CheckRstOnClose(self.accepted, None, True, msg) + + @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") + def testTcpResets(self): + """Checks that closing sockets in appropriate states sends a RST.""" + self.CheckTcpReset(sock_diag.TCP_LISTEN, "TCP_LISTEN") + self.CheckTcpReset(sock_diag.TCP_ESTABLISHED, "TCP_ESTABLISHED") + self.CheckTcpReset(sock_diag.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT") + + def FindChildSockets(self, s): + """Finds the SYN_RECV child sockets of a given listening socket.""" + d = self.sock_diag.FindSockDiagFromFd(self.s) + req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) + req.states = 1 << sock_diag.TCP_SYN_RECV | 1 << sock_diag.TCP_ESTABLISHED + req.id.cookie = "\x00" * 8 + children = self.sock_diag.Dump(req) + return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) + for d, _ in children] + + def CheckChildSocket(self, state, statename, parent_first): + for version in [4, 6]: + self.IncomingConnection(version, state, self.netid) + + d = self.sock_diag.FindSockDiagFromFd(self.s) + parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) + children = self.FindChildSockets(self.s) + self.assertEquals(1, len(children)) + + is_established = (state == self.NOT_YET_ACCEPTED) + + # The new TCP listener code in 4.4 makes SYN_RECV sockets live in the + # regular TCP hash tables, and inet_diag_find_one_icsk can find them. + # Before 4.4, we can see those sockets in dumps, but we can't fetch + # or close them. + can_close_children = is_established or net_test.LINUX_VERSION >= (4, 4) + + for child in children: + if can_close_children: + self.sock_diag.GetSockDiag(child) # No errors? Good, child found. + else: + self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child) + + def CloseParent(expect_reset): + msg = "Closing parent IPv%d %s socket %s child" % ( + version, statename, "before" if parent_first else "after") + self.CheckRstOnClose(self.s, None, expect_reset, msg) + self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, parent) + + def CheckChildrenClosed(): + for child in children: + self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child) + + def CloseChildren(): + for child in children: + msg = "Closing child IPv%d %s socket %s parent" % ( + version, statename, "after" if parent_first else "before") + self.sock_diag.GetSockDiag(child) + self.CheckRstOnClose(None, child, is_established, msg) + self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child) + CheckChildrenClosed() + + if parent_first: + # Closing the parent will close child sockets, which will send a RST, + # iff they are already established. + CloseParent(is_established) + if is_established: + CheckChildrenClosed() + elif can_close_children: + CloseChildren() + CheckChildrenClosed() + self.s.close() + else: + if can_close_children: + CloseChildren() + CloseParent(False) + self.s.close() + + @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") + def testChildSockets(self): + self.CheckChildSocket(sock_diag.TCP_SYN_RECV, "TCP_SYN_RECV", False) + self.CheckChildSocket(sock_diag.TCP_SYN_RECV, "TCP_SYN_RECV", True) + self.CheckChildSocket(self.NOT_YET_ACCEPTED, "not yet accepted", False) + self.CheckChildSocket(self.NOT_YET_ACCEPTED, "not yet accepted", True) + + def CloseDuringBlockingCall(self, sock, call, expected_errno): + thread = SocketExceptionThread(sock, call) + thread.start() + time.sleep(0.1) + self.sock_diag.CloseSocketFromFd(sock) + thread.join(1) + self.assertFalse(thread.is_alive()) + self.assertIsNotNone(thread.exception) + self.assertTrue(isinstance(thread.exception, IOError), + "Expected IOError, got %s" % thread.exception) + self.assertEqual(expected_errno, thread.exception.errno) + self.assertSocketClosed(sock) + + @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") + def testAcceptInterrupted(self): + """Tests that accept() is interrupted by SOCK_DESTROY.""" + for version in [4, 6]: + self.IncomingConnection(version, sock_diag.TCP_LISTEN, self.netid) + self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL) + self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo") + self.assertRaisesErrno(EINVAL, self.s.accept) + + @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") + def testReadInterrupted(self): + """Tests that read() is interrupted by SOCK_DESTROY.""" + for version in [4, 6]: + self.IncomingConnection(version, sock_diag.TCP_ESTABLISHED, self.netid) + self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096), + ECONNABORTED) + self.assertRaisesErrno(EPIPE, self.accepted.send, "foo") + + @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") + def testConnectInterrupted(self): + """Tests that connect() is interrupted by SOCK_DESTROY.""" + for version in [4, 6]: + family = {4: AF_INET, 6: AF_INET6}[version] + s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP) + self.SelectInterface(s, self.netid, "mark") + remoteaddr = self.GetRemoteAddress(version) + s.bind(("", 0)) + _, sport = s.getsockname()[:2] + self.CloseDuringBlockingCall( + s, lambda sock: sock.connect((remoteaddr, 53)), ECONNABORTED) + desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid), + remoteaddr, sport=sport, seq=None) + self.ExpectPacketOn(self.netid, desc, syn) + msg = "SOCK_DESTROY of socket in connect, expected no RST" + self.ExpectNoPacketsOn(self.netid, msg) + if __name__ == "__main__": unittest.main() |