diff options
Diffstat (limited to 'src/tests/portpicker_test.py')
-rw-r--r-- | src/tests/portpicker_test.py | 390 |
1 files changed, 390 insertions, 0 deletions
diff --git a/src/tests/portpicker_test.py b/src/tests/portpicker_test.py new file mode 100644 index 0000000..c2925db --- /dev/null +++ b/src/tests/portpicker_test.py @@ -0,0 +1,390 @@ +#!/usr/bin/python +# +# Copyright 2007 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Unittests for the portpicker module.""" + +from __future__ import print_function +import errno +import os +import random +import socket +import sys +import unittest +from contextlib import ExitStack + +if sys.platform == 'win32': + import _winapi +else: + _winapi = None + +try: + # pylint: disable=no-name-in-module + from unittest import mock # Python >= 3.3. +except ImportError: + import mock # https://pypi.python.org/pypi/mock + +import portpicker + + +class PickUnusedPortTest(unittest.TestCase): + def IsUnusedTCPPort(self, port): + return self._bind(port, socket.SOCK_STREAM, socket.IPPROTO_TCP) + + def IsUnusedUDPPort(self, port): + return self._bind(port, socket.SOCK_DGRAM, socket.IPPROTO_UDP) + + def setUp(self): + # So we can Bind even if portpicker.bind is stubbed out. + self._bind = portpicker.bind + portpicker._owned_ports.clear() + portpicker._free_ports.clear() + portpicker._random_ports.clear() + + def testPickUnusedPortActuallyWorks(self): + """This test can be flaky.""" + for _ in range(10): + port = portpicker.pick_unused_port() + self.assertTrue(self.IsUnusedTCPPort(port)) + self.assertTrue(self.IsUnusedUDPPort(port)) + + @unittest.skipIf('PORTSERVER_ADDRESS' not in os.environ, + 'no port server to test against') + def testPickUnusedCanSuccessfullyUsePortServer(self): + + with mock.patch.object(portpicker, '_pick_unused_port_without_server'): + portpicker._pick_unused_port_without_server.side_effect = ( + Exception('eek!') + ) + + # Since _PickUnusedPortWithoutServer() raises an exception, if we + # can successfully obtain a port, the portserver must be working. + port = portpicker.pick_unused_port() + self.assertTrue(self.IsUnusedTCPPort(port)) + self.assertTrue(self.IsUnusedUDPPort(port)) + + @unittest.skipIf('PORTSERVER_ADDRESS' not in os.environ, + 'no port server to test against') + def testPickUnusedCanSuccessfullyUsePortServerAddressKwarg(self): + + with mock.patch.object(portpicker, '_pick_unused_port_without_server'): + portpicker._pick_unused_port_without_server.side_effect = ( + Exception('eek!') + ) + + # Since _PickUnusedPortWithoutServer() raises an exception, and + # we've temporarily removed PORTSERVER_ADDRESS from os.environ, if + # we can successfully obtain a port, the portserver must be working. + addr = os.environ.pop('PORTSERVER_ADDRESS') + try: + port = portpicker.pick_unused_port(portserver_address=addr) + self.assertTrue(self.IsUnusedTCPPort(port)) + self.assertTrue(self.IsUnusedUDPPort(port)) + finally: + os.environ['PORTSERVER_ADDRESS'] = addr + + @unittest.skipIf('PORTSERVER_ADDRESS' not in os.environ, + 'no port server to test against') + def testGetPortFromPortServer(self): + """Exercise the get_port_from_port_server() helper function.""" + for _ in range(10): + port = portpicker.get_port_from_port_server( + os.environ['PORTSERVER_ADDRESS']) + self.assertTrue(self.IsUnusedTCPPort(port)) + self.assertTrue(self.IsUnusedUDPPort(port)) + + def testSendsPidToPortServer(self): + with ExitStack() as stack: + if _winapi: + create_file_mock = mock.Mock() + create_file_mock.return_value = 0 + read_file_mock = mock.Mock() + write_file_mock = mock.Mock() + read_file_mock.return_value = (b'42768\n', 0) + stack.enter_context( + mock.patch('_winapi.CreateFile', new=create_file_mock)) + stack.enter_context( + mock.patch('_winapi.WriteFile', new=write_file_mock)) + stack.enter_context( + mock.patch('_winapi.ReadFile', new=read_file_mock)) + port = portpicker.get_port_from_port_server( + 'portserver', pid=1234) + write_file_mock.assert_called_once_with(0, b'1234\n') + else: + server = mock.Mock() + server.recv.return_value = b'42768\n' + stack.enter_context( + mock.patch.object(socket, 'socket', return_value=server)) + port = portpicker.get_port_from_port_server( + 'portserver', pid=1234) + server.sendall.assert_called_once_with(b'1234\n') + + self.assertEqual(port, 42768) + + def testPidDefaultsToOwnPid(self): + with ExitStack() as stack: + stack.enter_context( + mock.patch.object(os, 'getpid', return_value=9876)) + + if _winapi: + create_file_mock = mock.Mock() + create_file_mock.return_value = 0 + read_file_mock = mock.Mock() + write_file_mock = mock.Mock() + read_file_mock.return_value = (b'52768\n', 0) + stack.enter_context( + mock.patch('_winapi.CreateFile', new=create_file_mock)) + stack.enter_context( + mock.patch('_winapi.WriteFile', new=write_file_mock)) + stack.enter_context( + mock.patch('_winapi.ReadFile', new=read_file_mock)) + port = portpicker.get_port_from_port_server('portserver') + write_file_mock.assert_called_once_with(0, b'9876\n') + else: + server = mock.Mock() + server.recv.return_value = b'52768\n' + stack.enter_context( + mock.patch.object(socket, 'socket', return_value=server)) + port = portpicker.get_port_from_port_server('portserver') + server.sendall.assert_called_once_with(b'9876\n') + + self.assertEqual(port, 52768) + + @mock.patch.dict(os.environ,{'PORTSERVER_ADDRESS': 'portserver'}) + def testReusesPortServerPorts(self): + with ExitStack() as stack: + if _winapi: + read_file_mock = mock.Mock() + read_file_mock.side_effect = [ + (b'12345\n', 0), + (b'23456\n', 0), + (b'34567\n', 0), + ] + stack.enter_context(mock.patch('_winapi.CreateFile')) + stack.enter_context(mock.patch('_winapi.WriteFile')) + stack.enter_context( + mock.patch('_winapi.ReadFile', new=read_file_mock)) + else: + server = mock.Mock() + server.recv.side_effect = [b'12345\n', b'23456\n', b'34567\n'] + stack.enter_context( + mock.patch.object(socket, 'socket', return_value=server)) + + self.assertEqual(portpicker.pick_unused_port(), 12345) + self.assertEqual(portpicker.pick_unused_port(), 23456) + portpicker.return_port(12345) + self.assertEqual(portpicker.pick_unused_port(), 12345) + + @mock.patch.dict(os.environ,{'PORTSERVER_ADDRESS': ''}) + def testDoesntReuseRandomPorts(self): + ports = set() + for _ in range(10): + try: + port = portpicker.pick_unused_port() + except portpicker.NoFreePortFoundError: + # This sometimes happens when not using portserver. Just + # skip to the next attempt. + continue + ports.add(port) + portpicker.return_port(port) + self.assertGreater(len(ports), 5) # Allow some random reuse. + + def testReturnsReservedPorts(self): + with mock.patch.object(portpicker, '_pick_unused_port_without_server'): + portpicker._pick_unused_port_without_server.side_effect = ( + Exception('eek!')) + # Arbitrary port. In practice you should get this from somewhere + # that assigns ports. + reserved_port = 28465 + portpicker.add_reserved_port(reserved_port) + ports = set() + for _ in range(10): + port = portpicker.pick_unused_port() + ports.add(port) + portpicker.return_port(port) + self.assertEqual(len(ports), 1) + self.assertEqual(ports.pop(), reserved_port) + + @mock.patch.dict(os.environ,{'PORTSERVER_ADDRESS': ''}) + def testFallsBackToRandomAfterRunningOutOfReservedPorts(self): + # Arbitrary port. In practice you should get this from somewhere + # that assigns ports. + reserved_port = 23456 + portpicker.add_reserved_port(reserved_port) + self.assertEqual(portpicker.pick_unused_port(), reserved_port) + self.assertNotEqual(portpicker.pick_unused_port(), reserved_port) + + def testRandomlyChosenPorts(self): + # Unless this box is under an overwhelming socket load, this test + # will heavily exercise the "pick a port randomly" part of the + # port picking code, but may never hit the "OS assigns a port" + # code. + ports = 0 + for _ in range(100): + try: + port = portpicker._pick_unused_port_without_server() + except portpicker.NoFreePortFoundError: + # Without the portserver, pick_unused_port can sometimes fail + # to find a free port. Check that it passes most of the time. + continue + self.assertTrue(self.IsUnusedTCPPort(port)) + self.assertTrue(self.IsUnusedUDPPort(port)) + ports += 1 + # Getting a port shouldn't have failed very often, even on machines + # with a heavy socket load. + self.assertGreater(ports, 95) + + def testOSAssignedPorts(self): + self.last_assigned_port = None + + def error_for_explicit_ports(port, socket_type, socket_proto): + # Only successfully return a port if an OS-assigned port is + # requested, or if we're checking that the last OS-assigned port + # is unused on the other protocol. + if port == 0 or port == self.last_assigned_port: + self.last_assigned_port = self._bind(port, socket_type, + socket_proto) + return self.last_assigned_port + else: + return None + + with mock.patch.object(portpicker, 'bind', error_for_explicit_ports): + # Without server, this can be little flaky, so check that it + # passes most of the time. + ports = 0 + for _ in range(100): + try: + port = portpicker._pick_unused_port_without_server() + except portpicker.NoFreePortFoundError: + continue + self.assertTrue(self.IsUnusedTCPPort(port)) + self.assertTrue(self.IsUnusedUDPPort(port)) + ports += 1 + self.assertGreater(ports, 70) + + def pickUnusedPortWithoutServer(self): + # Try a few times to pick a port, to avoid flakiness and to make sure + # the code path we want was exercised. + for _ in range(5): + try: + port = portpicker._pick_unused_port_without_server() + except portpicker.NoFreePortFoundError: + continue + else: + self.assertTrue(self.IsUnusedTCPPort(port)) + self.assertTrue(self.IsUnusedUDPPort(port)) + return + self.fail("Failed to find a free port") + + def testPickPortsWithoutServer(self): + # Test the first part of _pick_unused_port_without_server, which + # tries a few random ports and checks is_port_free. + self.pickUnusedPortWithoutServer() + + # Now test the second part, the fallback from above, which asks the + # OS for a port. + def mock_port_free(port): + return False + + with mock.patch.object(portpicker, 'is_port_free', mock_port_free): + self.pickUnusedPortWithoutServer() + + def checkIsPortFree(self): + """This might be flaky unless this test is run with a portserver.""" + # The port should be free initially. + port = portpicker.pick_unused_port() + self.assertTrue(portpicker.is_port_free(port)) + + cases = [ + (socket.AF_INET, socket.SOCK_STREAM, None), + (socket.AF_INET6, socket.SOCK_STREAM, 1), + (socket.AF_INET, socket.SOCK_DGRAM, None), + (socket.AF_INET6, socket.SOCK_DGRAM, 1), + ] + + # Using v6only=0 on Windows doesn't result in collisions + if not _winapi: + cases.extend([ + (socket.AF_INET6, socket.SOCK_STREAM, 0), + (socket.AF_INET6, socket.SOCK_DGRAM, 0), + ]) + + for (sock_family, sock_type, v6only) in cases: + # Occupy the port on a subset of possible protocols. + try: + sock = socket.socket(sock_family, sock_type, 0) + except socket.error: + print('Kernel does not support sock_family=%d' % sock_family, + file=sys.stderr) + # Skip this case, since we cannot occupy a port. + continue + + if not hasattr(socket, 'IPPROTO_IPV6'): + v6only = None + + if v6only is not None: + try: + sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, + v6only) + except socket.error: + print('Kernel does not support IPV6_V6ONLY=%d' % v6only, + file=sys.stderr) + # Don't care; just proceed with the default. + + # Socket may have been taken in the mean time, so catch the + # socket.error with errno set to EADDRINUSE and skip this + # attempt. + try: + sock.bind(('', port)) + except socket.error as e: + if e.errno == errno.EADDRINUSE: + raise portpicker.NoFreePortFoundError + raise + + # The port should be busy. + self.assertFalse(portpicker.is_port_free(port)) + sock.close() + + # Now it's free again. + self.assertTrue(portpicker.is_port_free(port)) + + def testIsPortFree(self): + # This can be quite flaky on a busy host, so try a few times. + for _ in range(10): + try: + self.checkIsPortFree() + except portpicker.NoFreePortFoundError: + pass + else: + return + self.fail("checkPortIsFree failed every time.") + + def testIsPortFreeException(self): + port = portpicker.pick_unused_port() + with mock.patch.object(socket, 'socket') as mock_sock: + mock_sock.side_effect = socket.error('fake socket error', 0) + self.assertFalse(portpicker.is_port_free(port)) + + def testThatLegacyCapWordsAPIsExist(self): + """The original APIs were CapWords style, 1.1 added PEP8 names.""" + self.assertEqual(portpicker.bind, portpicker.Bind) + self.assertEqual(portpicker.is_port_free, portpicker.IsPortFree) + self.assertEqual(portpicker.pick_unused_port, portpicker.PickUnusedPort) + self.assertEqual(portpicker.get_port_from_port_server, + portpicker.GetPortFromPortServer) + + +if __name__ == '__main__': + unittest.main() |