aboutsummaryrefslogtreecommitdiff
path: root/bumble/gatt_client.py
diff options
context:
space:
mode:
Diffstat (limited to 'bumble/gatt_client.py')
-rw-r--r--bumble/gatt_client.py99
1 files changed, 67 insertions, 32 deletions
diff --git a/bumble/gatt_client.py b/bumble/gatt_client.py
index a33039e..e3b8bb2 100644
--- a/bumble/gatt_client.py
+++ b/bumble/gatt_client.py
@@ -28,7 +28,18 @@ import asyncio
import logging
import struct
from datetime import datetime
-from typing import List, Optional, Dict, Tuple, Callable, Union, Any
+from typing import (
+ List,
+ Optional,
+ Dict,
+ Tuple,
+ Callable,
+ Union,
+ Any,
+ Iterable,
+ Type,
+ TYPE_CHECKING,
+)
from pyee import EventEmitter
@@ -66,8 +77,12 @@ from .gatt import (
GATT_INCLUDE_ATTRIBUTE_TYPE,
Characteristic,
ClientCharacteristicConfigurationBits,
+ TemplateService,
)
+if TYPE_CHECKING:
+ from bumble.device import Connection
+
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
@@ -78,16 +93,16 @@ logger = logging.getLogger(__name__)
# Proxies
# -----------------------------------------------------------------------------
class AttributeProxy(EventEmitter):
- client: Client
-
- def __init__(self, client, handle, end_group_handle, attribute_type):
+ def __init__(
+ self, client: Client, handle: int, end_group_handle: int, attribute_type: UUID
+ ) -> None:
EventEmitter.__init__(self)
self.client = client
self.handle = handle
self.end_group_handle = end_group_handle
self.type = attribute_type
- async def read_value(self, no_long_read=False):
+ async def read_value(self, no_long_read: bool = False) -> bytes:
return self.decode_value(
await self.client.read_value(self.handle, no_long_read)
)
@@ -97,13 +112,13 @@ class AttributeProxy(EventEmitter):
self.handle, self.encode_value(value), with_response
)
- def encode_value(self, value):
+ def encode_value(self, value: Any) -> bytes:
return value
- def decode_value(self, value_bytes):
+ def decode_value(self, value_bytes: bytes) -> Any:
return value_bytes
- def __str__(self):
+ def __str__(self) -> str:
return f'Attribute(handle=0x{self.handle:04X}, type={self.type})'
@@ -136,14 +151,14 @@ class ServiceProxy(AttributeProxy):
def get_characteristics_by_uuid(self, uuid):
return self.client.get_characteristics_by_uuid(uuid, self)
- def __str__(self):
+ def __str__(self) -> str:
return f'Service(handle=0x{self.handle:04X}, uuid={self.uuid})'
class CharacteristicProxy(AttributeProxy):
properties: Characteristic.Properties
descriptors: List[DescriptorProxy]
- subscribers: Dict[Any, Callable]
+ subscribers: Dict[Any, Callable[[bytes], Any]]
def __init__(
self,
@@ -171,7 +186,9 @@ class CharacteristicProxy(AttributeProxy):
return await self.client.discover_descriptors(self)
async def subscribe(
- self, subscriber: Optional[Callable] = None, prefer_notify=True
+ self,
+ subscriber: Optional[Callable[[bytes], Any]] = None,
+ prefer_notify: bool = True,
):
if subscriber is not None:
if subscriber in self.subscribers:
@@ -195,7 +212,7 @@ class CharacteristicProxy(AttributeProxy):
return await self.client.unsubscribe(self, subscriber)
- def __str__(self):
+ def __str__(self) -> str:
return (
f'Characteristic(handle=0x{self.handle:04X}, '
f'uuid={self.uuid}, '
@@ -207,7 +224,7 @@ class DescriptorProxy(AttributeProxy):
def __init__(self, client, handle, descriptor_type):
super().__init__(client, handle, 0, descriptor_type)
- def __str__(self):
+ def __str__(self) -> str:
return f'Descriptor(handle=0x{self.handle:04X}, type={self.type})'
@@ -216,8 +233,10 @@ class ProfileServiceProxy:
Base class for profile-specific service proxies
'''
+ SERVICE_CLASS: Type[TemplateService]
+
@classmethod
- def from_client(cls, client):
+ def from_client(cls, client: Client) -> ProfileServiceProxy:
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
@@ -227,8 +246,12 @@ class ProfileServiceProxy:
class Client:
services: List[ServiceProxy]
cached_values: Dict[int, Tuple[datetime, bytes]]
+ notification_subscribers: Dict[int, Callable[[bytes], Any]]
+ indication_subscribers: Dict[int, Callable[[bytes], Any]]
+ pending_response: Optional[asyncio.futures.Future[ATT_PDU]]
+ pending_request: Optional[ATT_PDU]
- def __init__(self, connection):
+ def __init__(self, connection: Connection) -> None:
self.connection = connection
self.mtu_exchange_done = False
self.request_semaphore = asyncio.Semaphore(1)
@@ -241,16 +264,16 @@ class Client:
self.services = []
self.cached_values = {}
- def send_gatt_pdu(self, pdu):
+ def send_gatt_pdu(self, pdu: bytes) -> None:
self.connection.send_l2cap_pdu(ATT_CID, pdu)
- async def send_command(self, command):
+ async def send_command(self, command: ATT_PDU) -> None:
logger.debug(
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
)
self.send_gatt_pdu(command.to_bytes())
- async def send_request(self, request):
+ async def send_request(self, request: ATT_PDU):
logger.debug(
f'GATT Request from client: [0x{self.connection.handle:04X}] {request}'
)
@@ -279,14 +302,14 @@ class Client:
return response
- def send_confirmation(self, confirmation):
+ def send_confirmation(self, confirmation: ATT_Handle_Value_Confirmation) -> None:
logger.debug(
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
f'{confirmation}'
)
self.send_gatt_pdu(confirmation.to_bytes())
- async def request_mtu(self, mtu):
+ async def request_mtu(self, mtu: int) -> int:
# Check the range
if mtu < ATT_DEFAULT_MTU:
raise ValueError(f'MTU must be >= {ATT_DEFAULT_MTU}')
@@ -313,10 +336,12 @@ class Client:
return self.connection.att_mtu
- def get_services_by_uuid(self, uuid):
+ def get_services_by_uuid(self, uuid: UUID) -> List[ServiceProxy]:
return [service for service in self.services if service.uuid == uuid]
- def get_characteristics_by_uuid(self, uuid, service=None):
+ def get_characteristics_by_uuid(
+ self, uuid: UUID, service: Optional[ServiceProxy] = None
+ ) -> List[CharacteristicProxy]:
services = [service] if service else self.services
return [
c
@@ -363,7 +388,7 @@ class Client:
if not already_known:
self.services.append(service)
- async def discover_services(self, uuids=None) -> List[ServiceProxy]:
+ async def discover_services(self, uuids: Iterable[UUID] = []) -> List[ServiceProxy]:
'''
See Vol 3, Part G - 4.4.1 Discover All Primary Services
'''
@@ -435,7 +460,7 @@ class Client:
return services
- async def discover_service(self, uuid):
+ async def discover_service(self, uuid: Union[str, UUID]) -> List[ServiceProxy]:
'''
See Vol 3, Part G - 4.4.2 Discover Primary Service by Service UUID
'''
@@ -468,7 +493,7 @@ class Client:
f'{HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception
- return
+ return []
break
for attribute_handle, end_group_handle in response.handles_information:
@@ -480,7 +505,7 @@ class Client:
logger.warning(
f'bogus handle values: {attribute_handle} {end_group_handle}'
)
- return
+ return []
# Create a service proxy for this service
service = ServiceProxy(
@@ -721,7 +746,7 @@ class Client:
return descriptors
- async def discover_attributes(self):
+ async def discover_attributes(self) -> List[AttributeProxy]:
'''
Discover all attributes, regardless of type
'''
@@ -844,7 +869,9 @@ class Client:
# No more subscribers left
await self.write_value(cccd, b'\x00\x00', with_response=True)
- async def read_value(self, attribute, no_long_read=False):
+ async def read_value(
+ self, attribute: Union[int, AttributeProxy], no_long_read: bool = False
+ ) -> Any:
'''
See Vol 3, Part G - 4.8.1 Read Characteristic Value
@@ -905,7 +932,9 @@ class Client:
# Return the value as bytes
return attribute_value
- async def read_characteristics_by_uuid(self, uuid, service):
+ async def read_characteristics_by_uuid(
+ self, uuid: UUID, service: Optional[ServiceProxy]
+ ) -> List[bytes]:
'''
See Vol 3, Part G - 4.8.2 Read Using Characteristic UUID
'''
@@ -960,7 +989,12 @@ class Client:
return characteristics_values
- async def write_value(self, attribute, value, with_response=False):
+ async def write_value(
+ self,
+ attribute: Union[int, AttributeProxy],
+ value: bytes,
+ with_response: bool = False,
+ ) -> None:
'''
See Vol 3, Part G - 4.9.1 Write Without Response & 4.9.3 Write Characteristic
Value
@@ -990,7 +1024,7 @@ class Client:
)
)
- def on_gatt_pdu(self, att_pdu):
+ def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None:
logger.debug(
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
)
@@ -1013,6 +1047,7 @@ class Client:
return
# Return the response to the coroutine that is waiting for it
+ assert self.pending_response is not None
self.pending_response.set_result(att_pdu)
else:
handler_name = f'on_{att_pdu.name.lower()}'
@@ -1060,7 +1095,7 @@ class Client:
# Confirm that we received the indication
self.send_confirmation(ATT_Handle_Value_Confirmation())
- def cache_value(self, attribute_handle: int, value: bytes):
+ def cache_value(self, attribute_handle: int, value: bytes) -> None:
self.cached_values[attribute_handle] = (
datetime.now(),
value,