aboutsummaryrefslogtreecommitdiff
path: root/bumble/host.py
diff options
context:
space:
mode:
Diffstat (limited to 'bumble/host.py')
-rw-r--r--bumble/host.py89
1 files changed, 58 insertions, 31 deletions
diff --git a/bumble/host.py b/bumble/host.py
index e41fd02..02caa46 100644
--- a/bumble/host.py
+++ b/bumble/host.py
@@ -15,23 +15,24 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
+from __future__ import annotations
import asyncio
import collections
import logging
import struct
+from typing import Optional, TYPE_CHECKING, Dict, Callable, Awaitable
+
from bumble.colors import color
from bumble.l2cap import L2CAP_PDU
from bumble.snoop import Snooper
from bumble import drivers
-from typing import Optional
-
from .hci import (
Address,
HCI_ACL_DATA_PACKET,
- HCI_COMMAND_COMPLETE_EVENT,
HCI_COMMAND_PACKET,
+ HCI_COMMAND_COMPLETE_EVENT,
HCI_EVENT_PACKET,
HCI_LE_READ_BUFFER_SIZE_COMMAND,
HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND,
@@ -45,8 +46,11 @@ from .hci import (
HCI_VERSION_BLUETOOTH_CORE_4_0,
HCI_AclDataPacket,
HCI_AclDataPacketAssembler,
+ HCI_Command,
+ HCI_Command_Complete_Event,
HCI_Constant,
HCI_Error,
+ HCI_Event,
HCI_LE_Long_Term_Key_Request_Negative_Reply_Command,
HCI_LE_Long_Term_Key_Request_Reply_Command,
HCI_LE_Read_Buffer_Size_Command,
@@ -63,16 +67,19 @@ from .hci import (
HCI_Read_Local_Version_Information_Command,
HCI_Reset_Command,
HCI_Set_Event_Mask_Command,
- map_null_terminated_utf8_string,
)
from .core import (
BT_BR_EDR_TRANSPORT,
- BT_CENTRAL_ROLE,
BT_LE_TRANSPORT,
ConnectionPHY,
ConnectionParameters,
+ InvalidStateError,
)
from .utils import AbortableEventEmitter
+from .transport.common import TransportLostError
+
+if TYPE_CHECKING:
+ from .transport.common import TransportSink, TransportSource
# -----------------------------------------------------------------------------
@@ -96,27 +103,38 @@ HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS = 1
# -----------------------------------------------------------------------------
class Connection:
- def __init__(self, host, handle, peer_address, transport):
+ def __init__(self, host: Host, handle: int, peer_address: Address, transport: int):
self.host = host
self.handle = handle
self.peer_address = peer_address
self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.transport = transport
- def on_hci_acl_data_packet(self, packet):
+ def on_hci_acl_data_packet(self, packet: HCI_AclDataPacket) -> None:
self.assembler.feed_packet(packet)
- def on_acl_pdu(self, pdu):
+ def on_acl_pdu(self, pdu: bytes) -> None:
l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload)
# -----------------------------------------------------------------------------
class Host(AbortableEventEmitter):
- def __init__(self, controller_source=None, controller_sink=None):
+ connections: Dict[int, Connection]
+ acl_packet_queue: collections.deque[HCI_AclDataPacket]
+ hci_sink: TransportSink
+ long_term_key_provider: Optional[
+ Callable[[int, bytes, int], Awaitable[Optional[bytes]]]
+ ]
+ link_key_provider: Optional[Callable[[Address], Awaitable[Optional[bytes]]]]
+
+ def __init__(
+ self,
+ controller_source: Optional[TransportSource] = None,
+ controller_sink: Optional[TransportSink] = None,
+ ) -> None:
super().__init__()
- self.hci_sink = None
self.hci_metadata = None
self.ready = False # True when we can accept incoming packets
self.reset_done = False
@@ -296,7 +314,7 @@ class Host(AbortableEventEmitter):
self.reset_done = True
@property
- def controller(self):
+ def controller(self) -> TransportSink:
return self.hci_sink
@controller.setter
@@ -305,13 +323,12 @@ class Host(AbortableEventEmitter):
if controller:
controller.set_packet_sink(self)
- def set_packet_sink(self, sink):
+ def set_packet_sink(self, sink: TransportSink) -> None:
self.hci_sink = sink
- def send_hci_packet(self, packet):
+ def send_hci_packet(self, packet: HCI_Packet) -> None:
if self.snooper:
self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER)
-
self.hci_sink.on_packet(bytes(packet))
async def send_command(self, command, check_result=False):
@@ -349,7 +366,7 @@ class Host(AbortableEventEmitter):
return response
except Exception as error:
logger.warning(
- f'{color("!!! Exception while sending HCI packet:", "red")} {error}'
+ f'{color("!!! Exception while sending command:", "red")} {error}'
)
raise error
finally:
@@ -357,13 +374,13 @@ class Host(AbortableEventEmitter):
self.pending_response = None
# Use this method to send a command from a task
- def send_command_sync(self, command):
- async def send_command(command):
+ def send_command_sync(self, command: HCI_Command) -> None:
+ async def send_command(command: HCI_Command) -> None:
await self.send_command(command)
asyncio.create_task(send_command(command))
- def send_l2cap_pdu(self, connection_handle, cid, pdu):
+ def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
l2cap_pdu = bytes(L2CAP_PDU(cid, pdu))
# Send the data to the controller via ACL packets
@@ -388,7 +405,7 @@ class Host(AbortableEventEmitter):
offset += data_total_length
bytes_remaining -= data_total_length
- def queue_acl_packet(self, acl_packet):
+ def queue_acl_packet(self, acl_packet: HCI_AclDataPacket) -> None:
self.acl_packet_queue.appendleft(acl_packet)
self.check_acl_packet_queue()
@@ -398,7 +415,7 @@ class Host(AbortableEventEmitter):
f'{len(self.acl_packet_queue)} in queue'
)
- def check_acl_packet_queue(self):
+ def check_acl_packet_queue(self) -> None:
# Send all we can (TODO: support different LE/Classic limits)
while (
len(self.acl_packet_queue) > 0
@@ -444,47 +461,53 @@ class Host(AbortableEventEmitter):
]
# Packet Sink protocol (packets coming from the controller via HCI)
- def on_packet(self, packet):
+ def on_packet(self, packet: bytes) -> None:
hci_packet = HCI_Packet.from_bytes(packet)
if self.ready or (
- hci_packet.hci_packet_type == HCI_EVENT_PACKET
- and hci_packet.event_code == HCI_COMMAND_COMPLETE_EVENT
+ isinstance(hci_packet, HCI_Command_Complete_Event)
and hci_packet.command_opcode == HCI_RESET_COMMAND
):
self.on_hci_packet(hci_packet)
else:
logger.debug('reset not done, ignoring packet from controller')
- def on_hci_packet(self, packet):
+ def on_transport_lost(self):
+ # Called by the source when the transport has been lost.
+ if self.pending_response:
+ self.pending_response.set_exception(TransportLostError('transport lost'))
+
+ self.emit('flush')
+
+ def on_hci_packet(self, packet: HCI_Packet) -> None:
logger.debug(f'{color("### CONTROLLER -> HOST", "green")}: {packet}')
if self.snooper:
self.snooper.snoop(bytes(packet), Snooper.Direction.CONTROLLER_TO_HOST)
# If the packet is a command, invoke the handler for this packet
- if packet.hci_packet_type == HCI_COMMAND_PACKET:
+ if isinstance(packet, HCI_Command):
self.on_hci_command_packet(packet)
- elif packet.hci_packet_type == HCI_EVENT_PACKET:
+ elif isinstance(packet, HCI_Event):
self.on_hci_event_packet(packet)
- elif packet.hci_packet_type == HCI_ACL_DATA_PACKET:
+ elif isinstance(packet, HCI_AclDataPacket):
self.on_hci_acl_data_packet(packet)
else:
logger.warning(f'!!! unknown packet type {packet.hci_packet_type}')
- def on_hci_command_packet(self, command):
+ def on_hci_command_packet(self, command: HCI_Command) -> None:
logger.warning(f'!!! unexpected command packet: {command}')
- def on_hci_event_packet(self, event):
+ def on_hci_event_packet(self, event: HCI_Event) -> None:
handler_name = f'on_{event.name.lower()}'
handler = getattr(self, handler_name, self.on_hci_event)
handler(event)
- def on_hci_acl_data_packet(self, packet):
+ def on_hci_acl_data_packet(self, packet: HCI_AclDataPacket) -> None:
# Look for the connection to which this data belongs
if connection := self.connections.get(packet.connection_handle):
connection.on_hci_acl_data_packet(packet)
- def on_l2cap_pdu(self, connection, cid, pdu):
+ def on_l2cap_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None:
self.emit('l2cap_pdu', connection.handle, cid, pdu)
def on_command_processed(self, event):
@@ -822,6 +845,10 @@ class Host(AbortableEventEmitter):
f'simple pairing complete for {event.bd_addr}: '
f'status={HCI_Constant.status_name(event.status)}'
)
+ if event.status == HCI_SUCCESS:
+ self.emit('classic_pairing', event.bd_addr)
+ else:
+ self.emit('classic_pairing_failure', event.bd_addr, event.status)
def on_hci_pin_code_request_event(self, event):
self.emit('pin_code_request', event.bd_addr)