diff --git a/pyproject.toml b/pyproject.toml index a183e5b..25caa42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "danom" -version = "0.10.2" +version = "0.10.3" description = "Functional streams and monads" readme = "README.md" license = "MIT" diff --git a/src/danom/_result.py b/src/danom/_result.py index dcbfe86..2c22def 100644 --- a/src/danom/_result.py +++ b/src/danom/_result.py @@ -5,7 +5,10 @@ from types import TracebackType from typing import ( Any, + Concatenate, + Generic, Literal, + Never, ParamSpec, Self, TypeVar, @@ -20,12 +23,12 @@ F_co = TypeVar("F_co", bound=object, covariant=True) P = ParamSpec("P") -Mappable = Callable[P, U_co] -Bindable = Callable[P, "Result[U_co, E_co]"] +Mappable = Callable[Concatenate[T_co, P], U_co] +Bindable = Callable[Concatenate[T_co, P], "Result[U_co, E_co]"] @attrs.define(frozen=True) -class Result(ABC): +class Result(ABC, Generic[T_co, E_co]): """`Result` monad. Consists of `Ok` and `Err` for successful and failed operations respectively. Each monad is a frozen instance to prevent further mutation. """ @@ -67,7 +70,7 @@ def is_ok(self) -> bool: ... @abstractmethod - def map(self, func: Mappable, **kwargs: P.kwargs) -> Result[U_co, E_co]: + def map(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]: """Pipe a pure function and wrap the return value with `Ok`. Given an `Err` will return self. @@ -81,7 +84,7 @@ def map(self, func: Mappable, **kwargs: P.kwargs) -> Result[U_co, E_co]: ... @abstractmethod - def map_err(self, func: Mappable, **kwargs: P.kwargs) -> Result[U_co, E_co]: + def map_err(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]: """Pipe a pure function and wrap the return value with `Err`. Given an `Ok` will return self. @@ -95,7 +98,7 @@ def map_err(self, func: Mappable, **kwargs: P.kwargs) -> Result[U_co, E_co]: ... @abstractmethod - def and_then(self, func: Bindable, **kwargs: P.kwargs) -> Result[U_co, E_co]: + def and_then(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]: """Pipe another function that returns a monad. For `Err` will return original error. .. code-block:: python @@ -110,7 +113,7 @@ def and_then(self, func: Bindable, **kwargs: P.kwargs) -> Result[U_co, E_co]: ... @abstractmethod - def or_else(self, func: Bindable, **kwargs: P.kwargs) -> Result[U_co, E_co]: + def or_else(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]: """Pipe a function that returns a monad to recover from an `Err`. For `Ok` will return original `Result`. .. code-block:: python @@ -147,27 +150,24 @@ def unwrap(self) -> T_co: """ ... - def __class_getitem__(cls, _params: tuple) -> Self: - return cls # ty: ignore[invalid-return-type] - @attrs.define(frozen=True, hash=True) -class Ok(Result): +class Ok(Result[T_co, Never]): inner: Any = attrs.field(default=None) def is_ok(self) -> Literal[True]: return True - def map(self, func: Mappable, **kwargs: P.kwargs) -> Ok[U_co]: - return Ok(func(self.inner, **kwargs)) + def map(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Ok[U_co]: + return Ok(func(self.inner, *args, **kwargs)) - def map_err(self, func: Mappable, **kwargs: P.kwargs) -> Ok[U_co]: # noqa: ARG002 + def map_err(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002 return self - def and_then(self, func: Bindable, **kwargs: P.kwargs) -> Result[U_co, E_co]: - return func(self.inner, **kwargs) + def and_then(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]: + return func(self.inner, *args, **kwargs) - def or_else(self, func: Bindable, **kwargs: P.kwargs) -> Ok[T_co]: # noqa: ARG002 + def or_else(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002 return self def unwrap(self) -> T_co: @@ -179,7 +179,7 @@ def unwrap(self) -> T_co: @attrs.define(frozen=True) -class Err(Result): +class Err(Result[Never, E_co]): error: Any = attrs.field(default=None) input_args: tuple[()] | SafeArgs | SafeMethodArgs = attrs.field( default=(), validator=instance_of(tuple), repr=False @@ -210,19 +210,19 @@ def _extract_details(self, tb: TracebackType | None) -> list[dict[str, Any]]: def is_ok(self) -> Literal[False]: return False - def map(self, func: Mappable, **kwargs: P.kwargs) -> Err[E_co]: # noqa: ARG002 + def map(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002 return self - def map_err(self, func: Mappable, **kwargs: P.kwargs) -> Err[F_co]: - return Err(func(self.error, **kwargs)) + def map_err(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Err[F_co]: + return Err(func(self.error, *args, **kwargs)) - def and_then(self, func: Bindable, **kwargs: P.kwargs) -> Err[E_co]: # noqa: ARG002 + def and_then(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002 return self - def or_else(self, func: Bindable, **kwargs: P.kwargs) -> Result[U_co, E_co]: - return func(self.error, **kwargs) + def or_else(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]: + return func(self.error, *args, **kwargs) - def unwrap(self) -> None: + def unwrap(self) -> T_co: if isinstance(self.error, Exception): raise self.error raise ValueError(f"Err does not have a caught error to raise: {self.error = }") diff --git a/src/danom/_safe.py b/src/danom/_safe.py index aeefd98..20fb410 100644 --- a/src/danom/_safe.py +++ b/src/danom/_safe.py @@ -26,7 +26,7 @@ def add_one(a: int) -> int: """ @functools.wraps(func) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> Result[U, E]: + def wrapper(*args: P.args, **kwargs: P.kwargs) -> Result[U, Exception]: try: return Ok(func(*args, **kwargs)) except Exception as e: # noqa: BLE001 diff --git a/tests/conftest.py b/tests/conftest.py index 3b46ab1..bf93208 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,7 @@ import asyncio from multiprocessing.managers import ListProxy from pathlib import Path -from typing import Any, Self +from typing import Any, NoReturn, Self from src.danom import safe, safe_method from src.danom._result import Err, Ok, Result @@ -53,7 +53,7 @@ async def async_read_text(path: str) -> str: @safe -def safe_add(a: int, b: int) -> Result[int, Exception]: +def safe_add(a: int, b: int) -> int: return a + b @@ -71,7 +71,7 @@ def safe_double[T: (str, float, int)](x: T) -> T: @safe -def safe_raise_type_error(_a: Any) -> None: # noqa: ANN401 +def safe_raise_type_error(_a: Any) -> NoReturn: # noqa: ANN401 raise TypeError diff --git a/tests/test_safe.py b/tests/test_safe.py index 008e8ef..a7ab0ae 100644 --- a/tests/test_safe.py +++ b/tests/test_safe.py @@ -2,6 +2,7 @@ import pytest +from danom._result import Err from tests.conftest import ( REPO_ROOT, Adder, @@ -71,6 +72,9 @@ def test_traceback(): "ZeroDivisionError: division by zero", ] + if not isinstance(err, Err): + raise TypeError("This should be an Err by now") + tb_lines = err.traceback.replace(str(REPO_ROOT), ".").splitlines() missing_lines = [line for line in expected_lines if line not in tb_lines] diff --git a/uv.lock b/uv.lock index 21eabf1..0db4803 100644 --- a/uv.lock +++ b/uv.lock @@ -317,7 +317,7 @@ wheels = [ [[package]] name = "danom" -version = "0.10.2" +version = "0.10.3" source = { editable = "." } dependencies = [ { name = "attrs" },