diff --git a/bumble/a2dp.py b/bumble/a2dp.py index 4c9eb320..7daf775c 100644 --- a/bumble/a2dp.py +++ b/bumble/a2dp.py @@ -267,26 +267,27 @@ class MediaCodecInformation: def create( cls, media_codec_type: int, data: bytes ) -> MediaCodecInformation | bytes: - if media_codec_type == CodecType.SBC: - return SbcMediaCodecInformation.from_bytes(data) - elif media_codec_type == CodecType.MPEG_2_4_AAC: - return AacMediaCodecInformation.from_bytes(data) - elif media_codec_type == CodecType.NON_A2DP: - vendor_media_codec_information = ( - VendorSpecificMediaCodecInformation.from_bytes(data) - ) - if ( - vendor_class_map := A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES.get( - vendor_media_codec_information.vendor_id - ) - ) and ( - media_codec_information_class := vendor_class_map.get( - vendor_media_codec_information.codec_id - ) - ): - return media_codec_information_class.from_bytes( - vendor_media_codec_information.value + match media_codec_type: + case CodecType.SBC: + return SbcMediaCodecInformation.from_bytes(data) + case CodecType.MPEG_2_4_AAC: + return AacMediaCodecInformation.from_bytes(data) + case CodecType.NON_A2DP: + vendor_media_codec_information = ( + VendorSpecificMediaCodecInformation.from_bytes(data) ) + if ( + vendor_class_map := A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES.get( + vendor_media_codec_information.vendor_id + ) + ) and ( + media_codec_information_class := vendor_class_map.get( + vendor_media_codec_information.codec_id + ) + ): + return media_codec_information_class.from_bytes( + vendor_media_codec_information.value + ) return vendor_media_codec_information @classmethod diff --git a/bumble/at.py b/bumble/at.py index 9fe85aec..d49cf9cd 100644 --- a/bumble/at.py +++ b/bumble/at.py @@ -27,7 +27,7 @@ def tokenize_parameters(buffer: bytes) -> list[bytes]: are ignored [..], unless they are embedded in numeric or string constants" Raises AtParsingError in case of invalid input string.""" - tokens = [] + tokens: list[bytearray] = [] in_quotes = False token = bytearray() for b in buffer: @@ -40,23 +40,24 @@ def tokenize_parameters(buffer: bytes) -> list[bytes]: tokens.append(token[1:-1]) token = bytearray() else: - if char == b' ': - pass - elif char == b',' or char == b')': - tokens.append(token) - tokens.append(char) - token = bytearray() - elif char == b'(': - if len(token) > 0: - raise AtParsingError("open_paren following regular character") - tokens.append(char) - elif char == b'"': - if len(token) > 0: - raise AtParsingError("quote following regular character") - in_quotes = True - token.extend(char) - else: - token.extend(char) + match char: + case b' ': + pass + case b',' | b')': + tokens.append(token) + tokens.append(char) + token = bytearray() + case b'(': + if len(token) > 0: + raise AtParsingError("open_paren following regular character") + tokens.append(char) + case b'"': + if len(token) > 0: + raise AtParsingError("quote following regular character") + in_quotes = True + token.extend(char) + case _: + token.extend(char) tokens.append(token) return [bytes(token) for token in tokens if len(token) > 0] @@ -71,18 +72,19 @@ def parse_parameters(buffer: bytes) -> list[bytes | list]: current: bytes | list = b'' for token in tokens: - if token == b',': - accumulator[-1].append(current) - current = b'' - elif token == b'(': - accumulator.append([]) - elif token == b')': - if len(accumulator) < 2: - raise AtParsingError("close_paren without matching open_paren") - accumulator[-1].append(current) - current = accumulator.pop() - else: - current = token + match token: + case b',': + accumulator[-1].append(current) + current = b'' + case b'(': + accumulator.append([]) + case b')': + if len(accumulator) < 2: + raise AtParsingError("close_paren without matching open_paren") + accumulator[-1].append(current) + current = accumulator.pop() + case _: + current = token accumulator[-1].append(current) if len(accumulator) > 1: diff --git a/bumble/att.py b/bumble/att.py index 60e9b5c5..07ebe868 100644 --- a/bumble/att.py +++ b/bumble/att.py @@ -954,12 +954,13 @@ def __init__( self.permissions = permissions # Convert the type to a UUID object if it isn't already - if isinstance(attribute_type, str): - self.type = UUID(attribute_type) - elif isinstance(attribute_type, bytes): - self.type = UUID.from_bytes(attribute_type) - else: - self.type = attribute_type + match attribute_type: + case str(): + self.type = UUID(attribute_type) + case bytes(): + self.type = UUID.from_bytes(attribute_type) + case _: + self.type = attribute_type self.value = value @@ -994,30 +995,31 @@ async def read_value(self, bearer: Bearer) -> bytes: ) value: _T | None - if isinstance(self.value, AttributeValue): - try: - read_value = self.value.read(connection) - if inspect.isawaitable(read_value): - value = await read_value - else: - value = read_value - except ATT_Error as error: - raise ATT_Error( - error_code=error.error_code, att_handle=self.handle - ) from error - elif isinstance(self.value, AttributeValueV2): - try: - read_value = self.value.read(bearer) - if inspect.isawaitable(read_value): - value = await read_value - else: - value = read_value - except ATT_Error as error: - raise ATT_Error( - error_code=error.error_code, att_handle=self.handle - ) from error - else: - value = self.value + match self.value: + case AttributeValue(): + try: + read_value = self.value.read(connection) + if inspect.isawaitable(read_value): + value = await read_value + else: + value = read_value + except ATT_Error as error: + raise ATT_Error( + error_code=error.error_code, att_handle=self.handle + ) from error + case AttributeValueV2(): + try: + read_value = self.value.read(bearer) + if inspect.isawaitable(read_value): + value = await read_value + else: + value = read_value + except ATT_Error as error: + raise ATT_Error( + error_code=error.error_code, att_handle=self.handle + ) from error + case _: + value = self.value self.emit(self.EVENT_READ, connection, b'' if value is None else value) @@ -1049,26 +1051,27 @@ async def write_value(self, bearer: Bearer, value: bytes) -> None: decoded_value = self.decode_value(value) - if isinstance(self.value, AttributeValue): - try: - result = self.value.write(connection, decoded_value) - if inspect.isawaitable(result): - await result - except ATT_Error as error: - raise ATT_Error( - error_code=error.error_code, att_handle=self.handle - ) from error - elif isinstance(self.value, AttributeValueV2): - try: - result = self.value.write(bearer, decoded_value) - if inspect.isawaitable(result): - await result - except ATT_Error as error: - raise ATT_Error( - error_code=error.error_code, att_handle=self.handle - ) from error - else: - self.value = decoded_value + match self.value: + case AttributeValue(): + try: + result = self.value.write(connection, decoded_value) + if inspect.isawaitable(result): + await result + except ATT_Error as error: + raise ATT_Error( + error_code=error.error_code, att_handle=self.handle + ) from error + case AttributeValueV2(): + try: + result = self.value.write(bearer, decoded_value) + if inspect.isawaitable(result): + await result + except ATT_Error as error: + raise ATT_Error( + error_code=error.error_code, att_handle=self.handle + ) from error + case _: + self.value = decoded_value self.emit(self.EVENT_WRITE, connection, decoded_value) diff --git a/bumble/controller.py b/bumble/controller.py index 9fbe80ae..fcb35de5 100644 --- a/bumble/controller.py +++ b/bumble/controller.py @@ -403,14 +403,15 @@ def on_hci_packet(self, packet: hci.HCI_Packet) -> None: ) # If the packet is a command, invoke the handler for this packet - if isinstance(packet, hci.HCI_Command): - self.on_hci_command_packet(packet) - elif isinstance(packet, hci.HCI_AclDataPacket): - self.on_hci_acl_data_packet(packet) - elif isinstance(packet, hci.HCI_Event): - self.on_hci_event_packet(packet) - else: - logger.warning(f'!!! unknown packet type {packet.hci_packet_type}') + match packet: + case hci.HCI_Command(): + self.on_hci_command_packet(packet) + case hci.HCI_AclDataPacket(): + self.on_hci_acl_data_packet(packet) + case hci.HCI_Event(): + self.on_hci_event_packet(packet) + case _: + logger.warning(f'!!! unknown packet type {packet.hci_packet_type}') def on_hci_command_packet(self, command: hci.HCI_Command) -> None: handler_name = f'on_{command.name.lower()}' @@ -517,26 +518,28 @@ def on_ll_control_pdu( logger.error("Cannot find a connection for %s", sender_address) return - if isinstance(packet, ll.TerminateInd): - self.on_le_disconnected(connection, packet.error_code) - elif isinstance(packet, ll.CisReq): - self.on_le_cis_request(connection, packet.cig_id, packet.cis_id) - elif isinstance(packet, ll.CisRsp): - self.on_le_cis_established(packet.cig_id, packet.cis_id) - connection.send_ll_control_pdu(ll.CisInd(packet.cig_id, packet.cis_id)) - elif isinstance(packet, ll.CisInd): - self.on_le_cis_established(packet.cig_id, packet.cis_id) - elif isinstance(packet, ll.CisTerminateInd): - self.on_le_cis_disconnected(packet.cig_id, packet.cis_id) - elif isinstance(packet, ll.EncReq): - self.on_le_encrypted(connection) + match packet: + case ll.TerminateInd(): + self.on_le_disconnected(connection, packet.error_code) + case ll.CisReq(): + self.on_le_cis_request(connection, packet.cig_id, packet.cis_id) + case ll.CisRsp(): + self.on_le_cis_established(packet.cig_id, packet.cis_id) + connection.send_ll_control_pdu(ll.CisInd(packet.cig_id, packet.cis_id)) + case ll.CisInd(): + self.on_le_cis_established(packet.cig_id, packet.cis_id) + case ll.CisTerminateInd(): + self.on_le_cis_disconnected(packet.cig_id, packet.cis_id) + case ll.EncReq(): + self.on_le_encrypted(connection) def on_ll_advertising_pdu(self, packet: ll.AdvertisingPdu) -> None: logger.debug("[%s] <<< Advertising PDU: %s", self.name, packet) - if isinstance(packet, ll.ConnectInd): - self.on_le_connect_ind(packet) - elif isinstance(packet, (ll.AdvInd, ll.AdvExtInd)): - self.on_advertising_pdu(packet) + match packet: + case ll.ConnectInd(): + self.on_le_connect_ind(packet) + case ll.AdvInd() | ll.AdvExtInd(): + self.on_advertising_pdu(packet) def on_le_connect_ind(self, packet: ll.ConnectInd) -> None: ''' @@ -894,51 +897,52 @@ def send_lmp_packet( return future def on_lmp_packet(self, sender_address: hci.Address, packet: lmp.Packet): - if isinstance(packet, (lmp.LmpAccepted, lmp.LmpAcceptedExt)): - if future := self.classic_pending_commands.setdefault( - sender_address, {} - ).get(packet.response_opcode): - future.set_result(hci.HCI_SUCCESS) - else: - logger.error("!!! Unhandled packet: %s", packet) - elif isinstance(packet, (lmp.LmpNotAccepted, lmp.LmpNotAcceptedExt)): - if future := self.classic_pending_commands.setdefault( - sender_address, {} - ).get(packet.response_opcode): - future.set_result(packet.error_code) - else: + match packet: + case lmp.LmpAccepted() | lmp.LmpAcceptedExt(): + if future := self.classic_pending_commands.setdefault( + sender_address, {} + ).get(packet.response_opcode): + future.set_result(hci.HCI_SUCCESS) + else: + logger.error("!!! Unhandled packet: %s", packet) + case lmp.LmpNotAccepted() | lmp.LmpNotAcceptedExt(): + if future := self.classic_pending_commands.setdefault( + sender_address, {} + ).get(packet.response_opcode): + future.set_result(packet.error_code) + else: + logger.error("!!! Unhandled packet: %s", packet) + case lmp.LmpHostConnectionReq(): + self.on_classic_connection_request( + sender_address, hci.HCI_Connection_Complete_Event.LinkType.ACL + ) + case lmp.LmpScoLinkReq(): + self.on_classic_connection_request( + sender_address, hci.HCI_Connection_Complete_Event.LinkType.SCO + ) + case lmp.LmpEscoLinkReq(): + self.on_classic_connection_request( + sender_address, hci.HCI_Connection_Complete_Event.LinkType.ESCO + ) + case lmp.LmpDetach(): + self.on_classic_disconnected( + sender_address, hci.HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR + ) + case lmp.LmpSwitchReq(): + self.on_classic_role_change_request(sender_address) + case lmp.LmpRemoveScoLinkReq() | lmp.LmpRemoveEscoLinkReq(): + self.on_classic_sco_disconnected(sender_address, packet.error_code) + case lmp.LmpNameReq(): + self.on_classic_remote_name_request(sender_address, packet.name_offset) + case lmp.LmpNameRes(): + self.on_classic_remote_name_response( + sender_address, + packet.name_offset, + packet.name_length, + packet.name_fregment, + ) + case _: logger.error("!!! Unhandled packet: %s", packet) - elif isinstance(packet, (lmp.LmpHostConnectionReq)): - self.on_classic_connection_request( - sender_address, hci.HCI_Connection_Complete_Event.LinkType.ACL - ) - elif isinstance(packet, (lmp.LmpScoLinkReq)): - self.on_classic_connection_request( - sender_address, hci.HCI_Connection_Complete_Event.LinkType.SCO - ) - elif isinstance(packet, (lmp.LmpEscoLinkReq)): - self.on_classic_connection_request( - sender_address, hci.HCI_Connection_Complete_Event.LinkType.ESCO - ) - elif isinstance(packet, (lmp.LmpDetach)): - self.on_classic_disconnected( - sender_address, hci.HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR - ) - elif isinstance(packet, (lmp.LmpSwitchReq)): - self.on_classic_role_change_request(sender_address) - elif isinstance(packet, (lmp.LmpRemoveScoLinkReq, lmp.LmpRemoveEscoLinkReq)): - self.on_classic_sco_disconnected(sender_address, packet.error_code) - elif isinstance(packet, lmp.LmpNameReq): - self.on_classic_remote_name_request(sender_address, packet.name_offset) - elif isinstance(packet, lmp.LmpNameRes): - self.on_classic_remote_name_response( - sender_address, - packet.name_offset, - packet.name_length, - packet.name_fregment, - ) - else: - logger.error("!!! Unhandled packet: %s", packet) def on_classic_connection_request( self, peer_address: hci.Address, link_type: int diff --git a/bumble/core.py b/bumble/core.py index 8f68a2b4..3be09deb 100644 --- a/bumble/core.py +++ b/bumble/core.py @@ -280,14 +280,15 @@ def to_bytes(self, force_128: bool = False) -> bytes: if not force_128: return self.uuid_bytes - if len(self.uuid_bytes) == 2: - return self.BASE_UUID + self.uuid_bytes + bytes([0, 0]) - elif len(self.uuid_bytes) == 4: - return self.BASE_UUID + self.uuid_bytes - elif len(self.uuid_bytes) == 16: - return self.uuid_bytes - else: - assert False, "unreachable" + match len(self.uuid_bytes): + case 2: + return self.BASE_UUID + self.uuid_bytes + bytes([0, 0]) + case 4: + return self.BASE_UUID + self.uuid_bytes + case 16: + return self.uuid_bytes + case _: + assert False, "unreachable" def to_pdu_bytes(self) -> bytes: ''' @@ -1769,66 +1770,71 @@ def uuid_list_to_string(ad_data, uuid_size): @classmethod def ad_data_to_string(cls, ad_type: int, ad_data: bytes) -> str: - if ad_type == AdvertisingData.FLAGS: - ad_type_str = 'Flags' - ad_data_str = AdvertisingData.flags_to_string(ad_data[0], short=True) - elif ad_type == AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS: - ad_type_str = 'Complete List of 16-bit Service Class UUIDs' - ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2) - elif ad_type == AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS: - ad_type_str = 'Incomplete List of 16-bit Service Class UUIDs' - ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2) - elif ad_type == AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS: - ad_type_str = 'Complete List of 32-bit Service Class UUIDs' - ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4) - elif ad_type == AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS: - ad_type_str = 'Incomplete List of 32-bit Service Class UUIDs' - ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4) - elif ad_type == AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS: - ad_type_str = 'Complete List of 128-bit Service Class UUIDs' - ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16) - elif ad_type == AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS: - ad_type_str = 'Incomplete List of 128-bit Service Class UUIDs' - ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16) - elif ad_type == AdvertisingData.SERVICE_DATA_16_BIT_UUID: - ad_type_str = 'Service Data' - uuid = UUID.from_bytes(ad_data[:2]) - ad_data_str = f'service={uuid}, data={ad_data[2:].hex()}' - elif ad_type == AdvertisingData.SERVICE_DATA_32_BIT_UUID: - ad_type_str = 'Service Data' - uuid = UUID.from_bytes(ad_data[:4]) - ad_data_str = f'service={uuid}, data={ad_data[4:].hex()}' - elif ad_type == AdvertisingData.SERVICE_DATA_128_BIT_UUID: - ad_type_str = 'Service Data' - uuid = UUID.from_bytes(ad_data[:16]) - ad_data_str = f'service={uuid}, data={ad_data[16:].hex()}' - elif ad_type == AdvertisingData.SHORTENED_LOCAL_NAME: - ad_type_str = 'Shortened Local Name' - ad_data_str = f'"{ad_data.decode("utf-8")}"' - elif ad_type == AdvertisingData.COMPLETE_LOCAL_NAME: - ad_type_str = 'Complete Local Name' - try: + match ad_type: + case AdvertisingData.FLAGS: + ad_type_str = 'Flags' + ad_data_str = AdvertisingData.flags_to_string(ad_data[0], short=True) + case AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS: + ad_type_str = 'Complete List of 16-bit Service Class UUIDs' + ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2) + case AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS: + ad_type_str = 'Incomplete List of 16-bit Service Class UUIDs' + ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2) + case AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS: + ad_type_str = 'Complete List of 32-bit Service Class UUIDs' + ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4) + case AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS: + ad_type_str = 'Incomplete List of 32-bit Service Class UUIDs' + ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4) + case AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS: + ad_type_str = 'Complete List of 128-bit Service Class UUIDs' + ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16) + case AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS: + ad_type_str = 'Incomplete List of 128-bit Service Class UUIDs' + ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16) + case AdvertisingData.SERVICE_DATA_16_BIT_UUID: + ad_type_str = 'Service Data' + uuid = UUID.from_bytes(ad_data[:2]) + ad_data_str = f'service={uuid}, data={ad_data[2:].hex()}' + case AdvertisingData.SERVICE_DATA_32_BIT_UUID: + ad_type_str = 'Service Data' + uuid = UUID.from_bytes(ad_data[:4]) + ad_data_str = f'service={uuid}, data={ad_data[4:].hex()}' + case AdvertisingData.SERVICE_DATA_128_BIT_UUID: + ad_type_str = 'Service Data' + uuid = UUID.from_bytes(ad_data[:16]) + ad_data_str = f'service={uuid}, data={ad_data[16:].hex()}' + case AdvertisingData.SHORTENED_LOCAL_NAME: + ad_type_str = 'Shortened Local Name' ad_data_str = f'"{ad_data.decode("utf-8")}"' - except UnicodeDecodeError: + case AdvertisingData.COMPLETE_LOCAL_NAME: + ad_type_str = 'Complete Local Name' + try: + ad_data_str = f'"{ad_data.decode("utf-8")}"' + except UnicodeDecodeError: + ad_data_str = ad_data.hex() + case AdvertisingData.TX_POWER_LEVEL: + ad_type_str = 'TX Power Level' + ad_data_str = str(ad_data[0]) + case AdvertisingData.MANUFACTURER_SPECIFIC_DATA: + ad_type_str = 'Manufacturer Specific Data' + company_id = struct.unpack_from(' list[tuple[ValueType, Any]]: value = data[2 : 2 + value_length] typed_value: Any - if value_type == ValueType.END: - break - - if value_type in (ValueType.CNVI, ValueType.CNVR): - (v,) = struct.unpack("> 0) & 0xF) << 12) - | (((v >> 4) & 0xF) << 0) - | (((v >> 8) & 0xF) << 4) - | (((v >> 24) & 0xF) << 8) - ) - elif value_type == ValueType.HARDWARE_INFO: - (v,) = struct.unpack("> 8) & 0xFF), HardwareVariant((v >> 16) & 0x3F) - ) - elif value_type in ( - ValueType.USB_VENDOR_ID, - ValueType.USB_PRODUCT_ID, - ValueType.DEVICE_REVISION, - ): - (typed_value,) = struct.unpack("> 0) & 0xF) << 12) + | (((v >> 4) & 0xF) << 0) + | (((v >> 8) & 0xF) << 4) + | (((v >> 24) & 0xF) << 8) + ) + case ValueType.HARDWARE_INFO: + (v,) = struct.unpack("> 8) & 0xFF), HardwareVariant((v >> 16) & 0x3F) + ) + case ( + ValueType.USB_VENDOR_ID + | ValueType.USB_PRODUCT_ID + | ValueType.DEVICE_REVISION + ): + (typed_value,) = struct.unpack("2': - # 16-bit unsigned big-endian - return (struct.unpack_from('>H', data, offset)[0], 2) - if field_type == -2: - # 16-bit signed - return (struct.unpack_from('4': - # 32-bit unsigned big-endian - return (struct.unpack_from('>I', data, offset)[0], 4) - if isinstance(field_type, int) and 4 < field_type <= 256: - # Byte array (from 5 up to 256 bytes) - return (data[offset : offset + field_type], field_type) + match field_type: + case '*': + # The rest of the bytes + field_value = data[offset:] + return (field_value, len(field_value)) + case 'v': + # Variable-length bytes field, with 1-byte length at the beginning + field_length = data[offset] + offset += 1 + field_value = data[offset : offset + field_length] + return (field_value, field_length + 1) + case 1: + # 8-bit unsigned + return (data[offset], 1) + case -1: + # 8-bit signed + return (struct.unpack_from('b', data, offset)[0], 1) + case 2: + # 16-bit unsigned + return (struct.unpack_from('2': + # 16-bit unsigned big-endian + return (struct.unpack_from('>H', data, offset)[0], 2) + case -2: + # 16-bit signed + return (struct.unpack_from('4': + # 32-bit unsigned big-endian + return (struct.unpack_from('>I', data, offset)[0], 4) + case int() if 4 < field_type <= 256: + # Byte array (from 5 up to 256 bytes) + return (data[offset : offset + field_type], field_type) + if callable(field_type): new_offset, field_value = field_type(data, offset) return (field_value, new_offset - offset) @@ -1954,60 +1957,58 @@ def serialize_field(field_value: Any, field_type: FieldSpec) -> bytes: # Serialize the field if serializer: - field_bytes = serializer(field_value) - elif field_type == 1: - # 8-bit unsigned - field_bytes = bytes([field_value]) - elif field_type == -1: - # 8-bit signed - field_bytes = struct.pack('b', field_value) - elif field_type == 2: - # 16-bit unsigned - field_bytes = struct.pack('2': - # 16-bit unsigned big-endian - field_bytes = struct.pack('>H', field_value) - elif field_type == -2: - # 16-bit signed - field_bytes = struct.pack('4': - # 32-bit unsigned big-endian - field_bytes = struct.pack('>I', field_value) - elif field_type == '*': - if isinstance(field_value, int): - if 0 <= field_value <= 255: - field_bytes = bytes([field_value]) + return serializer(field_value) + match field_type: + case 1: + # 8-bit unsigned + return bytes([field_value]) + case -1: + # 8-bit signed + return struct.pack('b', field_value) + case 2: + # 16-bit unsigned + return struct.pack('2': + # 16-bit unsigned big-endian + return struct.pack('>H', field_value) + case -2: + # 16-bit signed + return struct.pack('4': + # 32-bit unsigned big-endian + return struct.pack('>I', field_value) + case '*': + if isinstance(field_value, int): + if 0 <= field_value <= 255: + return bytes([field_value]) + else: + raise InvalidArgumentError('value too large for *-typed field') else: - raise InvalidArgumentError('value too large for *-typed field') - else: + return bytes(field_value) + case 'v': + # Variable-length bytes field, with 1-byte length at the beginning field_bytes = bytes(field_value) - elif field_type == 'v': - # Variable-length bytes field, with 1-byte length at the beginning - field_bytes = bytes(field_value) - field_length = len(field_bytes) - field_bytes = bytes([field_length]) + field_bytes - elif isinstance(field_value, (bytes, bytearray)) or hasattr( - field_value, '__bytes__' - ): + field_length = len(field_bytes) + return bytes([field_length]) + field_bytes + if isinstance(field_value, (bytes, bytearray, SupportsBytes)): field_bytes = bytes(field_value) if isinstance(field_type, int) and 4 < field_type <= 256: # Truncate or pad with zeros if the field is too long or too short if len(field_bytes) < field_type: - field_bytes += bytes(field_type - len(field_bytes)) + return field_bytes + bytes(field_type - len(field_bytes)) elif len(field_bytes) > field_type: - field_bytes = field_bytes[:field_type] - else: - raise InvalidArgumentError( - f"don't know how to serialize type {type(field_value)}" - ) + return field_bytes[:field_type] + return field_bytes - return field_bytes + raise InvalidArgumentError( + f"don't know how to serialize type {type(field_value)}" + ) @staticmethod def dict_to_bytes(hci_object, object_fields): diff --git a/bumble/host.py b/bumble/host.py index 98fd131a..f0d0f144 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -22,7 +22,7 @@ import dataclasses import logging from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Any, TypeVar, cast, overload +from typing import TYPE_CHECKING, Any, TypeVar, overload from bumble import drivers, hci, utils from bumble.colors import color @@ -1002,18 +1002,19 @@ def on_hci_packet(self, packet: hci.HCI_Packet) -> None: self.snooper.snoop(bytes(packet), Snooper.Direction.CONTROLLER_TO_HOST) # If the packet is a command, invoke the handler for this packet - if packet.hci_packet_type == hci.HCI_COMMAND_PACKET: - self.on_hci_command_packet(cast(hci.HCI_Command, packet)) - elif packet.hci_packet_type == hci.HCI_EVENT_PACKET: - self.on_hci_event_packet(cast(hci.HCI_Event, packet)) - elif packet.hci_packet_type == hci.HCI_ACL_DATA_PACKET: - self.on_hci_acl_data_packet(cast(hci.HCI_AclDataPacket, packet)) - elif packet.hci_packet_type == hci.HCI_SYNCHRONOUS_DATA_PACKET: - self.on_hci_sco_data_packet(cast(hci.HCI_SynchronousDataPacket, packet)) - elif packet.hci_packet_type == hci.HCI_ISO_DATA_PACKET: - self.on_hci_iso_data_packet(cast(hci.HCI_IsoDataPacket, packet)) - else: - logger.warning(f'!!! unknown packet type {packet.hci_packet_type}') + match packet: + case hci.HCI_Command(): + self.on_hci_command_packet(packet) + case hci.HCI_Event(): + self.on_hci_event_packet(packet) + case hci.HCI_AclDataPacket(): + self.on_hci_acl_data_packet(packet) + case hci.HCI_SynchronousDataPacket(): + self.on_hci_sco_data_packet(packet) + case hci.HCI_IsoDataPacket(): + self.on_hci_iso_data_packet(packet) + case _: + logger.warning(f'!!! unknown packet type {packet.hci_packet_type}') def on_hci_command_packet(self, command: hci.HCI_Command) -> None: logger.warning(f'!!! unexpected command packet: {command}') diff --git a/bumble/profiles/ascs.py b/bumble/profiles/ascs.py index 1e9d5910..50113c3b 100644 --- a/bumble/profiles/ascs.py +++ b/bumble/profiles/ascs.py @@ -664,46 +664,44 @@ def on_write_ase_control_point( responses = [] logger.debug(f'*** ASCS Write {operation} ***') - if isinstance(operation, ASE_Config_Codec): - for ase_id, *args in zip( - operation.ase_id, - operation.target_latency, - operation.target_phy, - operation.codec_id, - operation.codec_specific_configuration, + match operation: + case ASE_Config_Codec(): + for ase_id, *args in zip( + operation.ase_id, + operation.target_latency, + operation.target_phy, + operation.codec_id, + operation.codec_specific_configuration, + ): + responses.append(self.on_operation(operation.op_code, ase_id, args)) + case ASE_Config_QOS(): + for ase_id, *args in zip( + operation.ase_id, + operation.cig_id, + operation.cis_id, + operation.sdu_interval, + operation.framing, + operation.phy, + operation.max_sdu, + operation.retransmission_number, + operation.max_transport_latency, + operation.presentation_delay, + ): + responses.append(self.on_operation(operation.op_code, ase_id, args)) + case ASE_Enable() | ASE_Update_Metadata(): + for ase_id, *args in zip( + operation.ase_id, + operation.metadata, + ): + responses.append(self.on_operation(operation.op_code, ase_id, args)) + case ( + ASE_Receiver_Start_Ready() + | ASE_Disable() + | ASE_Receiver_Stop_Ready() + | ASE_Release() ): - responses.append(self.on_operation(operation.op_code, ase_id, args)) - elif isinstance(operation, ASE_Config_QOS): - for ase_id, *args in zip( - operation.ase_id, - operation.cig_id, - operation.cis_id, - operation.sdu_interval, - operation.framing, - operation.phy, - operation.max_sdu, - operation.retransmission_number, - operation.max_transport_latency, - operation.presentation_delay, - ): - responses.append(self.on_operation(operation.op_code, ase_id, args)) - elif isinstance(operation, (ASE_Enable, ASE_Update_Metadata)): - for ase_id, *args in zip( - operation.ase_id, - operation.metadata, - ): - responses.append(self.on_operation(operation.op_code, ase_id, args)) - elif isinstance( - operation, - ( - ASE_Receiver_Start_Ready, - ASE_Disable, - ASE_Receiver_Stop_Ready, - ASE_Release, - ), - ): - for ase_id in operation.ase_id: - responses.append(self.on_operation(operation.op_code, ase_id, [])) + for ase_id in operation.ase_id: + responses.append(self.on_operation(operation.op_code, ase_id, [])) control_point_notification = bytes( [operation.op_code, len(responses)] diff --git a/bumble/profiles/bap.py b/bumble/profiles/bap.py index 49c2e3d7..b569f722 100644 --- a/bumble/profiles/bap.py +++ b/bumble/profiles/bap.py @@ -333,17 +333,18 @@ def from_bytes(cls, data: bytes) -> CodecSpecificCapabilities: value = int.from_bytes(data[offset : offset + length - 1], 'little') offset += length - 1 - if type == CodecSpecificCapabilities.Type.SAMPLING_FREQUENCY: - supported_sampling_frequencies = SupportedSamplingFrequency(value) - elif type == CodecSpecificCapabilities.Type.FRAME_DURATION: - supported_frame_durations = SupportedFrameDuration(value) - elif type == CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT: - supported_audio_channel_count = bits_to_channel_counts(value) - elif type == CodecSpecificCapabilities.Type.OCTETS_PER_FRAME: - min_octets_per_sample = value & 0xFFFF - max_octets_per_sample = value >> 16 - elif type == CodecSpecificCapabilities.Type.CODEC_FRAMES_PER_SDU: - supported_max_codec_frames_per_sdu = value + match type: + case CodecSpecificCapabilities.Type.SAMPLING_FREQUENCY: + supported_sampling_frequencies = SupportedSamplingFrequency(value) + case CodecSpecificCapabilities.Type.FRAME_DURATION: + supported_frame_durations = SupportedFrameDuration(value) + case CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT: + supported_audio_channel_count = bits_to_channel_counts(value) + case CodecSpecificCapabilities.Type.OCTETS_PER_FRAME: + min_octets_per_sample = value & 0xFFFF + max_octets_per_sample = value >> 16 + case CodecSpecificCapabilities.Type.CODEC_FRAMES_PER_SDU: + supported_max_codec_frames_per_sdu = value # It is expected here that if some fields are missing, an error should be raised. # pylint: disable=possibly-used-before-assignment,used-before-assignment diff --git a/bumble/profiles/gap.py b/bumble/profiles/gap.py index 3b818af6..9ff374da 100644 --- a/bumble/profiles/gap.py +++ b/bumble/profiles/gap.py @@ -55,14 +55,15 @@ class GenericAccessService(TemplateService): def __init__( self, device_name: str, appearance: Appearance | tuple[int, int] | int = 0 ): - if isinstance(appearance, int): - appearance_int = appearance - elif isinstance(appearance, tuple): - appearance_int = (appearance[0] << 6) | appearance[1] - elif isinstance(appearance, Appearance): - appearance_int = int(appearance) - else: - raise TypeError() + match appearance: + case int(): + appearance_int = appearance + case tuple(): + appearance_int = (appearance[0] << 6) | appearance[1] + case Appearance(): + appearance_int = int(appearance) + case _: + raise TypeError() self.device_name_characteristic = Characteristic( GATT_DEVICE_NAME_CHARACTERISTIC,