diff options
author | markchien <markchien@google.com> | 2019-04-01 02:59:37 -0700 |
---|---|---|
committer | android-build-merger <android-build-merger@google.com> | 2019-04-01 02:59:37 -0700 |
commit | 0192118a6561ab2144ce9248721e89903461b9cc (patch) | |
tree | ce037e417a31b3abc12a7c3e1ad3925c4e67a9e2 | |
parent | 0508643812e44aa8d18ba5aca5aef8794093673c (diff) | |
parent | a3196de82e1c60196fcc69b4fa9c742c580295fc (diff) | |
download | tests-0192118a6561ab2144ce9248721e89903461b9cc.tar.gz |
Add tests for tcp connection repair
am: a3196de82e
Change-Id: Ic21c28d28c01684f73de8c10df82e6a7d29082bf
-rwxr-xr-x | net/test/all_tests.py | 1 | ||||
-rw-r--r-- | net/test/multinetwork_base.py | 2 | ||||
-rwxr-xr-x | net/test/tcp_repair_test.py | 340 |
3 files changed, 342 insertions, 1 deletions
diff --git a/net/test/all_tests.py b/net/test/all_tests.py index 485d55b..bbef3ac 100755 --- a/net/test/all_tests.py +++ b/net/test/all_tests.py @@ -38,6 +38,7 @@ test_modules = [ 'srcaddr_selection_test', 'tcp_fastopen_test', 'tcp_nuke_addr_test', + 'tcp_repair_test', 'tcp_test', 'xfrm_algorithm_test', 'xfrm_test', diff --git a/net/test/multinetwork_base.py b/net/test/multinetwork_base.py index ce653b2..8dbd360 100644 --- a/net/test/multinetwork_base.py +++ b/net/test/multinetwork_base.py @@ -470,7 +470,7 @@ class MultiNetworkBaseTest(net_test.NetworkTest): def GetRemoteAddress(self, version): return {4: self.IPV4_ADDR, - 5: self.IPV4_ADDR, + 5: self.IPV4_ADDR, # see GetRemoteSocketAddress() 6: self.IPV6_ADDR}[version] def GetRemoteSocketAddress(self, version): diff --git a/net/test/tcp_repair_test.py b/net/test/tcp_repair_test.py new file mode 100755 index 0000000..af67def --- /dev/null +++ b/net/test/tcp_repair_test.py @@ -0,0 +1,340 @@ +#!/usr/bin/python +# +# Copyright 2019 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from errno import * # pylint: disable=wildcard-import +from socket import * # pylint: disable=wildcard-import +import ctypes +import fcntl +import os +import random +import select +import termios +import threading +import time +from scapy import all as scapy + +import multinetwork_base +import net_test +import packets + +SOL_TCP = net_test.SOL_TCP +SHUT_RD = net_test.SHUT_RD +SHUT_WR = net_test.SHUT_WR +SHUT_RDWR = net_test.SHUT_RDWR +SIOCINQ = termios.FIONREAD +SIOCOUTQ = termios.TIOCOUTQ + +TEST_PORT = 5555 + +# Following constants are SOL_TCP level options and arguments. +# They are defined in linux-kernel: include/uapi/linux/tcp.h + +# SOL_TCP level options. +TCP_REPAIR = 19 +TCP_REPAIR_QUEUE = 20 +TCP_QUEUE_SEQ = 21 + +# TCP_REPAIR_{OFF, ON} is an argument to TCP_REPAIR. +TCP_REPAIR_OFF = 0 +TCP_REPAIR_ON = 1 + +# TCP_{NO, RECV, SEND}_QUEUE is an argument to TCP_REPAIR_QUEUE. +TCP_NO_QUEUE = 0 +TCP_RECV_QUEUE = 1 +TCP_SEND_QUEUE = 2 + +# This test is aiming to ensure tcp keep alive offload works correctly +# when it fetches tcp information from kernel via tcp repair mode. +class TcpRepairTest(multinetwork_base.MultiNetworkBaseTest): + + def assertSocketNotConnected(self, sock): + self.assertRaisesErrno(ENOTCONN, sock.getpeername) + + def assertSocketConnected(self, sock): + sock.getpeername() # No errors? Socket is alive and connected. + + def createConnectedSocket(self, version, netid): + s = net_test.TCPSocket(net_test.GetAddressFamily(version)) + net_test.DisableFinWait(s) + self.SelectInterface(s, netid, "mark") + + remoteaddr = self.GetRemoteSocketAddress(version) + self.assertRaisesErrno(EINPROGRESS, s.connect, (remoteaddr, TEST_PORT)) + self.assertSocketNotConnected(s) + + myaddr = self.MyAddress(version, netid) + port = s.getsockname()[1] + self.assertNotEqual(0, port) + + desc, expect_syn = packets.SYN(TEST_PORT, version, myaddr, remoteaddr, port, seq=None) + msg = "socket connect: expected %s" % desc + syn = self.ExpectPacketOn(netid, msg, expect_syn) + synack_desc, synack = packets.SYNACK(version, remoteaddr, myaddr, syn) + synack.getlayer("TCP").seq = random.getrandbits(32) + synack.getlayer("TCP").window = 14400 + self.ReceivePacketOn(netid, synack) + desc, ack = packets.ACK(version, myaddr, remoteaddr, synack) + msg = "socket connect: got SYN+ACK, expected %s" % desc + ack = self.ExpectPacketOn(netid, msg, ack) + self.last_sent = ack + self.last_received = synack + return s + + def receiveFin(self, netid, version, sock): + self.assertSocketConnected(sock) + remoteaddr = self.GetRemoteAddress(version) + myaddr = self.MyAddress(version, netid) + desc, fin = packets.FIN(version, remoteaddr, myaddr, self.last_sent) + self.ReceivePacketOn(netid, fin) + self.last_received = fin + + def sendData(self, netid, version, sock, payload): + sock.send(payload) + + remoteaddr = self.GetRemoteAddress(version) + myaddr = self.MyAddress(version, netid) + desc, send = packets.ACK(version, myaddr, remoteaddr, + self.last_received, payload) + self.last_sent = send + + def receiveData(self, netid, version, payload): + remoteaddr = self.GetRemoteAddress(version) + myaddr = self.MyAddress(version, netid) + + desc, received = packets.ACK(version, remoteaddr, myaddr, + self.last_sent, payload) + ack_desc, ack = packets.ACK(version, myaddr, remoteaddr, received) + self.ReceivePacketOn(netid, received) + time.sleep(0.1) + self.ExpectPacketOn(netid, "expecting %s" % ack_desc, ack) + self.last_sent = ack + self.last_received = received + + # Test the behavior of NO_QUEUE. Expect incoming data will be stored into + # the queue, but socket cannot be read/written in NO_QUEUE. + def testTcpRepairInNoQueue(self): + for version in [4, 5, 6]: + self.tcpRepairInNoQueueTest(version) + + def tcpRepairInNoQueueTest(self, version): + netid = self.RandomNetid() + sock = self.createConnectedSocket(version, netid) + sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_ON) + + # In repair mode with NO_QUEUE, writes fail... + self.assertRaisesErrno(EINVAL, sock.send, "write test") + + # remote data is coming. + TEST_RECEIVED = net_test.UDP_PAYLOAD + self.receiveData(netid, version, TEST_RECEIVED) + + # In repair mode with NO_QUEUE, read fail... + self.assertRaisesErrno(EPERM, sock.recv, 4096) + + sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_OFF) + readData = sock.recv(4096) + self.assertEquals(readData, TEST_RECEIVED) + sock.close() + + # Test whether tcp read/write sequence number can be fetched correctly + # by TCP_QUEUE_SEQ. + def testGetSequenceNumber(self): + for version in [4, 5, 6]: + self.GetSequenceNumberTest(version) + + def GetSequenceNumberTest(self, version): + netid = self.RandomNetid() + sock = self.createConnectedSocket(version, netid) + # test write queue sequence number + sequence_before = self.GetWriteSequenceNumber(version, sock) + expect_sequence = self.last_sent.getlayer("TCP").seq + self.assertEquals(sequence_before & 0xffffffff, expect_sequence) + TEST_SEND = net_test.UDP_PAYLOAD + self.sendData(netid, version, sock, TEST_SEND) + sequence_after = self.GetWriteSequenceNumber(version, sock) + self.assertEquals(sequence_before + len(TEST_SEND), sequence_after) + + # test read queue sequence number + sequence_before = self.GetReadSequenceNumber(version, sock) + expect_sequence = self.last_received.getlayer("TCP").seq + 1 + self.assertEquals(sequence_before & 0xffffffff, expect_sequence) + TEST_READ = net_test.UDP_PAYLOAD + self.receiveData(netid, version, TEST_READ) + sequence_after = self.GetReadSequenceNumber(version, sock) + self.assertEquals(sequence_before + len(TEST_READ), sequence_after) + sock.close() + + def GetWriteSequenceNumber(self, version, sock): + sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_ON) + sock.setsockopt(SOL_TCP, TCP_REPAIR_QUEUE, TCP_SEND_QUEUE) + sequence = sock.getsockopt(SOL_TCP, TCP_QUEUE_SEQ) + sock.setsockopt(SOL_TCP, TCP_REPAIR_QUEUE, TCP_NO_QUEUE) + sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_OFF) + return sequence + + def GetReadSequenceNumber(self, version, sock): + sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_ON) + sock.setsockopt(SOL_TCP, TCP_REPAIR_QUEUE, TCP_RECV_QUEUE) + sequence = sock.getsockopt(SOL_TCP, TCP_QUEUE_SEQ) + sock.setsockopt(SOL_TCP, TCP_REPAIR_QUEUE, TCP_NO_QUEUE) + sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_OFF) + return sequence + + # Test whether tcp repair socket can be poll()'ed correctly + # in mutiple threads at the same time. + def testMultiThreadedPoll(self): + for version in [4, 5, 6]: + self.PollWhenShutdownTest(version) + self.PollWhenReceiveFinTest(version) + + def PollRepairSocketInMultipleThreads(self, netid, version, expected): + sock = self.createConnectedSocket(version, netid) + sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_ON) + + multiThreads = [] + for i in [0, 1]: + thread = SocketExceptionThread(sock, lambda sk: self.fdSelect(sock, expected)) + thread.start() + self.assertTrue(thread.is_alive()) + multiThreads.append(thread) + + return sock, multiThreads + + def assertThreadsStopped(self, multiThreads, msg) : + for thread in multiThreads: + if (thread.is_alive()): + thread.join(1) + if (thread.is_alive()): + thread.stop() + raise AssertionError(msg) + + def PollWhenShutdownTest(self, version): + netid = self.RandomNetid() + expected = select.POLLIN + sock, multiThreads = self.PollRepairSocketInMultipleThreads(netid, version, expected) + # Test shutdown RD. + sock.shutdown(SHUT_RD) + self.assertThreadsStopped(multiThreads, "poll fail during SHUT_RD") + sock.close() + + expected = None + sock, multiThreads = self.PollRepairSocketInMultipleThreads(netid, version, expected) + # Test shutdown WR. + sock.shutdown(SHUT_WR) + self.assertThreadsStopped(multiThreads, "poll fail during SHUT_WR") + sock.close() + + expected = select.POLLIN | select.POLLHUP + sock, multiThreads = self.PollRepairSocketInMultipleThreads(netid, version, expected) + # Test shutdown RDWR. + sock.shutdown(SHUT_RDWR) + self.assertThreadsStopped(multiThreads, "poll fail during SHUT_RDWR") + sock.close() + + def PollWhenReceiveFinTest(self, version): + netid = self.RandomNetid() + expected = select.POLLIN + sock, multiThreads = self.PollRepairSocketInMultipleThreads(netid, version, expected) + self.receiveFin(netid, version, sock) + self.assertThreadsStopped(multiThreads, "poll fail during FIN") + sock.close() + + # Test whether socket idle can be detected by SIOCINQ and SIOCOUTQ. + def testSocketIdle(self): + for version in [4, 5, 6]: + self.readQueueIdleTest(version) + self.writeQueueIdleTest(version) + + def readQueueIdleTest(self, version): + netid = self.RandomNetid() + sock = self.createConnectedSocket(version, netid) + + buf = ctypes.c_int() + fcntl.ioctl(sock, SIOCINQ, buf) + self.assertEquals(buf.value, 0) + + TEST_RECV_PAYLOAD = net_test.UDP_PAYLOAD + self.receiveData(netid, version, TEST_RECV_PAYLOAD) + fcntl.ioctl(sock, SIOCINQ, buf) + self.assertEquals(buf.value, len(TEST_RECV_PAYLOAD)) + sock.close() + + def writeQueueIdleTest(self, version): + netid = self.RandomNetid() + # Setup a connected socket, write queue is empty. + sock = self.createConnectedSocket(version, netid) + buf = ctypes.c_int() + fcntl.ioctl(sock, SIOCOUTQ, buf) + self.assertEquals(buf.value, 0) + # Change to repair mode with SEND_QUEUE, writing some data to the queue. + sock.setsockopt(SOL_TCP, TCP_REPAIR, TCP_REPAIR_ON) + TEST_SEND_PAYLOAD = net_test.UDP_PAYLOAD + sock.setsockopt(SOL_TCP, TCP_REPAIR_QUEUE, TCP_SEND_QUEUE) + self.sendData(netid, version, sock, TEST_SEND_PAYLOAD) + fcntl.ioctl(sock, SIOCOUTQ, buf) + self.assertEquals(buf.value, len(TEST_SEND_PAYLOAD)) + sock.close() + + # Setup a connected socket again. + netid = self.RandomNetid() + sock = self.createConnectedSocket(version, netid) + # Send out some data and don't receive ACK yet. + self.sendData(netid, version, sock, TEST_SEND_PAYLOAD) + fcntl.ioctl(sock, SIOCOUTQ, buf) + self.assertEquals(buf.value, len(TEST_SEND_PAYLOAD)) + # Receive response ACK. + remoteaddr = self.GetRemoteAddress(version) + myaddr = self.MyAddress(version, netid) + desc_ack, ack = packets.ACK(version, remoteaddr, myaddr, self.last_sent) + self.ReceivePacketOn(netid, ack) + fcntl.ioctl(sock, SIOCOUTQ, buf) + self.assertEquals(buf.value, 0) + sock.close() + + + def fdSelect(self, sock, expected): + READ_ONLY = select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR | select.POLLNVAL + p = select.poll() + p.register(sock, READ_ONLY) + events = p.poll(500) + for fd,event in events: + if fd == sock.fileno(): + self.assertEquals(event, expected) + else: + raise AssertionError("unexpected poll fd") + +class SocketExceptionThread(threading.Thread): + + def __init__(self, sock, operation): + self.exception = None + super(SocketExceptionThread, self).__init__() + self.daemon = True + self.sock = sock + self.operation = operation + + def stop(self): + self._Thread__stop() + + def run(self): + try: + self.operation(self.sock) + except (IOError, AssertionError), e: + self.exception = e + +if __name__ == '__main__': + unittest.main() |