Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "danom"
version = "0.10.2"
version = "0.10.3"
description = "Functional streams and monads"
readme = "README.md"
license = "MIT"
Expand Down
50 changes: 25 additions & 25 deletions src/danom/_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from types import TracebackType
from typing import (
Any,
Concatenate,
Generic,
Literal,
Never,
ParamSpec,
Self,
TypeVar,
Expand All @@ -20,12 +23,12 @@
F_co = TypeVar("F_co", bound=object, covariant=True)
P = ParamSpec("P")

Mappable = Callable[P, U_co]
Bindable = Callable[P, "Result[U_co, E_co]"]
Mappable = Callable[Concatenate[T_co, P], U_co]
Bindable = Callable[Concatenate[T_co, P], "Result[U_co, E_co]"]


@attrs.define(frozen=True)
class Result(ABC):
class Result(ABC, Generic[T_co, E_co]):
"""`Result` monad. Consists of `Ok` and `Err` for successful and failed operations respectively.
Each monad is a frozen instance to prevent further mutation.
"""
Expand Down Expand Up @@ -67,7 +70,7 @@ def is_ok(self) -> bool:
...

@abstractmethod
def map(self, func: Mappable, **kwargs: P.kwargs) -> Result[U_co, E_co]:
def map(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]:
"""Pipe a pure function and wrap the return value with `Ok`.
Given an `Err` will return self.

Expand All @@ -81,7 +84,7 @@ def map(self, func: Mappable, **kwargs: P.kwargs) -> Result[U_co, E_co]:
...

@abstractmethod
def map_err(self, func: Mappable, **kwargs: P.kwargs) -> Result[U_co, E_co]:
def map_err(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]:
"""Pipe a pure function and wrap the return value with `Err`.
Given an `Ok` will return self.

Expand All @@ -95,7 +98,7 @@ def map_err(self, func: Mappable, **kwargs: P.kwargs) -> Result[U_co, E_co]:
...

@abstractmethod
def and_then(self, func: Bindable, **kwargs: P.kwargs) -> Result[U_co, E_co]:
def and_then(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]:
"""Pipe another function that returns a monad. For `Err` will return original error.

.. code-block:: python
Expand All @@ -110,7 +113,7 @@ def and_then(self, func: Bindable, **kwargs: P.kwargs) -> Result[U_co, E_co]:
...

@abstractmethod
def or_else(self, func: Bindable, **kwargs: P.kwargs) -> Result[U_co, E_co]:
def or_else(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]:
"""Pipe a function that returns a monad to recover from an `Err`. For `Ok` will return original `Result`.

.. code-block:: python
Expand Down Expand Up @@ -147,27 +150,24 @@ def unwrap(self) -> T_co:
"""
...

def __class_getitem__(cls, _params: tuple) -> Self:
return cls # ty: ignore[invalid-return-type]


@attrs.define(frozen=True, hash=True)
class Ok(Result):
class Ok(Result[T_co, Never]):
inner: Any = attrs.field(default=None)

def is_ok(self) -> Literal[True]:
return True

def map(self, func: Mappable, **kwargs: P.kwargs) -> Ok[U_co]:
return Ok(func(self.inner, **kwargs))
def map(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Ok[U_co]:
return Ok(func(self.inner, *args, **kwargs))

def map_err(self, func: Mappable, **kwargs: P.kwargs) -> Ok[U_co]: # noqa: ARG002
def map_err(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002
return self

def and_then(self, func: Bindable, **kwargs: P.kwargs) -> Result[U_co, E_co]:
return func(self.inner, **kwargs)
def and_then(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]:
return func(self.inner, *args, **kwargs)

def or_else(self, func: Bindable, **kwargs: P.kwargs) -> Ok[T_co]: # noqa: ARG002
def or_else(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002
return self

def unwrap(self) -> T_co:
Expand All @@ -179,7 +179,7 @@ def unwrap(self) -> T_co:


@attrs.define(frozen=True)
class Err(Result):
class Err(Result[Never, E_co]):
error: Any = attrs.field(default=None)
input_args: tuple[()] | SafeArgs | SafeMethodArgs = attrs.field(
default=(), validator=instance_of(tuple), repr=False
Expand Down Expand Up @@ -210,19 +210,19 @@ def _extract_details(self, tb: TracebackType | None) -> list[dict[str, Any]]:
def is_ok(self) -> Literal[False]:
return False

def map(self, func: Mappable, **kwargs: P.kwargs) -> Err[E_co]: # noqa: ARG002
def map(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002
return self

def map_err(self, func: Mappable, **kwargs: P.kwargs) -> Err[F_co]:
return Err(func(self.error, **kwargs))
def map_err(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Err[F_co]:
return Err(func(self.error, *args, **kwargs))

def and_then(self, func: Bindable, **kwargs: P.kwargs) -> Err[E_co]: # noqa: ARG002
def and_then(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002
return self

def or_else(self, func: Bindable, **kwargs: P.kwargs) -> Result[U_co, E_co]:
return func(self.error, **kwargs)
def or_else(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]:
return func(self.error, *args, **kwargs)

def unwrap(self) -> None:
def unwrap(self) -> T_co:
if isinstance(self.error, Exception):
raise self.error
raise ValueError(f"Err does not have a caught error to raise: {self.error = }")
Expand Down
2 changes: 1 addition & 1 deletion src/danom/_safe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def add_one(a: int) -> int:
"""

@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> Result[U, E]:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> Result[U, Exception]:
try:
return Ok(func(*args, **kwargs))
except Exception as e: # noqa: BLE001
Expand Down
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
from multiprocessing.managers import ListProxy
from pathlib import Path
from typing import Any, Self
from typing import Any, NoReturn, Self

from src.danom import safe, safe_method
from src.danom._result import Err, Ok, Result
Expand Down Expand Up @@ -53,7 +53,7 @@ async def async_read_text(path: str) -> str:


@safe
def safe_add(a: int, b: int) -> Result[int, Exception]:
def safe_add(a: int, b: int) -> int:
return a + b


Expand All @@ -71,7 +71,7 @@ def safe_double[T: (str, float, int)](x: T) -> T:


@safe
def safe_raise_type_error(_a: Any) -> None: # noqa: ANN401
def safe_raise_type_error(_a: Any) -> NoReturn: # noqa: ANN401
raise TypeError


Expand Down
4 changes: 4 additions & 0 deletions tests/test_safe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest

from danom._result import Err
from tests.conftest import (
REPO_ROOT,
Adder,
Expand Down Expand Up @@ -71,6 +72,9 @@ def test_traceback():
"ZeroDivisionError: division by zero",
]

if not isinstance(err, Err):
raise TypeError("This should be an Err by now")

tb_lines = err.traceback.replace(str(REPO_ROOT), ".").splitlines()

missing_lines = [line for line in expected_lines if line not in tb_lines]
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.