diff --git a/livekit-rtc/livekit/rtc/_utils.py b/livekit-rtc/livekit/rtc/_utils.py index 2342b21a..b24451cd 100644 --- a/livekit-rtc/livekit/rtc/_utils.py +++ b/livekit-rtc/livekit/rtc/_utils.py @@ -17,7 +17,7 @@ from collections import deque import ctypes import random -from typing import Callable, Generator, Generic, List, TypeVar +from typing import Callable, Generator, Generic, List, TypeVar, Union logger = logging.getLogger("livekit") @@ -40,8 +40,35 @@ def task_done_logger(task: asyncio.Task) -> None: return -def get_address(mv: memoryview) -> int: - return ctypes.addressof(ctypes.c_char.from_buffer(mv)) +def _buffer_supported_or_raise( + data: Union[bytes, bytearray, memoryview], +) -> None: + """Validate a buffer for FFI use. + + Raises clear errors for non-contiguous or sliced memoryviews. + """ + if isinstance(data, memoryview): + if not data.contiguous: + raise ValueError("memoryview must be contiguous") + if data.nbytes != len(data.obj): # type: ignore[arg-type] + raise ValueError("sliced memoryviews are not supported") + elif not isinstance(data, (bytes, bytearray)): + raise TypeError(f"expected bytes, bytearray, or memoryview, got {type(data)}") + + +def get_address(data) -> int: + if isinstance(data, memoryview): + _buffer_supported_or_raise(data) + if not data.readonly: + return ctypes.addressof(ctypes.c_char.from_buffer(data)) + data = data.obj + if isinstance(data, bytearray): + return ctypes.addressof(ctypes.c_char.from_buffer(data)) + if isinstance(data, bytes): + addr = ctypes.cast(ctypes.c_char_p(data), ctypes.c_void_p).value + assert addr is not None + return addr + raise TypeError(f"expected bytes, bytearray, or memoryview, got {type(data)}") T = TypeVar("T") diff --git a/livekit-rtc/livekit/rtc/apm.py b/livekit-rtc/livekit/rtc/apm.py index 043d12ff..efd203a6 100644 --- a/livekit-rtc/livekit/rtc/apm.py +++ b/livekit-rtc/livekit/rtc/apm.py @@ -48,12 +48,15 @@ def process_stream(self, data: AudioFrame) -> None: Important: Audio frames must be exactly 10 ms in duration. """ - bdata = data.data.cast("b") + if isinstance(data._data, bytes) or ( + isinstance(data._data, memoryview) and data._data.readonly + ): + data._data = bytearray(data._data) req = proto_ffi.FfiRequest() req.apm_process_stream.apm_handle = self._ffi_handle.handle - req.apm_process_stream.data_ptr = get_address(memoryview(bdata)) - req.apm_process_stream.size = len(bdata) + req.apm_process_stream.data_ptr = get_address(data._data) + req.apm_process_stream.size = len(data._data) req.apm_process_stream.sample_rate = data.sample_rate req.apm_process_stream.num_channels = data.num_channels @@ -73,12 +76,15 @@ def process_reverse_stream(self, data: AudioFrame) -> None: Important: Audio frames must be exactly 10 ms in duration. """ - bdata = data.data.cast("b") + if isinstance(data._data, bytes) or ( + isinstance(data._data, memoryview) and data._data.readonly + ): + data._data = bytearray(data._data) req = proto_ffi.FfiRequest() req.apm_process_reverse_stream.apm_handle = self._ffi_handle.handle - req.apm_process_reverse_stream.data_ptr = get_address(memoryview(bdata)) - req.apm_process_reverse_stream.size = len(bdata) + req.apm_process_reverse_stream.data_ptr = get_address(data._data) + req.apm_process_reverse_stream.size = len(data._data) req.apm_process_reverse_stream.sample_rate = data.sample_rate req.apm_process_reverse_stream.num_channels = data.num_channels diff --git a/livekit-rtc/livekit/rtc/audio_frame.py b/livekit-rtc/livekit/rtc/audio_frame.py index c4f1943c..50f63899 100644 --- a/livekit-rtc/livekit/rtc/audio_frame.py +++ b/livekit-rtc/livekit/rtc/audio_frame.py @@ -15,7 +15,7 @@ import ctypes from ._ffi_client import FfiHandle from ._proto import audio_frame_pb2 as proto_audio -from ._utils import get_address +from ._utils import _buffer_supported_or_raise, get_address from typing import Any, Union @@ -49,19 +49,21 @@ def __init__( Raises: ValueError: If the length of `data` is smaller than the required size. """ - data = memoryview(data).cast("B") + _buffer_supported_or_raise(data) - if len(data) < num_channels * samples_per_channel * ctypes.sizeof(ctypes.c_int16): + min_size = num_channels * samples_per_channel * ctypes.sizeof(ctypes.c_int16) + data_len = len(data) + + if data_len < min_size: raise ValueError( "data length must be >= num_channels * samples_per_channel * sizeof(int16)" ) - if len(data) % ctypes.sizeof(ctypes.c_int16) != 0: + if data_len % ctypes.sizeof(ctypes.c_int16) != 0: # can happen if data is bigger than needed raise ValueError("data length must be a multiple of sizeof(int16)") - n = len(data) // ctypes.sizeof(ctypes.c_int16) - self._data = (ctypes.c_int16 * n).from_buffer_copy(data) + self._data = data self._sample_rate = sample_rate self._num_channels = num_channels @@ -97,7 +99,7 @@ def _from_owned_info(owned_info: proto_audio.OwnedAudioFrameBuffer) -> "AudioFra def _proto_info(self) -> proto_audio.AudioFrameBufferInfo: audio_info = proto_audio.AudioFrameBufferInfo() - audio_info.data_ptr = get_address(memoryview(self._data)) + audio_info.data_ptr = get_address(self._data) audio_info.sample_rate = self.sample_rate audio_info.num_channels = self.num_channels audio_info.samples_per_channel = self.samples_per_channel diff --git a/livekit-rtc/livekit/rtc/video_frame.py b/livekit-rtc/livekit/rtc/video_frame.py index 76b66625..e6ff1ad1 100644 --- a/livekit-rtc/livekit/rtc/video_frame.py +++ b/livekit-rtc/livekit/rtc/video_frame.py @@ -18,7 +18,7 @@ from ._proto import ffi_pb2 as proto from typing import List, Optional from ._ffi_client import FfiClient, FfiHandle -from ._utils import get_address +from ._utils import _buffer_supported_or_raise, get_address from typing import Any @@ -48,10 +48,12 @@ def __init__( (e.g., RGBA, BGRA, RGB24, etc.). data (Union[bytes, bytearray, memoryview]): The raw pixel data for the video frame. """ + _buffer_supported_or_raise(data) + self._width = width self._height = height self._type = type - self._data = bytearray(data) + self._data = data @property def width(self) -> int: