aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosh Wu <joshwu@google.com>2023-09-14 20:52:33 +0800
committerJosh Wu <joshwu@google.com>2023-09-14 20:52:33 +0800
commit5d9598ea514c4f272196483d8cd39cff0f0b7bb0 (patch)
treebf645d9bad04fdee4f80d29c6367026f220f8d89
parent0d36d99a73dcae215e08002087eb01bc6d03954c (diff)
downloadbumble-5d9598ea514c4f272196483d8cd39cff0f0b7bb0.tar.gz
L2CAP: Refactor states to enums
-rw-r--r--bumble/l2cap.py217
1 files changed, 92 insertions, 125 deletions
diff --git a/bumble/l2cap.py b/bumble/l2cap.py
index 736e155..cccb172 100644
--- a/bumble/l2cap.py
+++ b/bumble/l2cap.py
@@ -17,6 +17,7 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
+import enum
import logging
import struct
@@ -676,56 +677,35 @@ class L2CAP_LE_Flow_Control_Credit(L2CAP_Control_Frame):
# -----------------------------------------------------------------------------
class Channel(EventEmitter):
- # States
- CLOSED = 0x00
- WAIT_CONNECT = 0x01
- WAIT_CONNECT_RSP = 0x02
- OPEN = 0x03
- WAIT_DISCONNECT = 0x04
- WAIT_CREATE = 0x05
- WAIT_CREATE_RSP = 0x06
- WAIT_MOVE = 0x07
- WAIT_MOVE_RSP = 0x08
- WAIT_MOVE_CONFIRM = 0x09
- WAIT_CONFIRM_RSP = 0x0A
-
- # CONFIG substates
- WAIT_CONFIG = 0x10
- WAIT_SEND_CONFIG = 0x11
- WAIT_CONFIG_REQ_RSP = 0x12
- WAIT_CONFIG_RSP = 0x13
- WAIT_CONFIG_REQ = 0x14
- WAIT_IND_FINAL_RSP = 0x15
- WAIT_FINAL_RSP = 0x16
- WAIT_CONTROL_IND = 0x17
-
- STATE_NAMES = {
- CLOSED: 'CLOSED',
- WAIT_CONNECT: 'WAIT_CONNECT',
- WAIT_CONNECT_RSP: 'WAIT_CONNECT_RSP',
- OPEN: 'OPEN',
- WAIT_DISCONNECT: 'WAIT_DISCONNECT',
- WAIT_CREATE: 'WAIT_CREATE',
- WAIT_CREATE_RSP: 'WAIT_CREATE_RSP',
- WAIT_MOVE: 'WAIT_MOVE',
- WAIT_MOVE_RSP: 'WAIT_MOVE_RSP',
- WAIT_MOVE_CONFIRM: 'WAIT_MOVE_CONFIRM',
- WAIT_CONFIRM_RSP: 'WAIT_CONFIRM_RSP',
- WAIT_CONFIG: 'WAIT_CONFIG',
- WAIT_SEND_CONFIG: 'WAIT_SEND_CONFIG',
- WAIT_CONFIG_REQ_RSP: 'WAIT_CONFIG_REQ_RSP',
- WAIT_CONFIG_RSP: 'WAIT_CONFIG_RSP',
- WAIT_CONFIG_REQ: 'WAIT_CONFIG_REQ',
- WAIT_IND_FINAL_RSP: 'WAIT_IND_FINAL_RSP',
- WAIT_FINAL_RSP: 'WAIT_FINAL_RSP',
- WAIT_CONTROL_IND: 'WAIT_CONTROL_IND',
- }
+ class State(enum.IntEnum):
+ # States
+ CLOSED = 0x00
+ WAIT_CONNECT = 0x01
+ WAIT_CONNECT_RSP = 0x02
+ OPEN = 0x03
+ WAIT_DISCONNECT = 0x04
+ WAIT_CREATE = 0x05
+ WAIT_CREATE_RSP = 0x06
+ WAIT_MOVE = 0x07
+ WAIT_MOVE_RSP = 0x08
+ WAIT_MOVE_CONFIRM = 0x09
+ WAIT_CONFIRM_RSP = 0x0A
+
+ # CONFIG substates
+ WAIT_CONFIG = 0x10
+ WAIT_SEND_CONFIG = 0x11
+ WAIT_CONFIG_REQ_RSP = 0x12
+ WAIT_CONFIG_RSP = 0x13
+ WAIT_CONFIG_REQ = 0x14
+ WAIT_IND_FINAL_RSP = 0x15
+ WAIT_FINAL_RSP = 0x16
+ WAIT_CONTROL_IND = 0x17
connection_result: Optional[asyncio.Future[None]]
disconnection_result: Optional[asyncio.Future[None]]
response: Optional[asyncio.Future[bytes]]
sink: Optional[Callable[[bytes], Any]]
- state: int
+ state: State
connection: Connection
def __init__(
@@ -741,7 +721,7 @@ class Channel(EventEmitter):
self.manager = manager
self.connection = connection
self.signaling_cid = signaling_cid
- self.state = Channel.CLOSED
+ self.state = self.State.CLOSED
self.mtu = mtu
self.psm = psm
self.source_cid = source_cid
@@ -751,10 +731,8 @@ class Channel(EventEmitter):
self.disconnection_result = None
self.sink = None
- def change_state(self, new_state: int) -> None:
- logger.debug(
- f'{self} state change -> {color(Channel.STATE_NAMES[new_state], "cyan")}'
- )
+ def _change_state(self, new_state: State) -> None:
+ logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}')
self.state = new_state
def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
@@ -767,7 +745,7 @@ class Channel(EventEmitter):
# Check that there isn't already a request pending
if self.response:
raise InvalidStateError('request already pending')
- if self.state != Channel.OPEN:
+ if self.state != self.State.OPEN:
raise InvalidStateError('channel not open')
self.response = asyncio.get_running_loop().create_future()
@@ -787,14 +765,14 @@ class Channel(EventEmitter):
)
async def connect(self) -> None:
- if self.state != Channel.CLOSED:
+ if self.state != self.State.CLOSED:
raise InvalidStateError('invalid state')
# Check that we can start a new connection
if self.connection_result:
raise RuntimeError('connection already pending')
- self.change_state(Channel.WAIT_CONNECT_RSP)
+ self._change_state(self.State.WAIT_CONNECT_RSP)
self.send_control_frame(
L2CAP_Connection_Request(
identifier=self.manager.next_identifier(self.connection),
@@ -814,10 +792,10 @@ class Channel(EventEmitter):
self.connection_result = None
async def disconnect(self) -> None:
- if self.state != Channel.OPEN:
+ if self.state != self.State.OPEN:
raise InvalidStateError('invalid state')
- self.change_state(Channel.WAIT_DISCONNECT)
+ self._change_state(self.State.WAIT_DISCONNECT)
self.send_control_frame(
L2CAP_Disconnection_Request(
identifier=self.manager.next_identifier(self.connection),
@@ -832,8 +810,8 @@ class Channel(EventEmitter):
return await self.disconnection_result
def abort(self) -> None:
- if self.state == self.OPEN:
- self.change_state(self.CLOSED)
+ if self.state == self.State.OPEN:
+ self._change_state(self.State.CLOSED)
self.emit('close')
def send_configure_request(self) -> None:
@@ -856,7 +834,7 @@ class Channel(EventEmitter):
def on_connection_request(self, request) -> None:
self.destination_cid = request.source_cid
- self.change_state(Channel.WAIT_CONNECT)
+ self._change_state(self.State.WAIT_CONNECT)
self.send_control_frame(
L2CAP_Connection_Response(
identifier=request.identifier,
@@ -866,24 +844,24 @@ class Channel(EventEmitter):
status=0x0000,
)
)
- self.change_state(Channel.WAIT_CONFIG)
+ self._change_state(self.State.WAIT_CONFIG)
self.send_configure_request()
- self.change_state(Channel.WAIT_CONFIG_REQ_RSP)
+ self._change_state(self.State.WAIT_CONFIG_REQ_RSP)
def on_connection_response(self, response):
- if self.state != Channel.WAIT_CONNECT_RSP:
+ if self.state != self.State.WAIT_CONNECT_RSP:
logger.warning(color('invalid state', 'red'))
return
if response.result == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL:
self.destination_cid = response.destination_cid
- self.change_state(Channel.WAIT_CONFIG)
+ self._change_state(self.State.WAIT_CONFIG)
self.send_configure_request()
- self.change_state(Channel.WAIT_CONFIG_REQ_RSP)
+ self._change_state(self.State.WAIT_CONFIG_REQ_RSP)
elif response.result == L2CAP_Connection_Response.CONNECTION_PENDING:
pass
else:
- self.change_state(Channel.CLOSED)
+ self._change_state(self.State.CLOSED)
self.connection_result.set_exception(
ProtocolError(
response.result,
@@ -895,9 +873,9 @@ class Channel(EventEmitter):
def on_configure_request(self, request) -> None:
if self.state not in (
- Channel.WAIT_CONFIG,
- Channel.WAIT_CONFIG_REQ,
- Channel.WAIT_CONFIG_REQ_RSP,
+ self.State.WAIT_CONFIG,
+ self.State.WAIT_CONFIG_REQ,
+ self.State.WAIT_CONFIG_REQ_RSP,
):
logger.warning(color('invalid state', 'red'))
return
@@ -918,25 +896,28 @@ class Channel(EventEmitter):
options=request.options, # TODO: don't accept everything blindly
)
)
- if self.state == Channel.WAIT_CONFIG:
- self.change_state(Channel.WAIT_SEND_CONFIG)
+ if self.state == self.State.WAIT_CONFIG:
+ self._change_state(self.State.WAIT_SEND_CONFIG)
self.send_configure_request()
- self.change_state(Channel.WAIT_CONFIG_RSP)
- elif self.state == Channel.WAIT_CONFIG_REQ:
- self.change_state(Channel.OPEN)
+ self._change_state(self.State.WAIT_CONFIG_RSP)
+ elif self.state == self.State.WAIT_CONFIG_REQ:
+ self._change_state(self.State.OPEN)
if self.connection_result:
self.connection_result.set_result(None)
self.connection_result = None
self.emit('open')
- elif self.state == Channel.WAIT_CONFIG_REQ_RSP:
- self.change_state(Channel.WAIT_CONFIG_RSP)
+ elif self.state == self.State.WAIT_CONFIG_REQ_RSP:
+ self._change_state(self.State.WAIT_CONFIG_RSP)
def on_configure_response(self, response) -> None:
if response.result == L2CAP_Configure_Response.SUCCESS:
- if self.state == Channel.WAIT_CONFIG_REQ_RSP:
- self.change_state(Channel.WAIT_CONFIG_REQ)
- elif self.state in (Channel.WAIT_CONFIG_RSP, Channel.WAIT_CONTROL_IND):
- self.change_state(Channel.OPEN)
+ if self.state == self.State.WAIT_CONFIG_REQ_RSP:
+ self._change_state(self.State.WAIT_CONFIG_REQ)
+ elif self.state in (
+ self.State.WAIT_CONFIG_RSP,
+ self.State.WAIT_CONTROL_IND,
+ ):
+ self._change_state(self.State.OPEN)
if self.connection_result:
self.connection_result.set_result(None)
self.connection_result = None
@@ -966,7 +947,7 @@ class Channel(EventEmitter):
# TODO: decide how to fail gracefully
def on_disconnection_request(self, request) -> None:
- if self.state in (Channel.OPEN, Channel.WAIT_DISCONNECT):
+ if self.state in (self.State.OPEN, self.State.WAIT_DISCONNECT):
self.send_control_frame(
L2CAP_Disconnection_Response(
identifier=request.identifier,
@@ -974,14 +955,14 @@ class Channel(EventEmitter):
source_cid=request.source_cid,
)
)
- self.change_state(Channel.CLOSED)
+ self._change_state(self.State.CLOSED)
self.emit('close')
self.manager.on_channel_closed(self)
else:
logger.warning(color('invalid state', 'red'))
def on_disconnection_response(self, response) -> None:
- if self.state != Channel.WAIT_DISCONNECT:
+ if self.state != self.State.WAIT_DISCONNECT:
logger.warning(color('invalid state', 'red'))
return
@@ -992,7 +973,7 @@ class Channel(EventEmitter):
logger.warning('unexpected source or destination CID')
return
- self.change_state(Channel.CLOSED)
+ self._change_state(self.State.CLOSED)
if self.disconnection_result:
self.disconnection_result.set_result(None)
self.disconnection_result = None
@@ -1004,7 +985,7 @@ class Channel(EventEmitter):
f'Channel({self.source_cid}->{self.destination_cid}, '
f'PSM={self.psm}, '
f'MTU={self.mtu}, '
- f'state={Channel.STATE_NAMES[self.state]})'
+ f'state={self.state.name})'
)
@@ -1014,33 +995,21 @@ class LeConnectionOrientedChannel(EventEmitter):
LE Credit-based Connection Oriented Channel
"""
- INIT = 0
- CONNECTED = 1
- CONNECTING = 2
- DISCONNECTING = 3
- DISCONNECTED = 4
- CONNECTION_ERROR = 5
-
- STATE_NAMES = {
- INIT: 'INIT',
- CONNECTED: 'CONNECTED',
- CONNECTING: 'CONNECTING',
- DISCONNECTING: 'DISCONNECTING',
- DISCONNECTED: 'DISCONNECTED',
- CONNECTION_ERROR: 'CONNECTION_ERROR',
- }
+ class State(enum.IntEnum):
+ INIT = 0
+ CONNECTED = 1
+ CONNECTING = 2
+ DISCONNECTING = 3
+ DISCONNECTED = 4
+ CONNECTION_ERROR = 5
out_queue: Deque[bytes]
connection_result: Optional[asyncio.Future[LeConnectionOrientedChannel]]
disconnection_result: Optional[asyncio.Future[None]]
out_sdu: Optional[bytes]
- state: int
+ state: State
connection: Connection
- @staticmethod
- def state_name(state: int) -> str:
- return name_or_number(LeConnectionOrientedChannel.STATE_NAMES, state)
-
def __init__(
self,
manager: ChannelManager,
@@ -1083,19 +1052,17 @@ class LeConnectionOrientedChannel(EventEmitter):
self.drained.set()
if connected:
- self.state = LeConnectionOrientedChannel.CONNECTED
+ self.state = self.State.CONNECTED
else:
- self.state = LeConnectionOrientedChannel.INIT
+ self.state = self.State.INIT
- def change_state(self, new_state: int) -> None:
- logger.debug(
- f'{self} state change -> {color(self.state_name(new_state), "cyan")}'
- )
+ def _change_state(self, new_state: State) -> None:
+ logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}')
self.state = new_state
- if new_state == self.CONNECTED:
+ if new_state == self.State.CONNECTED:
self.emit('open')
- elif new_state == self.DISCONNECTED:
+ elif new_state == self.State.DISCONNECTED:
self.emit('close')
def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
@@ -1106,7 +1073,7 @@ class LeConnectionOrientedChannel(EventEmitter):
async def connect(self) -> LeConnectionOrientedChannel:
# Check that we're in the right state
- if self.state != self.INIT:
+ if self.state != self.State.INIT:
raise InvalidStateError('not in a connectable state')
# Check that we can start a new connection
@@ -1114,7 +1081,7 @@ class LeConnectionOrientedChannel(EventEmitter):
if identifier in self.manager.le_coc_requests:
raise RuntimeError('too many concurrent connection requests')
- self.change_state(self.CONNECTING)
+ self._change_state(self.State.CONNECTING)
request = L2CAP_LE_Credit_Based_Connection_Request(
identifier=identifier,
le_psm=self.le_psm,
@@ -1134,10 +1101,10 @@ class LeConnectionOrientedChannel(EventEmitter):
async def disconnect(self) -> None:
# Check that we're connected
- if self.state != self.CONNECTED:
+ if self.state != self.State.CONNECTED:
raise InvalidStateError('not connected')
- self.change_state(self.DISCONNECTING)
+ self._change_state(self.State.DISCONNECTING)
self.flush_output()
self.send_control_frame(
L2CAP_Disconnection_Request(
@@ -1153,15 +1120,15 @@ class LeConnectionOrientedChannel(EventEmitter):
return await self.disconnection_result
def abort(self) -> None:
- if self.state == self.CONNECTED:
- self.change_state(self.DISCONNECTED)
+ if self.state == self.State.CONNECTED:
+ self._change_state(self.State.DISCONNECTED)
def on_pdu(self, pdu: bytes) -> None:
if self.sink is None:
logger.warning('received pdu without a sink')
return
- if self.state != self.CONNECTED:
+ if self.state != self.State.CONNECTED:
logger.warning('received PDU while not connected, dropping')
# Manage the peer credits
@@ -1240,7 +1207,7 @@ class LeConnectionOrientedChannel(EventEmitter):
self.credits = response.initial_credits
self.connected = True
self.connection_result.set_result(self)
- self.change_state(self.CONNECTED)
+ self._change_state(self.State.CONNECTED)
else:
self.connection_result.set_exception(
ProtocolError(
@@ -1251,7 +1218,7 @@ class LeConnectionOrientedChannel(EventEmitter):
),
)
)
- self.change_state(self.CONNECTION_ERROR)
+ self._change_state(self.State.CONNECTION_ERROR)
# Cleanup
self.connection_result = None
@@ -1271,11 +1238,11 @@ class LeConnectionOrientedChannel(EventEmitter):
source_cid=request.source_cid,
)
)
- self.change_state(self.DISCONNECTED)
+ self._change_state(self.State.DISCONNECTED)
self.flush_output()
def on_disconnection_response(self, response) -> None:
- if self.state != self.DISCONNECTING:
+ if self.state != self.State.DISCONNECTING:
logger.warning(color('invalid state', 'red'))
return
@@ -1286,7 +1253,7 @@ class LeConnectionOrientedChannel(EventEmitter):
logger.warning('unexpected source or destination CID')
return
- self.change_state(self.DISCONNECTED)
+ self._change_state(self.State.DISCONNECTED)
if self.disconnection_result:
self.disconnection_result.set_result(None)
self.disconnection_result = None
@@ -1339,7 +1306,7 @@ class LeConnectionOrientedChannel(EventEmitter):
return
def write(self, data: bytes) -> None:
- if self.state != self.CONNECTED:
+ if self.state != self.State.CONNECTED:
logger.warning('not connected, dropping data')
return
@@ -1367,7 +1334,7 @@ class LeConnectionOrientedChannel(EventEmitter):
def __str__(self) -> str:
return (
f'CoC({self.source_cid}->{self.destination_cid}, '
- f'State={self.state_name(self.state)}, '
+ f'State={self.state.name}, '
f'PSM={self.le_psm}, '
f'MTU={self.mtu}/{self.peer_mtu}, '
f'MPS={self.mps}/{self.peer_mps}, '