diff options
author | Patrice Vignola <vignola.patrice@gmail.com> | 2021-07-11 18:34:38 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-07-11 18:34:38 -0700 |
commit | c283eb1aeaa355408cee1ddd4e49eb0fc4c02dce (patch) | |
tree | 62072c4c5f654210f2930dd4a808cb4d14c01bb9 | |
parent | 0b178b8e500065f7ad71204c0d1c066766d5d01b (diff) | |
download | portpicker-c283eb1aeaa355408cee1ddd4e49eb0fc4c02dce.tar.gz |
Add Windows support for the port server (#25)
Add Windows support for the port server and Windows named pipe support to the portpicker client.
Contributed by Patrice Vignola
-rw-r--r-- | .github/workflows/python-package.yml | 36 | ||||
-rw-r--r-- | ChangeLog.md | 7 | ||||
-rw-r--r-- | setup.cfg | 5 | ||||
-rw-r--r-- | src/portpicker.py | 93 | ||||
-rw-r--r-- | src/portserver.py | 87 | ||||
-rw-r--r-- | src/tests/portpicker_test.py | 95 | ||||
-rw-r--r-- | src/tests/portserver_test.py | 34 |
7 files changed, 279 insertions, 78 deletions
diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index d50e439..88a8115 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -3,16 +3,22 @@ name: Python Portpicker & Portserver -on: [push] +on: + push: + branches: + - 'main' + pull_request: + branches: + - 'main' jobs: - build: + build-ubuntu: runs-on: ubuntu-latest strategy: fail-fast: false matrix: - python-version: [3.6, 3.7, 3.8, 3.9, '3.10.0-beta.1'] + python-version: [3.6, 3.7, 3.8, 3.9, '3.10.0-beta.3'] steps: - uses: actions/checkout@v2 @@ -29,3 +35,27 @@ jobs: run: | # Run tox using the version of Python in `PATH` tox -e py + + build-windows: + + runs-on: windows-latest + strategy: + fail-fast: false + matrix: + python-version: [3.6, 3.7, 3.8, 3.9, '3.10.0-beta.3'] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pytest tox + if (Test-Path "requirements.txt") { pip install -r requirements.txt } + - name: Test with tox + run: | + # Run tox using the version of Python in `PATH` + tox -e py diff --git a/ChangeLog.md b/ChangeLog.md index b385a38..28ae395 100644 --- a/ChangeLog.md +++ b/ChangeLog.md @@ -1,3 +1,10 @@ +## 1.5.0 + +* Add portserver support to Windows using named pipes. To create or connect + to a server, prefix the name of the server with `@` (e.g. + `@unittest-portserver`). + + ## 1.4.0 * Use `async def` instead of `@asyncio.coroutine` in order to support 3.10. @@ -1,7 +1,7 @@ # https://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files [metadata] name = portpicker -version = 1.4.1b1 +version = 1.5.0b1 maintainer = Google LLC maintainer_email = greg@krypto.org license = Apache 2.0 @@ -29,10 +29,11 @@ classifiers = Programming Language :: Python :: 3.10 Programming Language :: Python :: Implementation :: CPython Programming Language :: Python :: Implementation :: PyPy -platforms = POSIX +platforms = POSIX, Windows requires = [options] +install_requires = psutil python_requires = >= 3.6 package_dir= =src diff --git a/src/portpicker.py b/src/portpicker.py index e54dcbc..4717bbc 100644 --- a/src/portpicker.py +++ b/src/portpicker.py @@ -43,6 +43,11 @@ import random import socket import sys +if sys.platform == 'win32': + import _winapi +else: + _winapi = None + # The legacy Bind, IsPortFree, etc. names are not exported. __all__ = ('bind', 'is_port_free', 'pick_unused_port', 'return_port', 'add_reserved_port', 'get_port_from_port_server') @@ -63,7 +68,6 @@ _random_ports = set() class NoFreePortFoundError(Exception): """Exception indicating that no free port could be found.""" - pass def add_reserved_port(port): @@ -217,6 +221,61 @@ def _pick_unused_port_without_server(): # Protected. pylint: disable=invalid-na raise NoFreePortFoundError() +def _get_linux_port_from_port_server(portserver_address, pid): + # An AF_UNIX address may start with a zero byte, in which case it is in the + # "abstract namespace", and doesn't have any filesystem representation. + # See 'man 7 unix' for details. + # The convention is to write '@' in the address to represent this zero byte. + if portserver_address[0] == '@': + portserver_address = '\0' + portserver_address[1:] + + try: + # Create socket. + if hasattr(socket, 'AF_UNIX'): + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) # pylint: disable=no-member + else: + # fallback to AF_INET if this is not unix + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + # Connect to portserver. + sock.connect(portserver_address) + + # Write request. + sock.sendall(('%d\n' % pid).encode('ascii')) + + # Read response. + # 1K should be ample buffer space. + return sock.recv(1024) + finally: + sock.close() + except socket.error as error: + print('Socket error when connecting to portserver:', error, + file=sys.stderr) + return None + + +def _get_windows_port_from_port_server(portserver_address, pid): + if portserver_address[0] == '@': + portserver_address = '\\\\.\\pipe\\' + portserver_address[1:] + + try: + handle = _winapi.CreateFile( + portserver_address, + _winapi.GENERIC_READ | _winapi.GENERIC_WRITE, + 0, + 0, + _winapi.OPEN_EXISTING, + 0, + 0) + + _winapi.WriteFile(handle, ('%d\n' % pid).encode('ascii')) + data, _ = _winapi.ReadFile(handle, 6, 0) + return data + except FileNotFoundError as error: + print('File error when connecting to portserver:', error, + file=sys.stderr) + return None + def get_port_from_port_server(portserver_address, pid=None): """Request a free a port from a system-wide portserver. @@ -240,38 +299,16 @@ def get_port_from_port_server(portserver_address, pid=None): """ if not portserver_address: return None - # An AF_UNIX address may start with a zero byte, in which case it is in the - # "abstract namespace", and doesn't have any filesystem representation. - # See 'man 7 unix' for details. - # The convention is to write '@' in the address to represent this zero byte. - if portserver_address[0] == '@': - portserver_address = '\0' + portserver_address[1:] if pid is None: pid = os.getpid() - try: - # Create socket. - if hasattr(socket, 'AF_UNIX'): - sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - else: - # fallback to AF_INET if this is not unix - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - try: - # Connect to portserver. - sock.connect(portserver_address) - - # Write request. - sock.sendall(('%d\n' % pid).encode('ascii')) + if _winapi: + buf = _get_windows_port_from_port_server(portserver_address, pid) + else: + buf = _get_linux_port_from_port_server(portserver_address, pid) - # Read response. - # 1K should be ample buffer space. - buf = sock.recv(1024) - finally: - sock.close() - except socket.error as e: - print('Socket error when connecting to portserver:', e, - file=sys.stderr) + if buf is None: return None try: diff --git a/src/portserver.py b/src/portserver.py index 58b7ecd..f986f3f 100644 --- a/src/portserver.py +++ b/src/portserver.py @@ -31,10 +31,12 @@ import argparse import asyncio import collections import logging -import os import signal import socket import sys +import psutil +import subprocess +from datetime import datetime, timezone, timedelta log = None # Initialized to a logging.Logger by _configure_logging(). @@ -44,18 +46,16 @@ _PROTOS = [(socket.SOCK_STREAM, socket.IPPROTO_TCP), def _get_process_command_line(pid): try: - with open('/proc/{}/cmdline'.format(pid), 'rt') as cmdline_f: - return cmdline_f.read() - except IOError: + return psutil.Process(pid).cmdline() + except psutil.NoSuchProcess: return '' def _get_process_start_time(pid): try: - with open('/proc/{}/stat'.format(pid), 'rt') as pid_stat_f: - return int(pid_stat_f.readline().split()[21]) - except IOError: - return 0 + return psutil.Process(pid).create_time() + except psutil.NoSuchProcess: + return 0.0 # TODO: Consider importing portpicker.bind() instead of duplicating the code. @@ -115,14 +115,27 @@ def _should_allocate_port(pid): # had been reparented to init. log.info('Not allocating a port to init.') return False - try: - os.kill(pid, 0) - except (ProcessLookupError, OverflowError): + + if not psutil.pid_exists(pid): log.info('Not allocating a port to a non-existent process') return False return True +async def _start_windows_server(client_connected_cb, path): + """Start the server on Windows using named pipes.""" + def protocol_factory(): + stream_reader = asyncio.StreamReader() + stream_reader_protocol = asyncio.StreamReaderProtocol( + stream_reader, client_connected_cb) + return stream_reader_protocol + + loop = asyncio.get_event_loop() + server, *_ = await loop.start_serving_pipe(protocol_factory, address=path) + + return server + + class _PortInfo(object): """Container class for information about a given port assignment. @@ -137,7 +150,7 @@ class _PortInfo(object): def __init__(self, port): self.port = port self.pid = 0 - self.start_time = 0 + self.start_time = 0.0 class _PortPool(object): @@ -178,7 +191,7 @@ class _PortPool(object): candidate = self._port_queue.pop() self._port_queue.appendleft(candidate) check_count += 1 - if (candidate.start_time == 0 or + if (candidate.start_time == 0.0 or candidate.start_time != _get_process_start_time(candidate.pid)): if _is_port_free(candidate.port): candidate.pid = pid @@ -287,10 +300,13 @@ def _parse_command_line(): default='15000-24999', help='Comma separated N-P Range(s) of ports to manage (inclusive).') parser.add_argument( - '--portserver_unix_socket_address', + '--portserver_address', + '--portserver_unix_socket_address', # Alias to be backward compatible type=str, default='@unittest-portserver', - help='Address of AF_UNIX socket on which to listen (first @ is a NUL).') + help='Address of AF_UNIX socket on which to listen on Unix (first @ is ' + 'a NUL) or the name of the pipe on Windows (first @ is the ' + r'\\.\pipe\ prefix).') parser.add_argument('--verbose', action='store_true', default=False, @@ -348,14 +364,33 @@ def main(): request_handler = _PortServerRequestHandler(ports_to_serve) + if sys.platform == 'win32': + asyncio.set_event_loop(asyncio.ProactorEventLoop()) + event_loop = asyncio.get_event_loop() - event_loop.add_signal_handler(signal.SIGUSR1, request_handler.dump_stats) - old_py_loop = {'loop': event_loop} if sys.version_info < (3, 10) else {} - coro = asyncio.start_unix_server( - request_handler.handle_port_request, - path=config.portserver_unix_socket_address.replace('@', '\0', 1), - **old_py_loop) - server_address = config.portserver_unix_socket_address + + if sys.platform == 'win32': + # On Windows, we need to periodically pause the loop to allow the user + # to send a break signal (e.g. ctrl+c) + def listen_for_signal(): + event_loop.call_later(0.5, listen_for_signal) + + event_loop.call_later(0.5, listen_for_signal) + + coro = _start_windows_server( + request_handler.handle_port_request, + path=config.portserver_address.replace('@', '\\\\.\\pipe\\', 1)) + else: + event_loop.add_signal_handler( + signal.SIGUSR1, request_handler.dump_stats) # pylint: disable=no-member + + old_py_loop = {'loop': event_loop} if sys.version_info < (3, 10) else {} + coro = asyncio.start_unix_server( + request_handler.handle_port_request, + path=config.portserver_address.replace('@', '\0', 1), + **old_py_loop) + + server_address = config.portserver_address server = event_loop.run_until_complete(coro) log.info('Serving on %s', server_address) @@ -365,8 +400,12 @@ def main(): log.info('Stopping due to ^C.') server.close() - event_loop.run_until_complete(server.wait_closed()) - event_loop.remove_signal_handler(signal.SIGUSR1) + + if sys.platform != 'win32': + # PipeServer doesn't have a wait_closed() function + event_loop.run_until_complete(server.wait_closed()) + event_loop.remove_signal_handler(signal.SIGUSR1) # pylint: disable=no-member + event_loop.close() request_handler.dump_stats() log.info('Goodbye.') diff --git a/src/tests/portpicker_test.py b/src/tests/portpicker_test.py index b82d7bf..e479a46 100644 --- a/src/tests/portpicker_test.py +++ b/src/tests/portpicker_test.py @@ -23,6 +23,12 @@ import random import socket import sys import unittest +from contextlib import ExitStack + +if sys.platform == 'win32': + import _winapi +else: + _winapi = None try: # pylint: disable=no-name-in-module @@ -100,27 +106,82 @@ class PickUnusedPortTest(unittest.TestCase): self.assertTrue(self.IsUnusedUDPPort(port)) def testSendsPidToPortServer(self): - server = mock.Mock() - server.recv.return_value = b'42768\n' - with mock.patch.object(socket, 'socket', return_value=server): - port = portpicker.get_port_from_port_server('portserver', pid=1234) - server.sendall.assert_called_once_with(b'1234\n') + with ExitStack() as stack: + if _winapi: + create_file_mock = mock.Mock() + create_file_mock.return_value = 0 + read_file_mock = mock.Mock() + write_file_mock = mock.Mock() + read_file_mock.return_value = (b'42768\n', 0) + stack.enter_context( + mock.patch('_winapi.CreateFile', new=create_file_mock)) + stack.enter_context( + mock.patch('_winapi.WriteFile', new=write_file_mock)) + stack.enter_context( + mock.patch('_winapi.ReadFile', new=read_file_mock)) + port = portpicker.get_port_from_port_server( + 'portserver', pid=1234) + write_file_mock.assert_called_once_with(0, b'1234\n') + else: + server = mock.Mock() + server.recv.return_value = b'42768\n' + stack.enter_context( + mock.patch.object(socket, 'socket', return_value=server)) + port = portpicker.get_port_from_port_server( + 'portserver', pid=1234) + server.sendall.assert_called_once_with(b'1234\n') + self.assertEqual(port, 42768) def testPidDefaultsToOwnPid(self): - server = mock.Mock() - server.recv.return_value = b'52768\n' - with mock.patch.object(socket, 'socket', return_value=server): - with mock.patch.object(os, 'getpid', return_value=9876): + with ExitStack() as stack: + stack.enter_context( + mock.patch.object(os, 'getpid', return_value=9876)) + + if _winapi: + create_file_mock = mock.Mock() + create_file_mock.return_value = 0 + read_file_mock = mock.Mock() + write_file_mock = mock.Mock() + read_file_mock.return_value = (b'52768\n', 0) + stack.enter_context( + mock.patch('_winapi.CreateFile', new=create_file_mock)) + stack.enter_context( + mock.patch('_winapi.WriteFile', new=write_file_mock)) + stack.enter_context( + mock.patch('_winapi.ReadFile', new=read_file_mock)) + port = portpicker.get_port_from_port_server('portserver') + write_file_mock.assert_called_once_with(0, b'9876\n') + else: + server = mock.Mock() + server.recv.return_value = b'52768\n' + stack.enter_context( + mock.patch.object(socket, 'socket', return_value=server)) port = portpicker.get_port_from_port_server('portserver') server.sendall.assert_called_once_with(b'9876\n') + self.assertEqual(port, 52768) @mock.patch.dict(os.environ,{'PORTSERVER_ADDRESS': 'portserver'}) def testReusesPortServerPorts(self): - server = mock.Mock() - server.recv.side_effect = [b'12345\n', b'23456\n', b'34567\n'] - with mock.patch.object(socket, 'socket', return_value=server): + with ExitStack() as stack: + if _winapi: + read_file_mock = mock.Mock() + read_file_mock.side_effect = [ + (b'12345\n', 0), + (b'23456\n', 0), + (b'34567\n', 0), + ] + stack.enter_context(mock.patch('_winapi.CreateFile')) + stack.enter_context(mock.patch('_winapi.WriteFile')) + stack.enter_context( + mock.patch('_winapi.ReadFile', new=read_file_mock)) + else: + server = mock.Mock() + server.recv.side_effect = [b'12345\n', b'23456\n', b'34567\n'] + stack.enter_context( + mock.patch.object(socket, 'socket', return_value=server)) + self.assertEqual(portpicker.pick_unused_port(), 12345) self.assertEqual(portpicker.pick_unused_port(), 23456) portpicker.return_port(12345) @@ -248,12 +309,18 @@ class PickUnusedPortTest(unittest.TestCase): cases = [ (socket.AF_INET, socket.SOCK_STREAM, None), - (socket.AF_INET6, socket.SOCK_STREAM, 0), (socket.AF_INET6, socket.SOCK_STREAM, 1), (socket.AF_INET, socket.SOCK_DGRAM, None), - (socket.AF_INET6, socket.SOCK_DGRAM, 0), (socket.AF_INET6, socket.SOCK_DGRAM, 1), ] + + # Using v6only=0 on Windows doesn't result in collisions + if not _winapi: + cases.extend([ + (socket.AF_INET6, socket.SOCK_STREAM, 0), + (socket.AF_INET6, socket.SOCK_DGRAM, 0), + ]) + for (sock_family, sock_type, v6only) in cases: # Occupy the port on a subset of possible protocols. try: diff --git a/src/tests/portserver_test.py b/src/tests/portserver_test.py index 394b1b5..b7de094 100644 --- a/src/tests/portserver_test.py +++ b/src/tests/portserver_test.py @@ -25,14 +25,23 @@ import sys import time import unittest from unittest import mock +from multiprocessing import Process import portpicker + +# On Windows, portserver.py is located in the "Scripts" folder, which isn't +# added to the import path by default +if sys.platform == 'win32': + sys.path.append(os.path.join(os.path.split(sys.executable)[0])) + import portserver def setUpModule(): portserver._configure_logging(verbose=True) +def exit_immediately(): + os._exit(0) class PortserverFunctionsTest(unittest.TestCase): @@ -53,12 +62,18 @@ class PortserverFunctionsTest(unittest.TestCase): cases = [ (socket.AF_INET, socket.SOCK_STREAM, None), - (socket.AF_INET6, socket.SOCK_STREAM, 0), (socket.AF_INET6, socket.SOCK_STREAM, 1), (socket.AF_INET, socket.SOCK_DGRAM, None), - (socket.AF_INET6, socket.SOCK_DGRAM, 0), (socket.AF_INET6, socket.SOCK_DGRAM, 1), ] + + # Using v6only=0 on Windows doesn't result in collisions + if sys.platform != 'win32': + cases.extend([ + (socket.AF_INET6, socket.SOCK_STREAM, 0), + (socket.AF_INET6, socket.SOCK_DGRAM, 0), + ]) + for (sock_family, sock_type, v6only) in cases: # Occupy the port on a subset of possible protocols. try: @@ -68,6 +83,10 @@ class PortserverFunctionsTest(unittest.TestCase): file=sys.stderr) # Skip this case, since we cannot occupy a port. continue + + if not hasattr(socket, 'IPPROTO_IPV6'): + v6only = None + if v6only is not None: try: sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, @@ -94,11 +113,12 @@ class PortserverFunctionsTest(unittest.TestCase): self.assertFalse(portserver._should_allocate_port(0)) self.assertFalse(portserver._should_allocate_port(1)) self.assertTrue(portserver._should_allocate_port, os.getpid()) - child_pid = os.fork() - if child_pid == 0: - os._exit(0) - else: - os.waitpid(child_pid, 0) + + p = Process(target=exit_immediately) + p.start() + child_pid = p.pid + p.join() + # This test assumes that after waitpid returns the kernel has finished # cleaning the process. We also assume that the kernel will not reuse # the former child's pid before our next call checks for its existence. |