diff options
author | Josh Wu <joshwu@google.com> | 2023-09-14 20:52:33 +0800 |
---|---|---|
committer | Josh Wu <joshwu@google.com> | 2023-09-14 20:52:33 +0800 |
commit | 5d9598ea514c4f272196483d8cd39cff0f0b7bb0 (patch) | |
tree | bf645d9bad04fdee4f80d29c6367026f220f8d89 | |
parent | 0d36d99a73dcae215e08002087eb01bc6d03954c (diff) | |
download | bumble-5d9598ea514c4f272196483d8cd39cff0f0b7bb0.tar.gz |
L2CAP: Refactor states to enums
-rw-r--r-- | bumble/l2cap.py | 217 |
1 files changed, 92 insertions, 125 deletions
diff --git a/bumble/l2cap.py b/bumble/l2cap.py index 736e155..cccb172 100644 --- a/bumble/l2cap.py +++ b/bumble/l2cap.py @@ -17,6 +17,7 @@ # ----------------------------------------------------------------------------- from __future__ import annotations import asyncio +import enum import logging import struct @@ -676,56 +677,35 @@ class L2CAP_LE_Flow_Control_Credit(L2CAP_Control_Frame): # ----------------------------------------------------------------------------- class Channel(EventEmitter): - # States - CLOSED = 0x00 - WAIT_CONNECT = 0x01 - WAIT_CONNECT_RSP = 0x02 - OPEN = 0x03 - WAIT_DISCONNECT = 0x04 - WAIT_CREATE = 0x05 - WAIT_CREATE_RSP = 0x06 - WAIT_MOVE = 0x07 - WAIT_MOVE_RSP = 0x08 - WAIT_MOVE_CONFIRM = 0x09 - WAIT_CONFIRM_RSP = 0x0A - - # CONFIG substates - WAIT_CONFIG = 0x10 - WAIT_SEND_CONFIG = 0x11 - WAIT_CONFIG_REQ_RSP = 0x12 - WAIT_CONFIG_RSP = 0x13 - WAIT_CONFIG_REQ = 0x14 - WAIT_IND_FINAL_RSP = 0x15 - WAIT_FINAL_RSP = 0x16 - WAIT_CONTROL_IND = 0x17 - - STATE_NAMES = { - CLOSED: 'CLOSED', - WAIT_CONNECT: 'WAIT_CONNECT', - WAIT_CONNECT_RSP: 'WAIT_CONNECT_RSP', - OPEN: 'OPEN', - WAIT_DISCONNECT: 'WAIT_DISCONNECT', - WAIT_CREATE: 'WAIT_CREATE', - WAIT_CREATE_RSP: 'WAIT_CREATE_RSP', - WAIT_MOVE: 'WAIT_MOVE', - WAIT_MOVE_RSP: 'WAIT_MOVE_RSP', - WAIT_MOVE_CONFIRM: 'WAIT_MOVE_CONFIRM', - WAIT_CONFIRM_RSP: 'WAIT_CONFIRM_RSP', - WAIT_CONFIG: 'WAIT_CONFIG', - WAIT_SEND_CONFIG: 'WAIT_SEND_CONFIG', - WAIT_CONFIG_REQ_RSP: 'WAIT_CONFIG_REQ_RSP', - WAIT_CONFIG_RSP: 'WAIT_CONFIG_RSP', - WAIT_CONFIG_REQ: 'WAIT_CONFIG_REQ', - WAIT_IND_FINAL_RSP: 'WAIT_IND_FINAL_RSP', - WAIT_FINAL_RSP: 'WAIT_FINAL_RSP', - WAIT_CONTROL_IND: 'WAIT_CONTROL_IND', - } + class State(enum.IntEnum): + # States + CLOSED = 0x00 + WAIT_CONNECT = 0x01 + WAIT_CONNECT_RSP = 0x02 + OPEN = 0x03 + WAIT_DISCONNECT = 0x04 + WAIT_CREATE = 0x05 + WAIT_CREATE_RSP = 0x06 + WAIT_MOVE = 0x07 + WAIT_MOVE_RSP = 0x08 + WAIT_MOVE_CONFIRM = 0x09 + WAIT_CONFIRM_RSP = 0x0A + + # CONFIG substates + WAIT_CONFIG = 0x10 + WAIT_SEND_CONFIG = 0x11 + WAIT_CONFIG_REQ_RSP = 0x12 + WAIT_CONFIG_RSP = 0x13 + WAIT_CONFIG_REQ = 0x14 + WAIT_IND_FINAL_RSP = 0x15 + WAIT_FINAL_RSP = 0x16 + WAIT_CONTROL_IND = 0x17 connection_result: Optional[asyncio.Future[None]] disconnection_result: Optional[asyncio.Future[None]] response: Optional[asyncio.Future[bytes]] sink: Optional[Callable[[bytes], Any]] - state: int + state: State connection: Connection def __init__( @@ -741,7 +721,7 @@ class Channel(EventEmitter): self.manager = manager self.connection = connection self.signaling_cid = signaling_cid - self.state = Channel.CLOSED + self.state = self.State.CLOSED self.mtu = mtu self.psm = psm self.source_cid = source_cid @@ -751,10 +731,8 @@ class Channel(EventEmitter): self.disconnection_result = None self.sink = None - def change_state(self, new_state: int) -> None: - logger.debug( - f'{self} state change -> {color(Channel.STATE_NAMES[new_state], "cyan")}' - ) + def _change_state(self, new_state: State) -> None: + logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}') self.state = new_state def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None: @@ -767,7 +745,7 @@ class Channel(EventEmitter): # Check that there isn't already a request pending if self.response: raise InvalidStateError('request already pending') - if self.state != Channel.OPEN: + if self.state != self.State.OPEN: raise InvalidStateError('channel not open') self.response = asyncio.get_running_loop().create_future() @@ -787,14 +765,14 @@ class Channel(EventEmitter): ) async def connect(self) -> None: - if self.state != Channel.CLOSED: + if self.state != self.State.CLOSED: raise InvalidStateError('invalid state') # Check that we can start a new connection if self.connection_result: raise RuntimeError('connection already pending') - self.change_state(Channel.WAIT_CONNECT_RSP) + self._change_state(self.State.WAIT_CONNECT_RSP) self.send_control_frame( L2CAP_Connection_Request( identifier=self.manager.next_identifier(self.connection), @@ -814,10 +792,10 @@ class Channel(EventEmitter): self.connection_result = None async def disconnect(self) -> None: - if self.state != Channel.OPEN: + if self.state != self.State.OPEN: raise InvalidStateError('invalid state') - self.change_state(Channel.WAIT_DISCONNECT) + self._change_state(self.State.WAIT_DISCONNECT) self.send_control_frame( L2CAP_Disconnection_Request( identifier=self.manager.next_identifier(self.connection), @@ -832,8 +810,8 @@ class Channel(EventEmitter): return await self.disconnection_result def abort(self) -> None: - if self.state == self.OPEN: - self.change_state(self.CLOSED) + if self.state == self.State.OPEN: + self._change_state(self.State.CLOSED) self.emit('close') def send_configure_request(self) -> None: @@ -856,7 +834,7 @@ class Channel(EventEmitter): def on_connection_request(self, request) -> None: self.destination_cid = request.source_cid - self.change_state(Channel.WAIT_CONNECT) + self._change_state(self.State.WAIT_CONNECT) self.send_control_frame( L2CAP_Connection_Response( identifier=request.identifier, @@ -866,24 +844,24 @@ class Channel(EventEmitter): status=0x0000, ) ) - self.change_state(Channel.WAIT_CONFIG) + self._change_state(self.State.WAIT_CONFIG) self.send_configure_request() - self.change_state(Channel.WAIT_CONFIG_REQ_RSP) + self._change_state(self.State.WAIT_CONFIG_REQ_RSP) def on_connection_response(self, response): - if self.state != Channel.WAIT_CONNECT_RSP: + if self.state != self.State.WAIT_CONNECT_RSP: logger.warning(color('invalid state', 'red')) return if response.result == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL: self.destination_cid = response.destination_cid - self.change_state(Channel.WAIT_CONFIG) + self._change_state(self.State.WAIT_CONFIG) self.send_configure_request() - self.change_state(Channel.WAIT_CONFIG_REQ_RSP) + self._change_state(self.State.WAIT_CONFIG_REQ_RSP) elif response.result == L2CAP_Connection_Response.CONNECTION_PENDING: pass else: - self.change_state(Channel.CLOSED) + self._change_state(self.State.CLOSED) self.connection_result.set_exception( ProtocolError( response.result, @@ -895,9 +873,9 @@ class Channel(EventEmitter): def on_configure_request(self, request) -> None: if self.state not in ( - Channel.WAIT_CONFIG, - Channel.WAIT_CONFIG_REQ, - Channel.WAIT_CONFIG_REQ_RSP, + self.State.WAIT_CONFIG, + self.State.WAIT_CONFIG_REQ, + self.State.WAIT_CONFIG_REQ_RSP, ): logger.warning(color('invalid state', 'red')) return @@ -918,25 +896,28 @@ class Channel(EventEmitter): options=request.options, # TODO: don't accept everything blindly ) ) - if self.state == Channel.WAIT_CONFIG: - self.change_state(Channel.WAIT_SEND_CONFIG) + if self.state == self.State.WAIT_CONFIG: + self._change_state(self.State.WAIT_SEND_CONFIG) self.send_configure_request() - self.change_state(Channel.WAIT_CONFIG_RSP) - elif self.state == Channel.WAIT_CONFIG_REQ: - self.change_state(Channel.OPEN) + self._change_state(self.State.WAIT_CONFIG_RSP) + elif self.state == self.State.WAIT_CONFIG_REQ: + self._change_state(self.State.OPEN) if self.connection_result: self.connection_result.set_result(None) self.connection_result = None self.emit('open') - elif self.state == Channel.WAIT_CONFIG_REQ_RSP: - self.change_state(Channel.WAIT_CONFIG_RSP) + elif self.state == self.State.WAIT_CONFIG_REQ_RSP: + self._change_state(self.State.WAIT_CONFIG_RSP) def on_configure_response(self, response) -> None: if response.result == L2CAP_Configure_Response.SUCCESS: - if self.state == Channel.WAIT_CONFIG_REQ_RSP: - self.change_state(Channel.WAIT_CONFIG_REQ) - elif self.state in (Channel.WAIT_CONFIG_RSP, Channel.WAIT_CONTROL_IND): - self.change_state(Channel.OPEN) + if self.state == self.State.WAIT_CONFIG_REQ_RSP: + self._change_state(self.State.WAIT_CONFIG_REQ) + elif self.state in ( + self.State.WAIT_CONFIG_RSP, + self.State.WAIT_CONTROL_IND, + ): + self._change_state(self.State.OPEN) if self.connection_result: self.connection_result.set_result(None) self.connection_result = None @@ -966,7 +947,7 @@ class Channel(EventEmitter): # TODO: decide how to fail gracefully def on_disconnection_request(self, request) -> None: - if self.state in (Channel.OPEN, Channel.WAIT_DISCONNECT): + if self.state in (self.State.OPEN, self.State.WAIT_DISCONNECT): self.send_control_frame( L2CAP_Disconnection_Response( identifier=request.identifier, @@ -974,14 +955,14 @@ class Channel(EventEmitter): source_cid=request.source_cid, ) ) - self.change_state(Channel.CLOSED) + self._change_state(self.State.CLOSED) self.emit('close') self.manager.on_channel_closed(self) else: logger.warning(color('invalid state', 'red')) def on_disconnection_response(self, response) -> None: - if self.state != Channel.WAIT_DISCONNECT: + if self.state != self.State.WAIT_DISCONNECT: logger.warning(color('invalid state', 'red')) return @@ -992,7 +973,7 @@ class Channel(EventEmitter): logger.warning('unexpected source or destination CID') return - self.change_state(Channel.CLOSED) + self._change_state(self.State.CLOSED) if self.disconnection_result: self.disconnection_result.set_result(None) self.disconnection_result = None @@ -1004,7 +985,7 @@ class Channel(EventEmitter): f'Channel({self.source_cid}->{self.destination_cid}, ' f'PSM={self.psm}, ' f'MTU={self.mtu}, ' - f'state={Channel.STATE_NAMES[self.state]})' + f'state={self.state.name})' ) @@ -1014,33 +995,21 @@ class LeConnectionOrientedChannel(EventEmitter): LE Credit-based Connection Oriented Channel """ - INIT = 0 - CONNECTED = 1 - CONNECTING = 2 - DISCONNECTING = 3 - DISCONNECTED = 4 - CONNECTION_ERROR = 5 - - STATE_NAMES = { - INIT: 'INIT', - CONNECTED: 'CONNECTED', - CONNECTING: 'CONNECTING', - DISCONNECTING: 'DISCONNECTING', - DISCONNECTED: 'DISCONNECTED', - CONNECTION_ERROR: 'CONNECTION_ERROR', - } + class State(enum.IntEnum): + INIT = 0 + CONNECTED = 1 + CONNECTING = 2 + DISCONNECTING = 3 + DISCONNECTED = 4 + CONNECTION_ERROR = 5 out_queue: Deque[bytes] connection_result: Optional[asyncio.Future[LeConnectionOrientedChannel]] disconnection_result: Optional[asyncio.Future[None]] out_sdu: Optional[bytes] - state: int + state: State connection: Connection - @staticmethod - def state_name(state: int) -> str: - return name_or_number(LeConnectionOrientedChannel.STATE_NAMES, state) - def __init__( self, manager: ChannelManager, @@ -1083,19 +1052,17 @@ class LeConnectionOrientedChannel(EventEmitter): self.drained.set() if connected: - self.state = LeConnectionOrientedChannel.CONNECTED + self.state = self.State.CONNECTED else: - self.state = LeConnectionOrientedChannel.INIT + self.state = self.State.INIT - def change_state(self, new_state: int) -> None: - logger.debug( - f'{self} state change -> {color(self.state_name(new_state), "cyan")}' - ) + def _change_state(self, new_state: State) -> None: + logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}') self.state = new_state - if new_state == self.CONNECTED: + if new_state == self.State.CONNECTED: self.emit('open') - elif new_state == self.DISCONNECTED: + elif new_state == self.State.DISCONNECTED: self.emit('close') def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None: @@ -1106,7 +1073,7 @@ class LeConnectionOrientedChannel(EventEmitter): async def connect(self) -> LeConnectionOrientedChannel: # Check that we're in the right state - if self.state != self.INIT: + if self.state != self.State.INIT: raise InvalidStateError('not in a connectable state') # Check that we can start a new connection @@ -1114,7 +1081,7 @@ class LeConnectionOrientedChannel(EventEmitter): if identifier in self.manager.le_coc_requests: raise RuntimeError('too many concurrent connection requests') - self.change_state(self.CONNECTING) + self._change_state(self.State.CONNECTING) request = L2CAP_LE_Credit_Based_Connection_Request( identifier=identifier, le_psm=self.le_psm, @@ -1134,10 +1101,10 @@ class LeConnectionOrientedChannel(EventEmitter): async def disconnect(self) -> None: # Check that we're connected - if self.state != self.CONNECTED: + if self.state != self.State.CONNECTED: raise InvalidStateError('not connected') - self.change_state(self.DISCONNECTING) + self._change_state(self.State.DISCONNECTING) self.flush_output() self.send_control_frame( L2CAP_Disconnection_Request( @@ -1153,15 +1120,15 @@ class LeConnectionOrientedChannel(EventEmitter): return await self.disconnection_result def abort(self) -> None: - if self.state == self.CONNECTED: - self.change_state(self.DISCONNECTED) + if self.state == self.State.CONNECTED: + self._change_state(self.State.DISCONNECTED) def on_pdu(self, pdu: bytes) -> None: if self.sink is None: logger.warning('received pdu without a sink') return - if self.state != self.CONNECTED: + if self.state != self.State.CONNECTED: logger.warning('received PDU while not connected, dropping') # Manage the peer credits @@ -1240,7 +1207,7 @@ class LeConnectionOrientedChannel(EventEmitter): self.credits = response.initial_credits self.connected = True self.connection_result.set_result(self) - self.change_state(self.CONNECTED) + self._change_state(self.State.CONNECTED) else: self.connection_result.set_exception( ProtocolError( @@ -1251,7 +1218,7 @@ class LeConnectionOrientedChannel(EventEmitter): ), ) ) - self.change_state(self.CONNECTION_ERROR) + self._change_state(self.State.CONNECTION_ERROR) # Cleanup self.connection_result = None @@ -1271,11 +1238,11 @@ class LeConnectionOrientedChannel(EventEmitter): source_cid=request.source_cid, ) ) - self.change_state(self.DISCONNECTED) + self._change_state(self.State.DISCONNECTED) self.flush_output() def on_disconnection_response(self, response) -> None: - if self.state != self.DISCONNECTING: + if self.state != self.State.DISCONNECTING: logger.warning(color('invalid state', 'red')) return @@ -1286,7 +1253,7 @@ class LeConnectionOrientedChannel(EventEmitter): logger.warning('unexpected source or destination CID') return - self.change_state(self.DISCONNECTED) + self._change_state(self.State.DISCONNECTED) if self.disconnection_result: self.disconnection_result.set_result(None) self.disconnection_result = None @@ -1339,7 +1306,7 @@ class LeConnectionOrientedChannel(EventEmitter): return def write(self, data: bytes) -> None: - if self.state != self.CONNECTED: + if self.state != self.State.CONNECTED: logger.warning('not connected, dropping data') return @@ -1367,7 +1334,7 @@ class LeConnectionOrientedChannel(EventEmitter): def __str__(self) -> str: return ( f'CoC({self.source_cid}->{self.destination_cid}, ' - f'State={self.state_name(self.state)}, ' + f'State={self.state.name}, ' f'PSM={self.le_psm}, ' f'MTU={self.mtu}/{self.peer_mtu}, ' f'MPS={self.mps}/{self.peer_mps}, ' |