diff options
Diffstat (limited to 'bumble/gatt_client.py')
-rw-r--r-- | bumble/gatt_client.py | 99 |
1 files changed, 67 insertions, 32 deletions
diff --git a/bumble/gatt_client.py b/bumble/gatt_client.py index a33039e..e3b8bb2 100644 --- a/bumble/gatt_client.py +++ b/bumble/gatt_client.py @@ -28,7 +28,18 @@ import asyncio import logging import struct from datetime import datetime -from typing import List, Optional, Dict, Tuple, Callable, Union, Any +from typing import ( + List, + Optional, + Dict, + Tuple, + Callable, + Union, + Any, + Iterable, + Type, + TYPE_CHECKING, +) from pyee import EventEmitter @@ -66,8 +77,12 @@ from .gatt import ( GATT_INCLUDE_ATTRIBUTE_TYPE, Characteristic, ClientCharacteristicConfigurationBits, + TemplateService, ) +if TYPE_CHECKING: + from bumble.device import Connection + # ----------------------------------------------------------------------------- # Logging # ----------------------------------------------------------------------------- @@ -78,16 +93,16 @@ logger = logging.getLogger(__name__) # Proxies # ----------------------------------------------------------------------------- class AttributeProxy(EventEmitter): - client: Client - - def __init__(self, client, handle, end_group_handle, attribute_type): + def __init__( + self, client: Client, handle: int, end_group_handle: int, attribute_type: UUID + ) -> None: EventEmitter.__init__(self) self.client = client self.handle = handle self.end_group_handle = end_group_handle self.type = attribute_type - async def read_value(self, no_long_read=False): + async def read_value(self, no_long_read: bool = False) -> bytes: return self.decode_value( await self.client.read_value(self.handle, no_long_read) ) @@ -97,13 +112,13 @@ class AttributeProxy(EventEmitter): self.handle, self.encode_value(value), with_response ) - def encode_value(self, value): + def encode_value(self, value: Any) -> bytes: return value - def decode_value(self, value_bytes): + def decode_value(self, value_bytes: bytes) -> Any: return value_bytes - def __str__(self): + def __str__(self) -> str: return f'Attribute(handle=0x{self.handle:04X}, type={self.type})' @@ -136,14 +151,14 @@ class ServiceProxy(AttributeProxy): def get_characteristics_by_uuid(self, uuid): return self.client.get_characteristics_by_uuid(uuid, self) - def __str__(self): + def __str__(self) -> str: return f'Service(handle=0x{self.handle:04X}, uuid={self.uuid})' class CharacteristicProxy(AttributeProxy): properties: Characteristic.Properties descriptors: List[DescriptorProxy] - subscribers: Dict[Any, Callable] + subscribers: Dict[Any, Callable[[bytes], Any]] def __init__( self, @@ -171,7 +186,9 @@ class CharacteristicProxy(AttributeProxy): return await self.client.discover_descriptors(self) async def subscribe( - self, subscriber: Optional[Callable] = None, prefer_notify=True + self, + subscriber: Optional[Callable[[bytes], Any]] = None, + prefer_notify: bool = True, ): if subscriber is not None: if subscriber in self.subscribers: @@ -195,7 +212,7 @@ class CharacteristicProxy(AttributeProxy): return await self.client.unsubscribe(self, subscriber) - def __str__(self): + def __str__(self) -> str: return ( f'Characteristic(handle=0x{self.handle:04X}, ' f'uuid={self.uuid}, ' @@ -207,7 +224,7 @@ class DescriptorProxy(AttributeProxy): def __init__(self, client, handle, descriptor_type): super().__init__(client, handle, 0, descriptor_type) - def __str__(self): + def __str__(self) -> str: return f'Descriptor(handle=0x{self.handle:04X}, type={self.type})' @@ -216,8 +233,10 @@ class ProfileServiceProxy: Base class for profile-specific service proxies ''' + SERVICE_CLASS: Type[TemplateService] + @classmethod - def from_client(cls, client): + def from_client(cls, client: Client) -> ProfileServiceProxy: return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID) @@ -227,8 +246,12 @@ class ProfileServiceProxy: class Client: services: List[ServiceProxy] cached_values: Dict[int, Tuple[datetime, bytes]] + notification_subscribers: Dict[int, Callable[[bytes], Any]] + indication_subscribers: Dict[int, Callable[[bytes], Any]] + pending_response: Optional[asyncio.futures.Future[ATT_PDU]] + pending_request: Optional[ATT_PDU] - def __init__(self, connection): + def __init__(self, connection: Connection) -> None: self.connection = connection self.mtu_exchange_done = False self.request_semaphore = asyncio.Semaphore(1) @@ -241,16 +264,16 @@ class Client: self.services = [] self.cached_values = {} - def send_gatt_pdu(self, pdu): + def send_gatt_pdu(self, pdu: bytes) -> None: self.connection.send_l2cap_pdu(ATT_CID, pdu) - async def send_command(self, command): + async def send_command(self, command: ATT_PDU) -> None: logger.debug( f'GATT Command from client: [0x{self.connection.handle:04X}] {command}' ) self.send_gatt_pdu(command.to_bytes()) - async def send_request(self, request): + async def send_request(self, request: ATT_PDU): logger.debug( f'GATT Request from client: [0x{self.connection.handle:04X}] {request}' ) @@ -279,14 +302,14 @@ class Client: return response - def send_confirmation(self, confirmation): + def send_confirmation(self, confirmation: ATT_Handle_Value_Confirmation) -> None: logger.debug( f'GATT Confirmation from client: [0x{self.connection.handle:04X}] ' f'{confirmation}' ) self.send_gatt_pdu(confirmation.to_bytes()) - async def request_mtu(self, mtu): + async def request_mtu(self, mtu: int) -> int: # Check the range if mtu < ATT_DEFAULT_MTU: raise ValueError(f'MTU must be >= {ATT_DEFAULT_MTU}') @@ -313,10 +336,12 @@ class Client: return self.connection.att_mtu - def get_services_by_uuid(self, uuid): + def get_services_by_uuid(self, uuid: UUID) -> List[ServiceProxy]: return [service for service in self.services if service.uuid == uuid] - def get_characteristics_by_uuid(self, uuid, service=None): + def get_characteristics_by_uuid( + self, uuid: UUID, service: Optional[ServiceProxy] = None + ) -> List[CharacteristicProxy]: services = [service] if service else self.services return [ c @@ -363,7 +388,7 @@ class Client: if not already_known: self.services.append(service) - async def discover_services(self, uuids=None) -> List[ServiceProxy]: + async def discover_services(self, uuids: Iterable[UUID] = []) -> List[ServiceProxy]: ''' See Vol 3, Part G - 4.4.1 Discover All Primary Services ''' @@ -435,7 +460,7 @@ class Client: return services - async def discover_service(self, uuid): + async def discover_service(self, uuid: Union[str, UUID]) -> List[ServiceProxy]: ''' See Vol 3, Part G - 4.4.2 Discover Primary Service by Service UUID ''' @@ -468,7 +493,7 @@ class Client: f'{HCI_Constant.error_name(response.error_code)}' ) # TODO raise appropriate exception - return + return [] break for attribute_handle, end_group_handle in response.handles_information: @@ -480,7 +505,7 @@ class Client: logger.warning( f'bogus handle values: {attribute_handle} {end_group_handle}' ) - return + return [] # Create a service proxy for this service service = ServiceProxy( @@ -721,7 +746,7 @@ class Client: return descriptors - async def discover_attributes(self): + async def discover_attributes(self) -> List[AttributeProxy]: ''' Discover all attributes, regardless of type ''' @@ -844,7 +869,9 @@ class Client: # No more subscribers left await self.write_value(cccd, b'\x00\x00', with_response=True) - async def read_value(self, attribute, no_long_read=False): + async def read_value( + self, attribute: Union[int, AttributeProxy], no_long_read: bool = False + ) -> Any: ''' See Vol 3, Part G - 4.8.1 Read Characteristic Value @@ -905,7 +932,9 @@ class Client: # Return the value as bytes return attribute_value - async def read_characteristics_by_uuid(self, uuid, service): + async def read_characteristics_by_uuid( + self, uuid: UUID, service: Optional[ServiceProxy] + ) -> List[bytes]: ''' See Vol 3, Part G - 4.8.2 Read Using Characteristic UUID ''' @@ -960,7 +989,12 @@ class Client: return characteristics_values - async def write_value(self, attribute, value, with_response=False): + async def write_value( + self, + attribute: Union[int, AttributeProxy], + value: bytes, + with_response: bool = False, + ) -> None: ''' See Vol 3, Part G - 4.9.1 Write Without Response & 4.9.3 Write Characteristic Value @@ -990,7 +1024,7 @@ class Client: ) ) - def on_gatt_pdu(self, att_pdu): + def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None: logger.debug( f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}' ) @@ -1013,6 +1047,7 @@ class Client: return # Return the response to the coroutine that is waiting for it + assert self.pending_response is not None self.pending_response.set_result(att_pdu) else: handler_name = f'on_{att_pdu.name.lower()}' @@ -1060,7 +1095,7 @@ class Client: # Confirm that we received the indication self.send_confirmation(ATT_Handle_Value_Confirmation()) - def cache_value(self, attribute_handle: int, value: bytes): + def cache_value(self, attribute_handle: int, value: bytes) -> None: self.cached_values[attribute_handle] = ( datetime.now(), value, |