aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Craig <tommycraig@gmail.com>2023-11-07 16:55:52 +0000
committerCQ Bot Account <pigweed-scoped@luci-project-accounts.iam.gserviceaccount.com>2023-11-07 16:55:52 +0000
commit4e43f63c083bb14ef1ff0c0c59b3ad8762d44478 (patch)
treed668cf3185d44ecc5acc3125324070848533fe5e
parent1283ffa78740b2c36db839f3dcf4cbb9701ae910 (diff)
downloadpigweed-4e43f63c083bb14ef1ff0c0c59b3ad8762d44478.tar.gz
pw_console: Improve SocketClient addressing
Update SocketClient's addressing support to handle both ipv6 and ipv4 in addition to unix sockets. Test: Successfully connected to localhost and unix socket. Change-Id: I0330394dcb998db9822cd2a0fd654bc7d60cd6a4 Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/178921 Reviewed-by: Carlos Chinchilla <cachinchilla@google.com> Commit-Queue: Carlos Chinchilla <cachinchilla@google.com> Pigweed-Auto-Submit: Carlos Chinchilla <cachinchilla@google.com> Reviewed-by: Taylor Cramer <cramertj@google.com> Reviewed-by: Tom Craig <tommycraig@gmail.com>
-rw-r--r--pw_console/embedding.rst9
-rw-r--r--pw_console/py/BUILD.bazel11
-rw-r--r--pw_console/py/BUILD.gn1
-rw-r--r--pw_console/py/pw_console/socket_client.py139
-rw-r--r--pw_console/py/socket_client_test.py181
5 files changed, 313 insertions, 28 deletions
diff --git a/pw_console/embedding.rst b/pw_console/embedding.rst
index a5d896c04..b2263afdb 100644
--- a/pw_console/embedding.rst
+++ b/pw_console/embedding.rst
@@ -113,7 +113,16 @@ Logging data with sockets
from pw_console.socket_client import SocketClientWithLogging
+ # Name resolution with explicit port
serial_device = SocketClientWithLogging('localhost:1234')
+ # Name resolution with default port.
+ serial_device = SocketClientWithLogging('pigweed.dev')
+ # Link-local IPv6 address with explicit port.
+ serial_device = SocketClientWithLogging('[fe80::100%enp1s0]:1234')
+ # Link-local IPv6 address with default port.
+ serial_device = SocketClientWithLogging('[fe80::100%enp1s0]')
+ # IPv4 address with port.
+ serial_device = SocketClientWithLogging('1.2.3.4:5678')
.. tip::
The ``SocketClient`` takes an optional callback called when a disconnect is
diff --git a/pw_console/py/BUILD.bazel b/pw_console/py/BUILD.bazel
index 798b3354c..9fdeebd91 100644
--- a/pw_console/py/BUILD.bazel
+++ b/pw_console/py/BUILD.bazel
@@ -201,6 +201,17 @@ py_test(
)
py_test(
+ name = "socket_client_test",
+ size = "small",
+ srcs = [
+ "socket_client_test.py",
+ ],
+ deps = [
+ ":pw_console",
+ ],
+)
+
+py_test(
name = "repl_pane_test",
size = "small",
srcs = [
diff --git a/pw_console/py/BUILD.gn b/pw_console/py/BUILD.gn
index 2cc6d1866..3683004d5 100644
--- a/pw_console/py/BUILD.gn
+++ b/pw_console/py/BUILD.gn
@@ -87,6 +87,7 @@ pw_python_package("py") {
"log_store_test.py",
"log_view_test.py",
"repl_pane_test.py",
+ "socket_client_test.py",
"table_test.py",
"text_formatting_test.py",
"window_manager_test.py",
diff --git a/pw_console/py/pw_console/socket_client.py b/pw_console/py/pw_console/socket_client.py
index 41a1c66c4..5344c3199 100644
--- a/pw_console/py/pw_console/socket_client.py
+++ b/pw_console/py/pw_console/socket_client.py
@@ -17,6 +17,7 @@ from __future__ import annotations
from typing import Callable, Optional, TYPE_CHECKING, Tuple, Union
import errno
+import re
import socket
from pw_console.plugins.bandwidth_toolbar import SerialBandwidthTracker
@@ -33,43 +34,125 @@ class SocketClient:
DEFAULT_SOCKET_PORT = 33000
PW_RPC_MAX_PACKET_SIZE = 256
+ _InitArgsType = Tuple[
+ socket.AddressFamily, int # pylint: disable=no-member
+ ]
+ # Can be a string, (address, port) for AF_INET or (address, port, flowinfo,
+ # scope_id) AF_INET6.
+ _AddressType = Union[str, Tuple[str, int], Tuple[str, int, int, int]]
+
def __init__(
self,
config: str,
on_disconnect: Optional[Callable[[SocketClient], None]] = None,
):
- self._connection_type: int
- self._interface: Union[str, Tuple[str, int]]
+ """Creates a socket connection.
+
+ Args:
+ config: The socket configuration. Accepted values and formats are:
+ 'default' - uses the default configuration (localhost:33000)
+ 'address:port' - An IPv4 address and port.
+ 'address' - An IPv4 address. Uses default port 33000.
+ '[address]:port' - An IPv6 address and port.
+ '[address]' - An IPv6 address. Uses default port 33000.
+ 'file:path_to_file' - A Unix socket at ``path_to_file``.
+ In the formats above,``address`` can be an actual address or a name
+ that resolves to an address through name-resolution.
+ on_disconnect: An optional callback called when the socket
+ disconnects.
+
+ Raises:
+ TypeError: The type of socket is not supported.
+ ValueError: The socket configuration is invalid.
+ """
+ self.socket: socket.socket
+ (
+ self._socket_init_args,
+ self._address,
+ ) = SocketClient._parse_socket_config(config)
+ self._on_disconnect = on_disconnect
+ self._connected = False
+ self.connect()
+
+ @staticmethod
+ def _parse_socket_config(
+ config: str,
+ ) -> Tuple[SocketClient._InitArgsType, SocketClient._AddressType]:
+ """Sets the variables used to create a socket given a config string.
+
+ Raises:
+ TypeError: The type of socket is not supported.
+ ValueError: The socket configuration is invalid.
+ """
+ init_args: SocketClient._InitArgsType
+ address: SocketClient._AddressType
+
+ # Check if this is using the default settings.
if config == 'default':
- self._connection_type = socket.AF_INET6
- self._interface = (
- self.DEFAULT_SOCKET_SERVER,
- self.DEFAULT_SOCKET_PORT,
+ init_args = socket.AF_INET6, socket.SOCK_STREAM
+ address = (
+ SocketClient.DEFAULT_SOCKET_SERVER,
+ SocketClient.DEFAULT_SOCKET_PORT,
)
- else:
- socket_server, socket_port_or_file = config.split(':')
- if socket_server == self.FILE_SOCKET_SERVER:
- # Unix socket support is available on Windows 10 since April
- # 2018. However, there is no Python support on Windows yet.
- # See https://bugs.python.org/issue33408 for more information.
- if not hasattr(socket, 'AF_UNIX'):
- raise TypeError(
- 'Unix sockets are not supported in this environment.'
- )
- self._connection_type = (
- socket.AF_UNIX # pylint: disable=no-member
+ return init_args, address
+
+ # Check if this is a UNIX socket.
+ unix_socket_file_setting = f'{SocketClient.FILE_SOCKET_SERVER}:'
+ if config.startswith(unix_socket_file_setting):
+ # Unix socket support is available on Windows 10 since April
+ # 2018. However, there is no Python support on Windows yet.
+ # See https://bugs.python.org/issue33408 for more information.
+ if not hasattr(socket, 'AF_UNIX'):
+ raise TypeError(
+ 'Unix sockets are not supported in this environment.'
)
- self._interface = socket_port_or_file
- else:
- self._connection_type = socket.AF_INET6
- self._interface = (socket_server, int(socket_port_or_file))
+ init_args = (
+ socket.AF_UNIX, # pylint: disable=no-member
+ socket.SOCK_STREAM,
+ )
+ address = config[len(unix_socket_file_setting) :]
+ return init_args, address
+
+ # Search for IPv4 or IPv6 address or name and port.
+ # First, try to capture an IPv6 address as anything inside []. If there
+ # are no [] capture the IPv4 address. Lastly, capture the port as the
+ # numbers after :, if any.
+ match = re.match(
+ r'(\[(?P<ipv6_addr>.+)\]:?|(?P<ipv4_addr>[a-zA-Z0-9\._\/]+):?)'
+ r'(?P<port>[0-9]+)?',
+ config,
+ )
+ invalid_config_message = (
+ f'Invalid socket configuration "{config}"'
+ 'Accepted values are "default", "file:<file_path>", '
+ '"<name_or_ipv4_address>" with optional ":<port>", and '
+ '"[<name_or_ipv6_address>]" with optional ":<port>".'
+ )
+ if match is None:
+ raise ValueError(invalid_config_message)
+
+ info = match.groupdict()
+ if info['port']:
+ port = int(info['port'])
+ else:
+ port = SocketClient.DEFAULT_SOCKET_PORT
- self._on_disconnect = on_disconnect
- self._connected = False
- self.connect()
+ if info['ipv4_addr']:
+ ip_addr = info['ipv4_addr']
+ elif info['ipv6_addr']:
+ ip_addr = info['ipv6_addr']
+ else:
+ raise ValueError(invalid_config_message)
+
+ sock_family, sock_type, _, _, address = socket.getaddrinfo(
+ ip_addr, port, type=socket.SOCK_STREAM
+ )[0]
+ init_args = sock_family, sock_type
+ return init_args, address
def __del__(self):
- self.socket.close()
+ if self._connected:
+ self.socket.close()
def write(self, data: ReadableBuffer) -> None:
"""Writes data and detects disconnects."""
@@ -96,13 +179,13 @@ class SocketClient:
def connect(self) -> None:
"""Connects to socket."""
- self.socket = socket.socket(self._connection_type, socket.SOCK_STREAM)
+ self.socket = socket.socket(*self._socket_init_args)
# Enable reusing address and port for reconnections.
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if hasattr(socket, 'SO_REUSEPORT'):
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
- self.socket.connect(self._interface)
+ self.socket.connect(self._address)
self._connected = True
def _handle_disconnect(self):
diff --git a/pw_console/py/socket_client_test.py b/pw_console/py/socket_client_test.py
new file mode 100644
index 000000000..f4e5a9f5d
--- /dev/null
+++ b/pw_console/py/socket_client_test.py
@@ -0,0 +1,181 @@
+# Copyright 2023 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+"""Tests for pw_console.socket_client"""
+
+import socket
+import unittest
+
+
+from pw_console import socket_client
+
+
+class TestSocketClient(unittest.TestCase):
+ """Tests for SocketClient."""
+
+ def test_parse_config_default(self) -> None:
+ config = "default"
+ with unittest.mock.patch.object(
+ socket_client.SocketClient, 'connect', return_value=None
+ ):
+ client = socket_client.SocketClient(config)
+ self.assertEqual(
+ client._socket_init_args, # pylint: disable=protected-access
+ (socket.AF_INET6, socket.SOCK_STREAM),
+ )
+ self.assertEqual(
+ client._address, # pylint: disable=protected-access
+ (
+ socket_client.SocketClient.DEFAULT_SOCKET_SERVER,
+ socket_client.SocketClient.DEFAULT_SOCKET_PORT,
+ ),
+ )
+
+ def test_parse_config_unix_file(self) -> None:
+ # Skip test if UNIX sockets are not supported.
+ if not hasattr(socket, 'AF_UNIX'):
+ return
+
+ config = 'file:fake_file_path'
+ with unittest.mock.patch.object(
+ socket_client.SocketClient, 'connect', return_value=None
+ ):
+ client = socket_client.SocketClient(config)
+ self.assertEqual(
+ client._socket_init_args, # pylint: disable=protected-access
+ (
+ socket.AF_UNIX, # pylint: disable=no-member
+ socket.SOCK_STREAM,
+ ),
+ )
+ self.assertEqual(
+ client._address, # pylint: disable=protected-access
+ 'fake_file_path',
+ )
+
+ def _check_config_parsing(
+ self, config: str, expected_address: str, expected_port: int
+ ) -> None:
+ with unittest.mock.patch.object(
+ socket_client.SocketClient, 'connect', return_value=None
+ ):
+ fake_getaddrinfo_return_value = [
+ (socket.AF_INET6, socket.SOCK_STREAM, 0, None, None)
+ ]
+ with unittest.mock.patch.object(
+ socket,
+ 'getaddrinfo',
+ return_value=fake_getaddrinfo_return_value,
+ ) as mock_getaddrinfo:
+ client = socket_client.SocketClient(config)
+ mock_getaddrinfo.assert_called_with(
+ expected_address, expected_port, type=socket.SOCK_STREAM
+ )
+ # Assert the init args are what is returned by ``getaddrinfo``
+ # not necessarily the correct ones, since this test should not
+ # perform any network action.
+ self.assertEqual(
+ client._socket_init_args, # pylint: disable=protected-access
+ (
+ socket.AF_INET6,
+ socket.SOCK_STREAM,
+ ),
+ )
+
+ def test_parse_config_ipv4_domain(self) -> None:
+ self._check_config_parsing(
+ config='file.com/some_long/path:80',
+ expected_address='file.com/some_long/path',
+ expected_port=80,
+ )
+
+ def test_parse_config_ipv4_domain_no_port(self) -> None:
+ self._check_config_parsing(
+ config='file.com/some/path',
+ expected_address='file.com/some/path',
+ expected_port=socket_client.SocketClient.DEFAULT_SOCKET_PORT,
+ )
+
+ def test_parse_config_ipv4_address(self) -> None:
+ self._check_config_parsing(
+ config='8.8.8.8:8080',
+ expected_address='8.8.8.8',
+ expected_port=8080,
+ )
+
+ def test_parse_config_ipv4_address_no_port(self) -> None:
+ self._check_config_parsing(
+ config='8.8.8.8',
+ expected_address='8.8.8.8',
+ expected_port=socket_client.SocketClient.DEFAULT_SOCKET_PORT,
+ )
+
+ def test_parse_config_ipv6_domain(self) -> None:
+ self._check_config_parsing(
+ config='[file.com/some_long/path]:80',
+ expected_address='file.com/some_long/path',
+ expected_port=80,
+ )
+
+ def test_parse_config_ipv6_domain_no_port(self) -> None:
+ self._check_config_parsing(
+ config='[file.com/some/path]',
+ expected_address='file.com/some/path',
+ expected_port=socket_client.SocketClient.DEFAULT_SOCKET_PORT,
+ )
+
+ def test_parse_config_ipv6_address(self) -> None:
+ self._check_config_parsing(
+ config='[2001:4860:4860::8888:8080]:666',
+ expected_address='2001:4860:4860::8888:8080',
+ expected_port=666,
+ )
+
+ def test_parse_config_ipv6_address_no_port(self) -> None:
+ self._check_config_parsing(
+ config='[2001:4860:4860::8844]',
+ expected_address='2001:4860:4860::8844',
+ expected_port=socket_client.SocketClient.DEFAULT_SOCKET_PORT,
+ )
+
+ def test_parse_config_ipv6_local(self) -> None:
+ self._check_config_parsing(
+ config='[fe80::100%eth0]:80',
+ expected_address='fe80::100%eth0',
+ expected_port=80,
+ )
+
+ def test_parse_config_ipv6_local_no_port(self) -> None:
+ self._check_config_parsing(
+ config='[fe80::100%eth0]',
+ expected_address='fe80::100%eth0',
+ expected_port=socket_client.SocketClient.DEFAULT_SOCKET_PORT,
+ )
+
+ def test_parse_config_ipv6_local_windows(self) -> None:
+ self._check_config_parsing(
+ config='[fe80::100%4]:80',
+ expected_address='fe80::100%4',
+ expected_port=80,
+ )
+
+ def test_parse_config_ipv6_local_no_port_windows(self) -> None:
+ self._check_config_parsing(
+ config='[fe80::100%4]',
+ expected_address='fe80::100%4',
+ expected_port=socket_client.SocketClient.DEFAULT_SOCKET_PORT,
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()