diff options
author | Tom Craig <tommycraig@gmail.com> | 2023-11-07 16:55:52 +0000 |
---|---|---|
committer | CQ Bot Account <pigweed-scoped@luci-project-accounts.iam.gserviceaccount.com> | 2023-11-07 16:55:52 +0000 |
commit | 4e43f63c083bb14ef1ff0c0c59b3ad8762d44478 (patch) | |
tree | d668cf3185d44ecc5acc3125324070848533fe5e | |
parent | 1283ffa78740b2c36db839f3dcf4cbb9701ae910 (diff) | |
download | pigweed-4e43f63c083bb14ef1ff0c0c59b3ad8762d44478.tar.gz |
pw_console: Improve SocketClient addressing
Update SocketClient's addressing support to handle both ipv6 and ipv4
in addition to unix sockets.
Test: Successfully connected to localhost and unix socket.
Change-Id: I0330394dcb998db9822cd2a0fd654bc7d60cd6a4
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/178921
Reviewed-by: Carlos Chinchilla <cachinchilla@google.com>
Commit-Queue: Carlos Chinchilla <cachinchilla@google.com>
Pigweed-Auto-Submit: Carlos Chinchilla <cachinchilla@google.com>
Reviewed-by: Taylor Cramer <cramertj@google.com>
Reviewed-by: Tom Craig <tommycraig@gmail.com>
-rw-r--r-- | pw_console/embedding.rst | 9 | ||||
-rw-r--r-- | pw_console/py/BUILD.bazel | 11 | ||||
-rw-r--r-- | pw_console/py/BUILD.gn | 1 | ||||
-rw-r--r-- | pw_console/py/pw_console/socket_client.py | 139 | ||||
-rw-r--r-- | pw_console/py/socket_client_test.py | 181 |
5 files changed, 313 insertions, 28 deletions
diff --git a/pw_console/embedding.rst b/pw_console/embedding.rst index a5d896c04..b2263afdb 100644 --- a/pw_console/embedding.rst +++ b/pw_console/embedding.rst @@ -113,7 +113,16 @@ Logging data with sockets from pw_console.socket_client import SocketClientWithLogging + # Name resolution with explicit port serial_device = SocketClientWithLogging('localhost:1234') + # Name resolution with default port. + serial_device = SocketClientWithLogging('pigweed.dev') + # Link-local IPv6 address with explicit port. + serial_device = SocketClientWithLogging('[fe80::100%enp1s0]:1234') + # Link-local IPv6 address with default port. + serial_device = SocketClientWithLogging('[fe80::100%enp1s0]') + # IPv4 address with port. + serial_device = SocketClientWithLogging('1.2.3.4:5678') .. tip:: The ``SocketClient`` takes an optional callback called when a disconnect is diff --git a/pw_console/py/BUILD.bazel b/pw_console/py/BUILD.bazel index 798b3354c..9fdeebd91 100644 --- a/pw_console/py/BUILD.bazel +++ b/pw_console/py/BUILD.bazel @@ -201,6 +201,17 @@ py_test( ) py_test( + name = "socket_client_test", + size = "small", + srcs = [ + "socket_client_test.py", + ], + deps = [ + ":pw_console", + ], +) + +py_test( name = "repl_pane_test", size = "small", srcs = [ diff --git a/pw_console/py/BUILD.gn b/pw_console/py/BUILD.gn index 2cc6d1866..3683004d5 100644 --- a/pw_console/py/BUILD.gn +++ b/pw_console/py/BUILD.gn @@ -87,6 +87,7 @@ pw_python_package("py") { "log_store_test.py", "log_view_test.py", "repl_pane_test.py", + "socket_client_test.py", "table_test.py", "text_formatting_test.py", "window_manager_test.py", diff --git a/pw_console/py/pw_console/socket_client.py b/pw_console/py/pw_console/socket_client.py index 41a1c66c4..5344c3199 100644 --- a/pw_console/py/pw_console/socket_client.py +++ b/pw_console/py/pw_console/socket_client.py @@ -17,6 +17,7 @@ from __future__ import annotations from typing import Callable, Optional, TYPE_CHECKING, Tuple, Union import errno +import re import socket from pw_console.plugins.bandwidth_toolbar import SerialBandwidthTracker @@ -33,43 +34,125 @@ class SocketClient: DEFAULT_SOCKET_PORT = 33000 PW_RPC_MAX_PACKET_SIZE = 256 + _InitArgsType = Tuple[ + socket.AddressFamily, int # pylint: disable=no-member + ] + # Can be a string, (address, port) for AF_INET or (address, port, flowinfo, + # scope_id) AF_INET6. + _AddressType = Union[str, Tuple[str, int], Tuple[str, int, int, int]] + def __init__( self, config: str, on_disconnect: Optional[Callable[[SocketClient], None]] = None, ): - self._connection_type: int - self._interface: Union[str, Tuple[str, int]] + """Creates a socket connection. + + Args: + config: The socket configuration. Accepted values and formats are: + 'default' - uses the default configuration (localhost:33000) + 'address:port' - An IPv4 address and port. + 'address' - An IPv4 address. Uses default port 33000. + '[address]:port' - An IPv6 address and port. + '[address]' - An IPv6 address. Uses default port 33000. + 'file:path_to_file' - A Unix socket at ``path_to_file``. + In the formats above,``address`` can be an actual address or a name + that resolves to an address through name-resolution. + on_disconnect: An optional callback called when the socket + disconnects. + + Raises: + TypeError: The type of socket is not supported. + ValueError: The socket configuration is invalid. + """ + self.socket: socket.socket + ( + self._socket_init_args, + self._address, + ) = SocketClient._parse_socket_config(config) + self._on_disconnect = on_disconnect + self._connected = False + self.connect() + + @staticmethod + def _parse_socket_config( + config: str, + ) -> Tuple[SocketClient._InitArgsType, SocketClient._AddressType]: + """Sets the variables used to create a socket given a config string. + + Raises: + TypeError: The type of socket is not supported. + ValueError: The socket configuration is invalid. + """ + init_args: SocketClient._InitArgsType + address: SocketClient._AddressType + + # Check if this is using the default settings. if config == 'default': - self._connection_type = socket.AF_INET6 - self._interface = ( - self.DEFAULT_SOCKET_SERVER, - self.DEFAULT_SOCKET_PORT, + init_args = socket.AF_INET6, socket.SOCK_STREAM + address = ( + SocketClient.DEFAULT_SOCKET_SERVER, + SocketClient.DEFAULT_SOCKET_PORT, ) - else: - socket_server, socket_port_or_file = config.split(':') - if socket_server == self.FILE_SOCKET_SERVER: - # Unix socket support is available on Windows 10 since April - # 2018. However, there is no Python support on Windows yet. - # See https://bugs.python.org/issue33408 for more information. - if not hasattr(socket, 'AF_UNIX'): - raise TypeError( - 'Unix sockets are not supported in this environment.' - ) - self._connection_type = ( - socket.AF_UNIX # pylint: disable=no-member + return init_args, address + + # Check if this is a UNIX socket. + unix_socket_file_setting = f'{SocketClient.FILE_SOCKET_SERVER}:' + if config.startswith(unix_socket_file_setting): + # Unix socket support is available on Windows 10 since April + # 2018. However, there is no Python support on Windows yet. + # See https://bugs.python.org/issue33408 for more information. + if not hasattr(socket, 'AF_UNIX'): + raise TypeError( + 'Unix sockets are not supported in this environment.' ) - self._interface = socket_port_or_file - else: - self._connection_type = socket.AF_INET6 - self._interface = (socket_server, int(socket_port_or_file)) + init_args = ( + socket.AF_UNIX, # pylint: disable=no-member + socket.SOCK_STREAM, + ) + address = config[len(unix_socket_file_setting) :] + return init_args, address + + # Search for IPv4 or IPv6 address or name and port. + # First, try to capture an IPv6 address as anything inside []. If there + # are no [] capture the IPv4 address. Lastly, capture the port as the + # numbers after :, if any. + match = re.match( + r'(\[(?P<ipv6_addr>.+)\]:?|(?P<ipv4_addr>[a-zA-Z0-9\._\/]+):?)' + r'(?P<port>[0-9]+)?', + config, + ) + invalid_config_message = ( + f'Invalid socket configuration "{config}"' + 'Accepted values are "default", "file:<file_path>", ' + '"<name_or_ipv4_address>" with optional ":<port>", and ' + '"[<name_or_ipv6_address>]" with optional ":<port>".' + ) + if match is None: + raise ValueError(invalid_config_message) + + info = match.groupdict() + if info['port']: + port = int(info['port']) + else: + port = SocketClient.DEFAULT_SOCKET_PORT - self._on_disconnect = on_disconnect - self._connected = False - self.connect() + if info['ipv4_addr']: + ip_addr = info['ipv4_addr'] + elif info['ipv6_addr']: + ip_addr = info['ipv6_addr'] + else: + raise ValueError(invalid_config_message) + + sock_family, sock_type, _, _, address = socket.getaddrinfo( + ip_addr, port, type=socket.SOCK_STREAM + )[0] + init_args = sock_family, sock_type + return init_args, address def __del__(self): - self.socket.close() + if self._connected: + self.socket.close() def write(self, data: ReadableBuffer) -> None: """Writes data and detects disconnects.""" @@ -96,13 +179,13 @@ class SocketClient: def connect(self) -> None: """Connects to socket.""" - self.socket = socket.socket(self._connection_type, socket.SOCK_STREAM) + self.socket = socket.socket(*self._socket_init_args) # Enable reusing address and port for reconnections. self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) if hasattr(socket, 'SO_REUSEPORT'): self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) - self.socket.connect(self._interface) + self.socket.connect(self._address) self._connected = True def _handle_disconnect(self): diff --git a/pw_console/py/socket_client_test.py b/pw_console/py/socket_client_test.py new file mode 100644 index 000000000..f4e5a9f5d --- /dev/null +++ b/pw_console/py/socket_client_test.py @@ -0,0 +1,181 @@ +# Copyright 2023 The Pigweed Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +"""Tests for pw_console.socket_client""" + +import socket +import unittest + + +from pw_console import socket_client + + +class TestSocketClient(unittest.TestCase): + """Tests for SocketClient.""" + + def test_parse_config_default(self) -> None: + config = "default" + with unittest.mock.patch.object( + socket_client.SocketClient, 'connect', return_value=None + ): + client = socket_client.SocketClient(config) + self.assertEqual( + client._socket_init_args, # pylint: disable=protected-access + (socket.AF_INET6, socket.SOCK_STREAM), + ) + self.assertEqual( + client._address, # pylint: disable=protected-access + ( + socket_client.SocketClient.DEFAULT_SOCKET_SERVER, + socket_client.SocketClient.DEFAULT_SOCKET_PORT, + ), + ) + + def test_parse_config_unix_file(self) -> None: + # Skip test if UNIX sockets are not supported. + if not hasattr(socket, 'AF_UNIX'): + return + + config = 'file:fake_file_path' + with unittest.mock.patch.object( + socket_client.SocketClient, 'connect', return_value=None + ): + client = socket_client.SocketClient(config) + self.assertEqual( + client._socket_init_args, # pylint: disable=protected-access + ( + socket.AF_UNIX, # pylint: disable=no-member + socket.SOCK_STREAM, + ), + ) + self.assertEqual( + client._address, # pylint: disable=protected-access + 'fake_file_path', + ) + + def _check_config_parsing( + self, config: str, expected_address: str, expected_port: int + ) -> None: + with unittest.mock.patch.object( + socket_client.SocketClient, 'connect', return_value=None + ): + fake_getaddrinfo_return_value = [ + (socket.AF_INET6, socket.SOCK_STREAM, 0, None, None) + ] + with unittest.mock.patch.object( + socket, + 'getaddrinfo', + return_value=fake_getaddrinfo_return_value, + ) as mock_getaddrinfo: + client = socket_client.SocketClient(config) + mock_getaddrinfo.assert_called_with( + expected_address, expected_port, type=socket.SOCK_STREAM + ) + # Assert the init args are what is returned by ``getaddrinfo`` + # not necessarily the correct ones, since this test should not + # perform any network action. + self.assertEqual( + client._socket_init_args, # pylint: disable=protected-access + ( + socket.AF_INET6, + socket.SOCK_STREAM, + ), + ) + + def test_parse_config_ipv4_domain(self) -> None: + self._check_config_parsing( + config='file.com/some_long/path:80', + expected_address='file.com/some_long/path', + expected_port=80, + ) + + def test_parse_config_ipv4_domain_no_port(self) -> None: + self._check_config_parsing( + config='file.com/some/path', + expected_address='file.com/some/path', + expected_port=socket_client.SocketClient.DEFAULT_SOCKET_PORT, + ) + + def test_parse_config_ipv4_address(self) -> None: + self._check_config_parsing( + config='8.8.8.8:8080', + expected_address='8.8.8.8', + expected_port=8080, + ) + + def test_parse_config_ipv4_address_no_port(self) -> None: + self._check_config_parsing( + config='8.8.8.8', + expected_address='8.8.8.8', + expected_port=socket_client.SocketClient.DEFAULT_SOCKET_PORT, + ) + + def test_parse_config_ipv6_domain(self) -> None: + self._check_config_parsing( + config='[file.com/some_long/path]:80', + expected_address='file.com/some_long/path', + expected_port=80, + ) + + def test_parse_config_ipv6_domain_no_port(self) -> None: + self._check_config_parsing( + config='[file.com/some/path]', + expected_address='file.com/some/path', + expected_port=socket_client.SocketClient.DEFAULT_SOCKET_PORT, + ) + + def test_parse_config_ipv6_address(self) -> None: + self._check_config_parsing( + config='[2001:4860:4860::8888:8080]:666', + expected_address='2001:4860:4860::8888:8080', + expected_port=666, + ) + + def test_parse_config_ipv6_address_no_port(self) -> None: + self._check_config_parsing( + config='[2001:4860:4860::8844]', + expected_address='2001:4860:4860::8844', + expected_port=socket_client.SocketClient.DEFAULT_SOCKET_PORT, + ) + + def test_parse_config_ipv6_local(self) -> None: + self._check_config_parsing( + config='[fe80::100%eth0]:80', + expected_address='fe80::100%eth0', + expected_port=80, + ) + + def test_parse_config_ipv6_local_no_port(self) -> None: + self._check_config_parsing( + config='[fe80::100%eth0]', + expected_address='fe80::100%eth0', + expected_port=socket_client.SocketClient.DEFAULT_SOCKET_PORT, + ) + + def test_parse_config_ipv6_local_windows(self) -> None: + self._check_config_parsing( + config='[fe80::100%4]:80', + expected_address='fe80::100%4', + expected_port=80, + ) + + def test_parse_config_ipv6_local_no_port_windows(self) -> None: + self._check_config_parsing( + config='[fe80::100%4]', + expected_address='fe80::100%4', + expected_port=socket_client.SocketClient.DEFAULT_SOCKET_PORT, + ) + + +if __name__ == '__main__': + unittest.main() |