Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions bumble/a2dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,26 +267,27 @@ class MediaCodecInformation:
def create(
cls, media_codec_type: int, data: bytes
) -> MediaCodecInformation | bytes:
if media_codec_type == CodecType.SBC:
return SbcMediaCodecInformation.from_bytes(data)
elif media_codec_type == CodecType.MPEG_2_4_AAC:
return AacMediaCodecInformation.from_bytes(data)
elif media_codec_type == CodecType.NON_A2DP:
vendor_media_codec_information = (
VendorSpecificMediaCodecInformation.from_bytes(data)
)
if (
vendor_class_map := A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES.get(
vendor_media_codec_information.vendor_id
)
) and (
media_codec_information_class := vendor_class_map.get(
vendor_media_codec_information.codec_id
)
):
return media_codec_information_class.from_bytes(
vendor_media_codec_information.value
match media_codec_type:
case CodecType.SBC:
return SbcMediaCodecInformation.from_bytes(data)
case CodecType.MPEG_2_4_AAC:
return AacMediaCodecInformation.from_bytes(data)
case CodecType.NON_A2DP:
vendor_media_codec_information = (
VendorSpecificMediaCodecInformation.from_bytes(data)
)
if (
vendor_class_map := A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES.get(
vendor_media_codec_information.vendor_id
)
) and (
media_codec_information_class := vendor_class_map.get(
vendor_media_codec_information.codec_id
)
):
return media_codec_information_class.from_bytes(
vendor_media_codec_information.value
)
return vendor_media_codec_information

@classmethod
Expand Down
62 changes: 32 additions & 30 deletions bumble/at.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def tokenize_parameters(buffer: bytes) -> list[bytes]:
are ignored [..], unless they are embedded in numeric or string constants"
Raises AtParsingError in case of invalid input string."""

tokens = []
tokens: list[bytearray] = []
in_quotes = False
token = bytearray()
for b in buffer:
Expand All @@ -40,23 +40,24 @@ def tokenize_parameters(buffer: bytes) -> list[bytes]:
tokens.append(token[1:-1])
token = bytearray()
else:
if char == b' ':
pass
elif char == b',' or char == b')':
tokens.append(token)
tokens.append(char)
token = bytearray()
elif char == b'(':
if len(token) > 0:
raise AtParsingError("open_paren following regular character")
tokens.append(char)
elif char == b'"':
if len(token) > 0:
raise AtParsingError("quote following regular character")
in_quotes = True
token.extend(char)
else:
token.extend(char)
match char:
case b' ':
pass
case b',' | b')':
tokens.append(token)
tokens.append(char)
token = bytearray()
case b'(':
if len(token) > 0:
raise AtParsingError("open_paren following regular character")
tokens.append(char)
case b'"':
if len(token) > 0:
raise AtParsingError("quote following regular character")
in_quotes = True
token.extend(char)
case _:
token.extend(char)

tokens.append(token)
return [bytes(token) for token in tokens if len(token) > 0]
Expand All @@ -71,18 +72,19 @@ def parse_parameters(buffer: bytes) -> list[bytes | list]:
current: bytes | list = b''

for token in tokens:
if token == b',':
accumulator[-1].append(current)
current = b''
elif token == b'(':
accumulator.append([])
elif token == b')':
if len(accumulator) < 2:
raise AtParsingError("close_paren without matching open_paren")
accumulator[-1].append(current)
current = accumulator.pop()
else:
current = token
match token:
case b',':
accumulator[-1].append(current)
current = b''
case b'(':
accumulator.append([])
case b')':
if len(accumulator) < 2:
raise AtParsingError("close_paren without matching open_paren")
accumulator[-1].append(current)
current = accumulator.pop()
case _:
current = token

accumulator[-1].append(current)
if len(accumulator) > 1:
Expand Down
103 changes: 53 additions & 50 deletions bumble/att.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,12 +954,13 @@ def __init__(
self.permissions = permissions

# Convert the type to a UUID object if it isn't already
if isinstance(attribute_type, str):
self.type = UUID(attribute_type)
elif isinstance(attribute_type, bytes):
self.type = UUID.from_bytes(attribute_type)
else:
self.type = attribute_type
match attribute_type:
case str():
self.type = UUID(attribute_type)
case bytes():
self.type = UUID.from_bytes(attribute_type)
case _:
self.type = attribute_type

self.value = value

Expand Down Expand Up @@ -994,30 +995,31 @@ async def read_value(self, bearer: Bearer) -> bytes:
)

value: _T | None
if isinstance(self.value, AttributeValue):
try:
read_value = self.value.read(connection)
if inspect.isawaitable(read_value):
value = await read_value
else:
value = read_value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
elif isinstance(self.value, AttributeValueV2):
try:
read_value = self.value.read(bearer)
if inspect.isawaitable(read_value):
value = await read_value
else:
value = read_value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
else:
value = self.value
match self.value:
case AttributeValue():
try:
read_value = self.value.read(connection)
if inspect.isawaitable(read_value):
value = await read_value
else:
value = read_value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
case AttributeValueV2():
try:
read_value = self.value.read(bearer)
if inspect.isawaitable(read_value):
value = await read_value
else:
value = read_value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
case _:
value = self.value

self.emit(self.EVENT_READ, connection, b'' if value is None else value)

Expand Down Expand Up @@ -1049,26 +1051,27 @@ async def write_value(self, bearer: Bearer, value: bytes) -> None:

decoded_value = self.decode_value(value)

if isinstance(self.value, AttributeValue):
try:
result = self.value.write(connection, decoded_value)
if inspect.isawaitable(result):
await result
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
elif isinstance(self.value, AttributeValueV2):
try:
result = self.value.write(bearer, decoded_value)
if inspect.isawaitable(result):
await result
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
else:
self.value = decoded_value
match self.value:
case AttributeValue():
try:
result = self.value.write(connection, decoded_value)
if inspect.isawaitable(result):
await result
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
case AttributeValueV2():
try:
result = self.value.write(bearer, decoded_value)
if inspect.isawaitable(result):
await result
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
case _:
self.value = decoded_value

self.emit(self.EVENT_WRITE, connection, decoded_value)

Expand Down
Loading
Loading