diff options
Diffstat (limited to 'bumble/host.py')
-rw-r--r-- | bumble/host.py | 89 |
1 files changed, 58 insertions, 31 deletions
diff --git a/bumble/host.py b/bumble/host.py index e41fd02..02caa46 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -15,23 +15,24 @@ # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- +from __future__ import annotations import asyncio import collections import logging import struct +from typing import Optional, TYPE_CHECKING, Dict, Callable, Awaitable + from bumble.colors import color from bumble.l2cap import L2CAP_PDU from bumble.snoop import Snooper from bumble import drivers -from typing import Optional - from .hci import ( Address, HCI_ACL_DATA_PACKET, - HCI_COMMAND_COMPLETE_EVENT, HCI_COMMAND_PACKET, + HCI_COMMAND_COMPLETE_EVENT, HCI_EVENT_PACKET, HCI_LE_READ_BUFFER_SIZE_COMMAND, HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND, @@ -45,8 +46,11 @@ from .hci import ( HCI_VERSION_BLUETOOTH_CORE_4_0, HCI_AclDataPacket, HCI_AclDataPacketAssembler, + HCI_Command, + HCI_Command_Complete_Event, HCI_Constant, HCI_Error, + HCI_Event, HCI_LE_Long_Term_Key_Request_Negative_Reply_Command, HCI_LE_Long_Term_Key_Request_Reply_Command, HCI_LE_Read_Buffer_Size_Command, @@ -63,16 +67,19 @@ from .hci import ( HCI_Read_Local_Version_Information_Command, HCI_Reset_Command, HCI_Set_Event_Mask_Command, - map_null_terminated_utf8_string, ) from .core import ( BT_BR_EDR_TRANSPORT, - BT_CENTRAL_ROLE, BT_LE_TRANSPORT, ConnectionPHY, ConnectionParameters, + InvalidStateError, ) from .utils import AbortableEventEmitter +from .transport.common import TransportLostError + +if TYPE_CHECKING: + from .transport.common import TransportSink, TransportSource # ----------------------------------------------------------------------------- @@ -96,27 +103,38 @@ HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS = 1 # ----------------------------------------------------------------------------- class Connection: - def __init__(self, host, handle, peer_address, transport): + def __init__(self, host: Host, handle: int, peer_address: Address, transport: int): self.host = host self.handle = handle self.peer_address = peer_address self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu) self.transport = transport - def on_hci_acl_data_packet(self, packet): + def on_hci_acl_data_packet(self, packet: HCI_AclDataPacket) -> None: self.assembler.feed_packet(packet) - def on_acl_pdu(self, pdu): + def on_acl_pdu(self, pdu: bytes) -> None: l2cap_pdu = L2CAP_PDU.from_bytes(pdu) self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload) # ----------------------------------------------------------------------------- class Host(AbortableEventEmitter): - def __init__(self, controller_source=None, controller_sink=None): + connections: Dict[int, Connection] + acl_packet_queue: collections.deque[HCI_AclDataPacket] + hci_sink: TransportSink + long_term_key_provider: Optional[ + Callable[[int, bytes, int], Awaitable[Optional[bytes]]] + ] + link_key_provider: Optional[Callable[[Address], Awaitable[Optional[bytes]]]] + + def __init__( + self, + controller_source: Optional[TransportSource] = None, + controller_sink: Optional[TransportSink] = None, + ) -> None: super().__init__() - self.hci_sink = None self.hci_metadata = None self.ready = False # True when we can accept incoming packets self.reset_done = False @@ -296,7 +314,7 @@ class Host(AbortableEventEmitter): self.reset_done = True @property - def controller(self): + def controller(self) -> TransportSink: return self.hci_sink @controller.setter @@ -305,13 +323,12 @@ class Host(AbortableEventEmitter): if controller: controller.set_packet_sink(self) - def set_packet_sink(self, sink): + def set_packet_sink(self, sink: TransportSink) -> None: self.hci_sink = sink - def send_hci_packet(self, packet): + def send_hci_packet(self, packet: HCI_Packet) -> None: if self.snooper: self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER) - self.hci_sink.on_packet(bytes(packet)) async def send_command(self, command, check_result=False): @@ -349,7 +366,7 @@ class Host(AbortableEventEmitter): return response except Exception as error: logger.warning( - f'{color("!!! Exception while sending HCI packet:", "red")} {error}' + f'{color("!!! Exception while sending command:", "red")} {error}' ) raise error finally: @@ -357,13 +374,13 @@ class Host(AbortableEventEmitter): self.pending_response = None # Use this method to send a command from a task - def send_command_sync(self, command): - async def send_command(command): + def send_command_sync(self, command: HCI_Command) -> None: + async def send_command(command: HCI_Command) -> None: await self.send_command(command) asyncio.create_task(send_command(command)) - def send_l2cap_pdu(self, connection_handle, cid, pdu): + def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None: l2cap_pdu = bytes(L2CAP_PDU(cid, pdu)) # Send the data to the controller via ACL packets @@ -388,7 +405,7 @@ class Host(AbortableEventEmitter): offset += data_total_length bytes_remaining -= data_total_length - def queue_acl_packet(self, acl_packet): + def queue_acl_packet(self, acl_packet: HCI_AclDataPacket) -> None: self.acl_packet_queue.appendleft(acl_packet) self.check_acl_packet_queue() @@ -398,7 +415,7 @@ class Host(AbortableEventEmitter): f'{len(self.acl_packet_queue)} in queue' ) - def check_acl_packet_queue(self): + def check_acl_packet_queue(self) -> None: # Send all we can (TODO: support different LE/Classic limits) while ( len(self.acl_packet_queue) > 0 @@ -444,47 +461,53 @@ class Host(AbortableEventEmitter): ] # Packet Sink protocol (packets coming from the controller via HCI) - def on_packet(self, packet): + def on_packet(self, packet: bytes) -> None: hci_packet = HCI_Packet.from_bytes(packet) if self.ready or ( - hci_packet.hci_packet_type == HCI_EVENT_PACKET - and hci_packet.event_code == HCI_COMMAND_COMPLETE_EVENT + isinstance(hci_packet, HCI_Command_Complete_Event) and hci_packet.command_opcode == HCI_RESET_COMMAND ): self.on_hci_packet(hci_packet) else: logger.debug('reset not done, ignoring packet from controller') - def on_hci_packet(self, packet): + def on_transport_lost(self): + # Called by the source when the transport has been lost. + if self.pending_response: + self.pending_response.set_exception(TransportLostError('transport lost')) + + self.emit('flush') + + def on_hci_packet(self, packet: HCI_Packet) -> None: logger.debug(f'{color("### CONTROLLER -> HOST", "green")}: {packet}') if self.snooper: self.snooper.snoop(bytes(packet), Snooper.Direction.CONTROLLER_TO_HOST) # If the packet is a command, invoke the handler for this packet - if packet.hci_packet_type == HCI_COMMAND_PACKET: + if isinstance(packet, HCI_Command): self.on_hci_command_packet(packet) - elif packet.hci_packet_type == HCI_EVENT_PACKET: + elif isinstance(packet, HCI_Event): self.on_hci_event_packet(packet) - elif packet.hci_packet_type == HCI_ACL_DATA_PACKET: + elif isinstance(packet, HCI_AclDataPacket): self.on_hci_acl_data_packet(packet) else: logger.warning(f'!!! unknown packet type {packet.hci_packet_type}') - def on_hci_command_packet(self, command): + def on_hci_command_packet(self, command: HCI_Command) -> None: logger.warning(f'!!! unexpected command packet: {command}') - def on_hci_event_packet(self, event): + def on_hci_event_packet(self, event: HCI_Event) -> None: handler_name = f'on_{event.name.lower()}' handler = getattr(self, handler_name, self.on_hci_event) handler(event) - def on_hci_acl_data_packet(self, packet): + def on_hci_acl_data_packet(self, packet: HCI_AclDataPacket) -> None: # Look for the connection to which this data belongs if connection := self.connections.get(packet.connection_handle): connection.on_hci_acl_data_packet(packet) - def on_l2cap_pdu(self, connection, cid, pdu): + def on_l2cap_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None: self.emit('l2cap_pdu', connection.handle, cid, pdu) def on_command_processed(self, event): @@ -822,6 +845,10 @@ class Host(AbortableEventEmitter): f'simple pairing complete for {event.bd_addr}: ' f'status={HCI_Constant.status_name(event.status)}' ) + if event.status == HCI_SUCCESS: + self.emit('classic_pairing', event.bd_addr) + else: + self.emit('classic_pairing_failure', event.bd_addr, event.status) def on_hci_pin_code_request_event(self, event): self.emit('pin_code_request', event.bd_addr) |