diff --git a/docs/source/conf.py b/docs/source/conf.py index f0daa9b2d..ecad14af7 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -100,7 +100,7 @@ def linkcode_resolve(domain: str, info: dict[str, str]) -> str | None: return link -def _get_obj(_info: dict[str, str]): +def _get_obj(_info: dict[str, str]) -> object: module_name = _info["module"] full_name = _info["fullname"] sub_module = sys.modules.get(module_name) @@ -112,7 +112,7 @@ def _get_obj(_info: dict[str, str]): return obj -def _get_file_name(obj) -> str | None: +def _get_file_name(obj: object) -> str | None: try: file_name = inspect.getsourcefile(obj) file_name = os.path.relpath(file_name, start=_PATH_ROOT) @@ -121,7 +121,7 @@ def _get_file_name(obj) -> str | None: return file_name -def _get_line_str(obj) -> str: +def _get_line_str(obj: object) -> str: source, start = inspect.getsourcelines(obj) end = start + len(source) - 1 line_str = f"#L{start}-L{end}" diff --git a/pyproject.toml b/pyproject.toml index 867277124..e20ac5739 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -134,6 +134,7 @@ select = [ "I", # isort "UP", # pyupgrade "ARG", # flake8-unused-arguments + "ANN", # flake-8-annotations "B", # flake8-bugbear "C4", # flake8-comprehensions "FIX", # flake8-fixme @@ -156,12 +157,16 @@ ignore = [ "RUF012", # Mutable default value for class attribute (a bit tedious to fix) "RET504", # Unnecessary assignment return statement "COM812", # Trailing comma missing (conflicts with formatter, see https://github.com/astral-sh/ruff/issues/9216) + "ANN401", # Prevent annotating as Any (we rarely do that, and when we do it's hard to find an alternative) ] [tool.ruff.lint.per-file-ignores] "**/conftest.py" = ["ARG"] # Can't change argument names in the functions pytest expects "tests/doc/test_rst.py" = ["ARG"] # For the lightning example +[tool.ruff.lint.flake8-annotations] +suppress-dummy-args = true + [tool.ruff.lint.isort] combine-as-imports = true diff --git a/src/torchjd/aggregation/_aggregator_bases.py b/src/torchjd/aggregation/_aggregator_bases.py index 78168eaed..d4be05e99 100644 --- a/src/torchjd/aggregation/_aggregator_bases.py +++ b/src/torchjd/aggregation/_aggregator_bases.py @@ -13,7 +13,7 @@ class Aggregator(nn.Module, ABC): :math:`m \times n` into row vectors of dimension :math:`n`. """ - def __init__(self): + def __init__(self) -> None: super().__init__() @staticmethod @@ -48,7 +48,7 @@ class WeightedAggregator(Aggregator): :param weighting: The object responsible for extracting the vector of weights from the matrix. """ - def __init__(self, weighting: Weighting[Matrix]): + def __init__(self, weighting: Weighting[Matrix]) -> None: super().__init__() self.weighting = weighting @@ -77,6 +77,6 @@ class GramianWeightedAggregator(WeightedAggregator): gramian. """ - def __init__(self, gramian_weighting: Weighting[PSDMatrix]): + def __init__(self, gramian_weighting: Weighting[PSDMatrix]) -> None: super().__init__(gramian_weighting << compute_gramian) self.gramian_weighting = gramian_weighting diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index fe807e0ac..2230c27f3 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -61,7 +61,7 @@ def __init__( self, pref_vector: Tensor | None = None, scale_mode: SUPPORTED_SCALE_MODE = "min", - ): + ) -> None: self._pref_vector = pref_vector self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode)) @@ -92,7 +92,7 @@ def __init__( self, pref_vector: Tensor | None = None, scale_mode: SUPPORTED_SCALE_MODE = "min", - ): + ) -> None: super().__init__() self._pref_vector = pref_vector self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index d29ca7b35..6731178bb 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -34,7 +34,7 @@ class CAGrad(GramianWeightedAggregator): To install it, use ``pip install "torchjd[cagrad]"``. """ - def __init__(self, c: float, norm_eps: float = 0.0001): + def __init__(self, c: float, norm_eps: float = 0.0001) -> None: super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps)) self._c = c self._norm_eps = norm_eps @@ -67,7 +67,7 @@ class CAGradWeighting(Weighting[PSDMatrix]): function. """ - def __init__(self, c: float, norm_eps: float = 0.0001): + def __init__(self, c: float, norm_eps: float = 0.0001) -> None: super().__init__() if c < 0.0: diff --git a/src/torchjd/aggregation/_config.py b/src/torchjd/aggregation/_config.py index 24866c41a..447ccd3af 100644 --- a/src/torchjd/aggregation/_config.py +++ b/src/torchjd/aggregation/_config.py @@ -50,7 +50,7 @@ class ConFIG(Aggregator): `_. """ - def __init__(self, pref_vector: Tensor | None = None): + def __init__(self, pref_vector: Tensor | None = None) -> None: super().__init__() self.weighting = pref_vector_to_weighting(pref_vector, default=SumWeighting()) self._pref_vector = pref_vector diff --git a/src/torchjd/aggregation/_constant.py b/src/torchjd/aggregation/_constant.py index 03b629eab..a547b813b 100644 --- a/src/torchjd/aggregation/_constant.py +++ b/src/torchjd/aggregation/_constant.py @@ -15,7 +15,7 @@ class Constant(WeightedAggregator): :param weights: The weights associated to the rows of the input matrices. """ - def __init__(self, weights: Tensor): + def __init__(self, weights: Tensor) -> None: super().__init__(weighting=ConstantWeighting(weights=weights)) self._weights = weights @@ -35,7 +35,7 @@ class ConstantWeighting(Weighting[Matrix]): :param weights: The weights to return at each call. """ - def __init__(self, weights: Tensor): + def __init__(self, weights: Tensor) -> None: if weights.dim() != 1: raise ValueError( "Parameter `weights` should be a 1-dimensional tensor. Found `weights.shape = " diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index d7d886489..7e868f620 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -33,7 +33,7 @@ def __init__( norm_eps: float = 0.0001, reg_eps: float = 0.0001, solver: SUPPORTED_SOLVER = "quadprog", - ): + ) -> None: self._pref_vector = pref_vector self._norm_eps = norm_eps self._reg_eps = reg_eps @@ -77,7 +77,7 @@ def __init__( norm_eps: float = 0.0001, reg_eps: float = 0.0001, solver: SUPPORTED_SOLVER = "quadprog", - ): + ) -> None: super().__init__() self._pref_vector = pref_vector self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) diff --git a/src/torchjd/aggregation/_flattening.py b/src/torchjd/aggregation/_flattening.py index 208db3ec2..15736b523 100644 --- a/src/torchjd/aggregation/_flattening.py +++ b/src/torchjd/aggregation/_flattening.py @@ -20,7 +20,7 @@ class Flattening(GeneralizedWeighting): :param weighting: The weighting to apply to the Gramian matrix. """ - def __init__(self, weighting: Weighting): + def __init__(self, weighting: Weighting) -> None: super().__init__() self.weighting = weighting diff --git a/src/torchjd/aggregation/_graddrop.py b/src/torchjd/aggregation/_graddrop.py index afa164519..61c9354ec 100644 --- a/src/torchjd/aggregation/_graddrop.py +++ b/src/torchjd/aggregation/_graddrop.py @@ -26,7 +26,7 @@ class GradDrop(Aggregator): through. Defaults to None, which means no leak. """ - def __init__(self, f: Callable = _identity, leak: Tensor | None = None): + def __init__(self, f: Callable = _identity, leak: Tensor | None = None) -> None: if leak is not None and leak.dim() != 1: raise ValueError( "Parameter `leak` should be a 1-dimensional tensor. Found `leak.shape = " diff --git a/src/torchjd/aggregation/_imtl_g.py b/src/torchjd/aggregation/_imtl_g.py index f45e8c2e9..75d00b76e 100644 --- a/src/torchjd/aggregation/_imtl_g.py +++ b/src/torchjd/aggregation/_imtl_g.py @@ -16,7 +16,7 @@ class IMTLG(GramianWeightedAggregator): `_, supports matrices with some linearly dependant rows. """ - def __init__(self): + def __init__(self) -> None: super().__init__(IMTLGWeighting()) # This prevents computing gradients that can be very wrong. diff --git a/src/torchjd/aggregation/_krum.py b/src/torchjd/aggregation/_krum.py index 935651746..40285d89c 100644 --- a/src/torchjd/aggregation/_krum.py +++ b/src/torchjd/aggregation/_krum.py @@ -19,7 +19,7 @@ class Krum(GramianWeightedAggregator): :param n_selected: The number of selected rows in the context of Multi-Krum. Defaults to 1. """ - def __init__(self, n_byzantine: int, n_selected: int = 1): + def __init__(self, n_byzantine: int, n_selected: int = 1) -> None: self._n_byzantine = n_byzantine self._n_selected = n_selected super().__init__(KrumWeighting(n_byzantine=n_byzantine, n_selected=n_selected)) @@ -44,7 +44,7 @@ class KrumWeighting(Weighting[PSDMatrix]): :param n_selected: The number of selected rows in the context of Multi-Krum. Defaults to 1. """ - def __init__(self, n_byzantine: int, n_selected: int = 1): + def __init__(self, n_byzantine: int, n_selected: int = 1) -> None: super().__init__() if n_byzantine < 0: raise ValueError( diff --git a/src/torchjd/aggregation/_mean.py b/src/torchjd/aggregation/_mean.py index d7085e104..8fc5b057a 100644 --- a/src/torchjd/aggregation/_mean.py +++ b/src/torchjd/aggregation/_mean.py @@ -13,7 +13,7 @@ class Mean(WeightedAggregator): matrices. """ - def __init__(self): + def __init__(self) -> None: super().__init__(weighting=MeanWeighting()) diff --git a/src/torchjd/aggregation/_mgda.py b/src/torchjd/aggregation/_mgda.py index 8f753c2a4..aec329470 100644 --- a/src/torchjd/aggregation/_mgda.py +++ b/src/torchjd/aggregation/_mgda.py @@ -20,7 +20,7 @@ class MGDA(GramianWeightedAggregator): :param max_iters: The maximum number of iterations of the optimization loop. """ - def __init__(self, epsilon: float = 0.001, max_iters: int = 100): + def __init__(self, epsilon: float = 0.001, max_iters: int = 100) -> None: super().__init__(MGDAWeighting(epsilon=epsilon, max_iters=max_iters)) self._epsilon = epsilon self._max_iters = max_iters @@ -38,7 +38,7 @@ class MGDAWeighting(Weighting[PSDMatrix]): :param max_iters: The maximum number of iterations of the optimization loop. """ - def __init__(self, epsilon: float = 0.001, max_iters: int = 100): + def __init__(self, epsilon: float = 0.001, max_iters: int = 100) -> None: super().__init__() self.epsilon = epsilon self.max_iters = max_iters diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 83455245a..06e1293df 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -77,7 +77,7 @@ def __init__( max_norm: float = 1.0, update_weights_every: int = 1, optim_niter: int = 20, - ): + ) -> None: super().__init__( weighting=_NashMTLWeighting( n_tasks=n_tasks, @@ -126,7 +126,7 @@ def __init__( max_norm: float, update_weights_every: int, optim_niter: int, - ): + ) -> None: super().__init__() self.n_tasks = n_tasks diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/aggregation/_pcgrad.py index d6cc3f104..0f1241df7 100644 --- a/src/torchjd/aggregation/_pcgrad.py +++ b/src/torchjd/aggregation/_pcgrad.py @@ -16,7 +16,7 @@ class PCGrad(GramianWeightedAggregator): `Gradient Surgery for Multi-Task Learning `_. """ - def __init__(self): + def __init__(self) -> None: super().__init__(PCGradWeighting()) # This prevents running into a RuntimeError due to modifying stored tensors in place. diff --git a/src/torchjd/aggregation/_random.py b/src/torchjd/aggregation/_random.py index 53ef188ca..734dfc177 100644 --- a/src/torchjd/aggregation/_random.py +++ b/src/torchjd/aggregation/_random.py @@ -16,7 +16,7 @@ class Random(WeightedAggregator): `_. """ - def __init__(self): + def __init__(self) -> None: super().__init__(RandomWeighting()) diff --git a/src/torchjd/aggregation/_sum.py b/src/torchjd/aggregation/_sum.py index 0d8bd5d63..aaf73f029 100644 --- a/src/torchjd/aggregation/_sum.py +++ b/src/torchjd/aggregation/_sum.py @@ -13,7 +13,7 @@ class Sum(WeightedAggregator): matrices. """ - def __init__(self): + def __init__(self) -> None: super().__init__(weighting=SumWeighting()) diff --git a/src/torchjd/aggregation/_trimmed_mean.py b/src/torchjd/aggregation/_trimmed_mean.py index 77d33c418..8dffe990d 100644 --- a/src/torchjd/aggregation/_trimmed_mean.py +++ b/src/torchjd/aggregation/_trimmed_mean.py @@ -15,7 +15,7 @@ class TrimmedMean(Aggregator): input matrix (note that ``2 * trim_number`` values are removed from each column). """ - def __init__(self, trim_number: int): + def __init__(self, trim_number: int) -> None: super().__init__() if trim_number < 0: raise ValueError( diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index 8234b3a88..45f760be9 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -34,7 +34,7 @@ def __init__( norm_eps: float = 0.0001, reg_eps: float = 0.0001, solver: SUPPORTED_SOLVER = "quadprog", - ): + ) -> None: self._pref_vector = pref_vector self._norm_eps = norm_eps self._reg_eps = reg_eps @@ -78,7 +78,7 @@ def __init__( norm_eps: float = 0.0001, reg_eps: float = 0.0001, solver: SUPPORTED_SOLVER = "quadprog", - ): + ) -> None: super().__init__() self._pref_vector = pref_vector self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) diff --git a/src/torchjd/aggregation/_utils/non_differentiable.py b/src/torchjd/aggregation/_utils/non_differentiable.py index e3f854203..c5fb1ffc5 100644 --- a/src/torchjd/aggregation/_utils/non_differentiable.py +++ b/src/torchjd/aggregation/_utils/non_differentiable.py @@ -2,7 +2,7 @@ class NonDifferentiableError(RuntimeError): - def __init__(self, module: nn.Module): + def __init__(self, module: nn.Module) -> None: super().__init__(f"Trying to differentiate through {module}, which is not differentiable.") diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/aggregation/_weighting_bases.py index dd7c53ee6..e321169c3 100644 --- a/src/torchjd/aggregation/_weighting_bases.py +++ b/src/torchjd/aggregation/_weighting_bases.py @@ -20,7 +20,7 @@ class Weighting(nn.Module, ABC, Generic[_T]): generally its Gramian, of dimension :math:`m \times m`. """ - def __init__(self): + def __init__(self) -> None: super().__init__() @abstractmethod @@ -46,7 +46,7 @@ class _Composition(Weighting[_T]): output of the function. """ - def __init__(self, weighting: Weighting[_FnOutputT], fn: Callable[[_T], _FnOutputT]): + def __init__(self, weighting: Weighting[_FnOutputT], fn: Callable[[_T], _FnOutputT]) -> None: super().__init__() self.fn = fn self.weighting = weighting @@ -63,7 +63,7 @@ class GeneralizedWeighting(nn.Module, ABC): :math:`m_1 \times \dots \times m_k \times m_k \times \dots \times m_1`. """ - def __init__(self): + def __init__(self) -> None: super().__init__() @abstractmethod diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 610b37535..7b7eae966 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -183,7 +183,7 @@ def __init__( self, *modules: nn.Module, batch_dim: int | None, - ): + ) -> None: self._gramian_accumulator = GramianAccumulator() self._target_edges = EdgeRegistry() self._batch_dim = batch_dim diff --git a/src/torchjd/autogram/_gramian_computer.py b/src/torchjd/autogram/_gramian_computer.py index 829e5da36..cdc7ce939 100644 --- a/src/torchjd/autogram/_gramian_computer.py +++ b/src/torchjd/autogram/_gramian_computer.py @@ -29,7 +29,7 @@ def reset(self) -> None: class JacobianBasedGramianComputer(GramianComputer, ABC): - def __init__(self, jacobian_computer: JacobianComputer): + def __init__(self, jacobian_computer: JacobianComputer) -> None: self.jacobian_computer = jacobian_computer @@ -39,7 +39,7 @@ class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer): the gramian. """ - def __init__(self, jacobian_computer: JacobianComputer): + def __init__(self, jacobian_computer: JacobianComputer) -> None: super().__init__(jacobian_computer) self.remaining_counter = 0 self.summed_jacobian: Matrix | None = None diff --git a/src/torchjd/autogram/_jacobian_computer.py b/src/torchjd/autogram/_jacobian_computer.py index 8dc88dbab..6bddbd60f 100644 --- a/src/torchjd/autogram/_jacobian_computer.py +++ b/src/torchjd/autogram/_jacobian_computer.py @@ -26,7 +26,7 @@ class JacobianComputer(ABC): :params module: The module to differentiate. """ - def __init__(self, module: nn.Module): + def __init__(self, module: nn.Module) -> None: self.module = module self.rg_params = dict[str, Parameter]() diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index 23f84cbff..39ebf6095 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -1,5 +1,5 @@ import weakref -from typing import cast +from typing import Any, cast import torch from torch import Tensor, nn @@ -34,7 +34,7 @@ def __init__( self, target_edges: EdgeRegistry, gramian_accumulator: GramianAccumulator, - ): + ) -> None: self._target_edges = target_edges self._gramian_accumulator = gramian_accumulator self.gramian_accumulation_phase = BoolRef(False) @@ -80,7 +80,7 @@ def remove_hooks(handles: list[TorchRemovableHandle]) -> None: class BoolRef: """Class wrapping a boolean value, acting as a reference to this boolean value.""" - def __init__(self, value: bool): + def __init__(self, value: bool) -> None: self.value = value def __bool__(self) -> bool: @@ -94,7 +94,7 @@ def __init__( target_edges: EdgeRegistry, gramian_accumulator: GramianAccumulator, gramian_computer: GramianComputer, - ): + ) -> None: self.gramian_accumulation_phase = gramian_accumulation_phase self.target_edges = target_edges self.gramian_accumulator = gramian_accumulator @@ -171,7 +171,7 @@ def forward( # tuple[BoolRef, GramianComputer, tuple[PyTree, ...], dict[str, PyTree], GramianAccumulator, *tuple[Tensor, ...]] @staticmethod def setup_context( - ctx, + ctx: Any, inputs: tuple, _, ) -> None: # type: ignore[reportIncompatibleMethodOverride] @@ -183,7 +183,7 @@ def setup_context( ctx.rg_outputs = inputs[5:] @staticmethod - def backward(ctx, *grad_outputs: Tensor) -> tuple: + def backward(ctx: Any, *grad_outputs: Tensor) -> tuple: # For python > 3.10: -> tuple[None, None, None, None, None, *tuple[Tensor, ...]] if ctx.gramian_accumulation_phase: diff --git a/src/torchjd/autojac/_transform/_base.py b/src/torchjd/autojac/_transform/_base.py index 579b845c7..4cba51208 100644 --- a/src/torchjd/autojac/_transform/_base.py +++ b/src/torchjd/autojac/_transform/_base.py @@ -67,7 +67,7 @@ class Composition(Transform): :param outer: The transform to apply second, to the result of ``inner``. """ - def __init__(self, outer: Transform, inner: Transform): + def __init__(self, outer: Transform, inner: Transform) -> None: self.outer = outer self.inner = inner @@ -92,7 +92,7 @@ class Conjunction(Transform): :param transforms: The transforms to apply. Their outputs should have disjoint sets of keys. """ - def __init__(self, transforms: Sequence[Transform]): + def __init__(self, transforms: Sequence[Transform]) -> None: self.transforms = transforms def __str__(self) -> str: diff --git a/src/torchjd/autojac/_transform/_diagonalize.py b/src/torchjd/autojac/_transform/_diagonalize.py index 7954d7cee..11c951deb 100644 --- a/src/torchjd/autojac/_transform/_diagonalize.py +++ b/src/torchjd/autojac/_transform/_diagonalize.py @@ -51,7 +51,7 @@ class Diagonalize(Transform): Jacobians. """ - def __init__(self, key_order: OrderedSet[Tensor]): + def __init__(self, key_order: OrderedSet[Tensor]) -> None: self.key_order = key_order self.indices: list[tuple[int, int]] = [] begin = 0 diff --git a/src/torchjd/autojac/_transform/_differentiate.py b/src/torchjd/autojac/_transform/_differentiate.py index 1ce264386..458bd8d06 100644 --- a/src/torchjd/autojac/_transform/_differentiate.py +++ b/src/torchjd/autojac/_transform/_differentiate.py @@ -31,7 +31,7 @@ def __init__( inputs: OrderedSet[Tensor], retain_graph: bool, create_graph: bool, - ): + ) -> None: self.outputs = list(outputs) self.inputs = list(inputs) self.retain_graph = retain_graph diff --git a/src/torchjd/autojac/_transform/_grad.py b/src/torchjd/autojac/_transform/_grad.py index e61d47480..a4bd4ff3a 100644 --- a/src/torchjd/autojac/_transform/_grad.py +++ b/src/torchjd/autojac/_transform/_grad.py @@ -31,7 +31,7 @@ def __init__( inputs: OrderedSet[Tensor], retain_graph: bool = False, create_graph: bool = False, - ): + ) -> None: super().__init__(outputs, inputs, retain_graph, create_graph) def _differentiate(self, grad_outputs: Sequence[Tensor], /) -> tuple[Tensor, ...]: diff --git a/src/torchjd/autojac/_transform/_init.py b/src/torchjd/autojac/_transform/_init.py index 508330325..9da503ed9 100644 --- a/src/torchjd/autojac/_transform/_init.py +++ b/src/torchjd/autojac/_transform/_init.py @@ -13,7 +13,7 @@ class Init(Transform): :param values: Tensors for which Gradients must be returned. """ - def __init__(self, values: AbstractSet[Tensor]): + def __init__(self, values: AbstractSet[Tensor]) -> None: self.values = values def __call__(self, _input: TensorDict, /) -> TensorDict: diff --git a/src/torchjd/autojac/_transform/_jac.py b/src/torchjd/autojac/_transform/_jac.py index 1f33d1d90..e245fae02 100644 --- a/src/torchjd/autojac/_transform/_jac.py +++ b/src/torchjd/autojac/_transform/_jac.py @@ -38,7 +38,7 @@ def __init__( chunk_size: int | None, retain_graph: bool = False, create_graph: bool = False, - ): + ) -> None: super().__init__(outputs, inputs, retain_graph, create_graph) self.chunk_size = chunk_size diff --git a/src/torchjd/autojac/_transform/_ordered_set.py b/src/torchjd/autojac/_transform/_ordered_set.py index c929cb45e..e182df895 100644 --- a/src/torchjd/autojac/_transform/_ordered_set.py +++ b/src/torchjd/autojac/_transform/_ordered_set.py @@ -10,7 +10,7 @@ class OrderedSet(MutableSet[_T]): """Ordered collection of distinct elements.""" - def __init__(self, elements: Iterable[_T]): + def __init__(self, elements: Iterable[_T]) -> None: super().__init__() self.ordered_dict = OrderedDict[_T, None]([(element, None) for element in elements]) diff --git a/src/torchjd/autojac/_transform/_select.py b/src/torchjd/autojac/_transform/_select.py index b2e45caae..9df527ff4 100644 --- a/src/torchjd/autojac/_transform/_select.py +++ b/src/torchjd/autojac/_transform/_select.py @@ -12,7 +12,7 @@ class Select(Transform): :param keys: The keys that should be included in the returned subset. """ - def __init__(self, keys: AbstractSet[Tensor]): + def __init__(self, keys: AbstractSet[Tensor]) -> None: self.keys = keys def __call__(self, tensor_dict: TensorDict, /) -> TensorDict: diff --git a/src/torchjd/autojac/_transform/_stack.py b/src/torchjd/autojac/_transform/_stack.py index a4152afcd..ad628e5d9 100644 --- a/src/torchjd/autojac/_transform/_stack.py +++ b/src/torchjd/autojac/_transform/_stack.py @@ -20,7 +20,7 @@ class Stack(Transform): to those dicts. """ - def __init__(self, transforms: Sequence[Transform]): + def __init__(self, transforms: Sequence[Transform]) -> None: self.transforms = transforms def __call__(self, input: TensorDict, /) -> TensorDict: diff --git a/tests/conftest.py b/tests/conftest.py index 25e8281c2..e2b75d673 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import warnings from contextlib import nullcontext +import pytest import torch from pytest import RaisesExc, fixture, mark from settings import DEVICE @@ -30,16 +31,16 @@ def fix_randomness() -> None: torch.use_deterministic_algorithms(True) -def pytest_addoption(parser): +def pytest_addoption(parser: pytest.Parser) -> None: parser.addoption("--runslow", action="store_true", default=False, help="run slow tests") -def pytest_configure(config): +def pytest_configure(config: pytest.Config) -> None: config.addinivalue_line("markers", "slow: mark test as slow to run") config.addinivalue_line("markers", "xfail_if_cuda: mark test as xfail if running on cuda") -def pytest_collection_modifyitems(config, items): +def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None: skip_slow = mark.skip(reason="Slow test. Use --runslow to run it.") xfail_cuda = mark.xfail(reason=f"Test expected to fail on {DEVICE}") for item in items: @@ -49,7 +50,7 @@ def pytest_collection_modifyitems(config, items): item.add_marker(xfail_cuda) -def pytest_make_parametrize_id(config, val, argname): +def pytest_make_parametrize_id(config: pytest.Config, val: object, argname: str) -> str | None: MAX_SIZE = 40 optional_string = None # Returning None means using pytest's way of making the string diff --git a/tests/doc/test_aggregation.py b/tests/doc/test_aggregation.py index 75ef0ddcd..a4219e4e6 100644 --- a/tests/doc/test_aggregation.py +++ b/tests/doc/test_aggregation.py @@ -4,7 +4,7 @@ from torch.testing import assert_close -def test_aggregation_and_weighting(): +def test_aggregation_and_weighting() -> None: from torch import tensor from torchjd.aggregation import UPGrad, UPGradWeighting @@ -22,7 +22,7 @@ def test_aggregation_and_weighting(): assert_close(weights, tensor([1.1109, 0.7894]), rtol=0, atol=1e-4) -def test_generalized_weighting(): +def test_generalized_weighting() -> None: from torch import ones from torchjd.aggregation import Flattening, UPGradWeighting diff --git a/tests/doc/test_autogram.py b/tests/doc/test_autogram.py index 436518245..a0861e240 100644 --- a/tests/doc/test_autogram.py +++ b/tests/doc/test_autogram.py @@ -1,7 +1,7 @@ """This file contains tests for the usage examples related to autogram.""" -def test_engine(): +def test_engine() -> None: import torch from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD diff --git a/tests/doc/test_backward.py b/tests/doc/test_backward.py index 2416210e1..d08d2c2fc 100644 --- a/tests/doc/test_backward.py +++ b/tests/doc/test_backward.py @@ -6,7 +6,7 @@ from utils.asserts import assert_jac_close -def test_backward(): +def test_backward() -> None: import torch from torchjd.autojac import backward @@ -21,7 +21,7 @@ def test_backward(): assert_jac_close(param, torch.tensor([[-1.0, 1.0], [2.0, 4.0]]), rtol=0.0, atol=1e-04) -def test_backward2(): +def test_backward2() -> None: import torch from torchjd.autojac import backward diff --git a/tests/doc/test_jac.py b/tests/doc/test_jac.py index e195c233a..1a0b79a20 100644 --- a/tests/doc/test_jac.py +++ b/tests/doc/test_jac.py @@ -5,7 +5,7 @@ from torch.testing import assert_close -def test_jac(): +def test_jac() -> None: import torch from torchjd.autojac import jac @@ -20,7 +20,7 @@ def test_jac(): assert_close(jacobians[0], torch.tensor([[-1.0, 1.0], [2.0, 4.0]]), rtol=0.0, atol=1e-04) -def test_jac_2(): +def test_jac_2() -> None: import torch from torchjd.autojac import jac @@ -44,7 +44,7 @@ def test_jac_2(): ) -def test_jac_3(): +def test_jac_3() -> None: import torch from torchjd.autojac import jac diff --git a/tests/doc/test_jac_to_grad.py b/tests/doc/test_jac_to_grad.py index 1f064a6c2..04ca3ac27 100644 --- a/tests/doc/test_jac_to_grad.py +++ b/tests/doc/test_jac_to_grad.py @@ -6,7 +6,7 @@ from utils.asserts import assert_grad_close -def test_jac_to_grad(): +def test_jac_to_grad() -> None: import torch from torchjd.aggregation import UPGrad diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index ac4ac0600..90dc099a4 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -9,7 +9,7 @@ from pytest import mark -def test_amp(): +def test_amp() -> None: import torch from torch.amp import GradScaler from torch.nn import Linear, MSELoss, ReLU, Sequential @@ -51,7 +51,7 @@ def test_amp(): optimizer.zero_grad() -def test_basic_usage(): +def test_basic_usage() -> None: import torch from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD @@ -79,7 +79,7 @@ def test_basic_usage(): optimizer.zero_grad() -def test_iwmtl(): +def test_iwmtl() -> None: import torch from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD @@ -125,8 +125,8 @@ def test_iwmtl(): optimizer.zero_grad() -def test_iwrm(): - def test_autograd(): +def test_iwrm() -> None: + def test_autograd() -> None: import torch from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD @@ -147,7 +147,7 @@ def test_autograd(): optimizer.step() optimizer.zero_grad() - def test_autojac(): + def test_autojac() -> None: import torch from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD @@ -173,7 +173,7 @@ def test_autojac(): optimizer.step() optimizer.zero_grad() - def test_autogram(): + def test_autogram() -> None: import torch from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD @@ -212,7 +212,7 @@ def test_autogram(): "ignore::lightning.fabric.utilities.warnings.PossibleUserWarning", ) @no_type_check # Typing is annoying with Lightning, which would make the example too hard to read. -def test_lightning_integration(): +def test_lightning_integration() -> None: # Extra ---------------------------------------------------------------------------------------- import logging @@ -231,14 +231,14 @@ def test_lightning_integration(): from torchjd.autojac import jac_to_grad, mtl_backward class Model(LightningModule): - def __init__(self): + def __init__(self) -> None: super().__init__() self.feature_extractor = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) self.task1_head = Linear(3, 1) self.task2_head = Linear(3, 1) self.automatic_optimization = False - def training_step(self, batch, batch_idx) -> None: + def training_step(self, batch, batch_idx) -> None: # noqa: ANN001 input, target1, target2 = batch features = self.feature_extractor(input) @@ -278,7 +278,7 @@ def configure_optimizers(self) -> OptimizerLRScheduler: trainer.fit(model=model, train_dataloaders=train_loader) -def test_monitoring(): +def test_monitoring() -> None: import torch from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.nn.functional import cosine_similarity @@ -331,7 +331,7 @@ def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch. optimizer.zero_grad() -def test_mtl(): +def test_mtl() -> None: import torch from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD @@ -369,7 +369,7 @@ def test_mtl(): optimizer.zero_grad() -def test_partial_jd(): +def test_partial_jd() -> None: import torch from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD @@ -402,7 +402,7 @@ def test_partial_jd(): optimizer.zero_grad() -def test_rnn(): +def test_rnn() -> None: import torch from torch.nn import RNN from torch.optim import SGD diff --git a/tests/plots/_utils.py b/tests/plots/_utils.py index e184d6d92..dc69bfdaf 100644 --- a/tests/plots/_utils.py +++ b/tests/plots/_utils.py @@ -7,7 +7,7 @@ class Plotter: - def __init__(self, aggregators: list[Aggregator], matrix: torch.Tensor, seed: int = 0): + def __init__(self, aggregators: list[Aggregator], matrix: torch.Tensor, seed: int = 0) -> None: self.aggregators = aggregators self.matrix = matrix self.seed = seed diff --git a/tests/plots/interactive_plotter.py b/tests/plots/interactive_plotter.py index 0d26a1e6b..2a945f939 100644 --- a/tests/plots/interactive_plotter.py +++ b/tests/plots/interactive_plotter.py @@ -120,7 +120,7 @@ def update_seed(value: int) -> Figure: *gradient_slider_inputs, prevent_initial_call=True, ) - def update_gradient_coordinate(*values) -> Figure: + def update_gradient_coordinate(*values: str) -> Figure: values_ = [float(value) for value in values] for j in range(len(values_) // 2): diff --git a/tests/plots/static_plotter.py b/tests/plots/static_plotter.py index c26c0d6b0..de5d68369 100644 --- a/tests/plots/static_plotter.py +++ b/tests/plots/static_plotter.py @@ -25,14 +25,14 @@ def main( *, - gradients=False, - cone=False, - projections=False, - upgrad=False, - mean=False, - dual_proj=False, - mgda=False, -): + gradients: bool = False, + cone: bool = False, + projections: bool = False, + upgrad: bool = False, + mean: bool = False, + dual_proj: bool = False, + mgda: bool = False, +) -> None: angle1 = 2.6 angle2 = 0.3277 norm1 = 0.9 @@ -76,13 +76,13 @@ def main( if cone: filename += "_cone" start_angle, opening = compute_2d_non_conflicting_cone(matrix.numpy()) - cone = make_cone_scatter( + cone_scatter = make_cone_scatter( start_angle, opening, label="Non-conflicting cone", printable=False, ) - fig.add_trace(cone) + fig.add_trace(cone_scatter) if projections: filename += "_projections" diff --git a/tests/profiling/plot_memory_timeline.py b/tests/profiling/plot_memory_timeline.py index 0f792be1c..4c2d94390 100644 --- a/tests/profiling/plot_memory_timeline.py +++ b/tests/profiling/plot_memory_timeline.py @@ -24,7 +24,7 @@ class MemoryFrame: device_id: int # -1 for CPU, 0+ for CUDA devices @staticmethod - def from_event(event: dict): + def from_event(event: dict) -> "MemoryFrame": args = event["args"] return MemoryFrame( timestamp=event["ts"], @@ -112,7 +112,7 @@ def plot_memory_timelines(experiment: str, folders: list[str]) -> None: print("Plot saved successfully!") -def main(): +def main() -> None: parser = argparse.ArgumentParser(description="Plot memory timeline from profiling traces.") parser.add_argument( "experiment", diff --git a/tests/profiling/run_profiler.py b/tests/profiling/run_profiler.py index 7707cb016..9807849bb 100644 --- a/tests/profiling/run_profiler.py +++ b/tests/profiling/run_profiler.py @@ -3,7 +3,9 @@ import torch from settings import DEVICE +from torch import Tensor, nn from torch.profiler import ProfilerActivity, profile +from torch.utils._pytree import PyTree from utils.architectures import ( AlexNet, Cifar10Model, @@ -105,7 +107,9 @@ def _save_and_print_trace( def profile_autojac(factory: ModuleFactory, batch_size: int) -> None: - def forward_backward_fn(model, inputs, loss_fn): + def forward_backward_fn( + model: nn.Module, inputs: PyTree, loss_fn: Callable[[PyTree], list[Tensor]] + ) -> None: aggregator = UPGrad() autojac_forward_backward(model, inputs, loss_fn, aggregator) @@ -113,7 +117,9 @@ def forward_backward_fn(model, inputs, loss_fn): def profile_autogram(factory: ModuleFactory, batch_size: int) -> None: - def forward_backward_fn(model, inputs, loss_fn): + def forward_backward_fn( + model: nn.Module, inputs: PyTree, loss_fn: Callable[[PyTree], list[Tensor]] + ) -> None: engine = Engine(model, batch_dim=0) weighting = UPGradWeighting() autogram_forward_backward(model, inputs, loss_fn, engine, weighting) @@ -121,7 +127,7 @@ def forward_backward_fn(model, inputs, loss_fn): profile_method("autogram", forward_backward_fn, factory, batch_size) -def main(): +def main() -> None: for factory, batch_size in PARAMETRIZATIONS: profile_autojac(factory, batch_size) print("\n" + "=" * 80 + "\n") diff --git a/tests/profiling/speed_grad_vs_jac_vs_gram.py b/tests/profiling/speed_grad_vs_jac_vs_gram.py index d68d67aa0..b670a96d3 100644 --- a/tests/profiling/speed_grad_vs_jac_vs_gram.py +++ b/tests/profiling/speed_grad_vs_jac_vs_gram.py @@ -1,5 +1,6 @@ import gc import time +from collections.abc import Callable import torch from settings import DEVICE @@ -41,13 +42,13 @@ ] -def main(): +def main() -> None: for factory, batch_size in PARAMETRIZATIONS: compare_autograd_autojac_and_autogram_speed(factory, batch_size) print("\n") -def compare_autograd_autojac_and_autogram_speed(factory: ModuleFactory, batch_size: int): +def compare_autograd_autojac_and_autogram_speed(factory: ModuleFactory, batch_size: int) -> None: model = factory() inputs, targets = make_inputs_and_targets(model, batch_size) loss_fn = make_mse_loss_fn(targets) @@ -57,47 +58,47 @@ def compare_autograd_autojac_and_autogram_speed(factory: ModuleFactory, batch_si print(f"\nTimes for forward + backward on {factory} with BS={batch_size}, A={A} on {DEVICE}.") - def fn_autograd(): + def fn_autograd() -> None: autograd_forward_backward(model, inputs, loss_fn) - def init_fn_autograd(): + def init_fn_autograd() -> None: torch.cuda.empty_cache() gc.collect() fn_autograd() - def fn_autograd_gramian(): + def fn_autograd_gramian() -> None: autograd_gramian_forward_backward(model, inputs, loss_fn, W) - def init_fn_autograd_gramian(): + def init_fn_autograd_gramian() -> None: torch.cuda.empty_cache() gc.collect() fn_autograd_gramian() - def fn_autojac(): + def fn_autojac() -> None: autojac_forward_backward(model, inputs, loss_fn, A) - def init_fn_autojac(): + def init_fn_autojac() -> None: torch.cuda.empty_cache() gc.collect() fn_autojac() - def fn_autogram(): + def fn_autogram() -> None: autogram_forward_backward(model, inputs, loss_fn, engine, W) - def init_fn_autogram(): + def init_fn_autogram() -> None: torch.cuda.empty_cache() gc.collect() fn_autogram() - def optionally_cuda_sync(): + def optionally_cuda_sync() -> None: if DEVICE.type == "cuda": torch.cuda.synchronize() - def pre_fn(): + def pre_fn() -> None: model.zero_grad() optionally_cuda_sync() - def post_fn(): + def post_fn() -> None: optionally_cuda_sync() n_runs = 10 @@ -121,11 +122,17 @@ def post_fn(): print_times("autogram", autogram_times) -def noop(): +def noop() -> None: pass -def time_call(fn, init_fn=noop, pre_fn=noop, post_fn=noop, n_runs: int = 10) -> Tensor: +def time_call( + fn: Callable[[], None], + init_fn: Callable[[], None] = noop, + pre_fn: Callable[[], None] = noop, + post_fn: Callable[[], None] = noop, + n_runs: int = 10, +) -> Tensor: init_fn() times = [] diff --git a/tests/unit/aggregation/_asserts.py b/tests/unit/aggregation/_asserts.py index 8c119674e..4b85bf09f 100644 --- a/tests/unit/aggregation/_asserts.py +++ b/tests/unit/aggregation/_asserts.py @@ -101,7 +101,7 @@ def assert_strongly_stationary( assert norm > threshold -def assert_non_differentiable(aggregator: Aggregator, matrix: Tensor): +def assert_non_differentiable(aggregator: Aggregator, matrix: Tensor) -> None: """ Tests empirically that a given non-differentiable `Aggregator` correctly raises a NonDifferentiableError whenever we try to backward through it. diff --git a/tests/unit/aggregation/_matrix_samplers.py b/tests/unit/aggregation/_matrix_samplers.py index 1b5cc8abb..68699ccef 100644 --- a/tests/unit/aggregation/_matrix_samplers.py +++ b/tests/unit/aggregation/_matrix_samplers.py @@ -9,7 +9,7 @@ class MatrixSampler(ABC): """Abstract base class for sampling matrices of a given shape, rank.""" - def __init__(self, m: int, n: int, rank: int): + def __init__(self, m: int, n: int, rank: int) -> None: self._check_params(m, n, rank) self.m = m self.n = n diff --git a/tests/unit/aggregation/_utils/test_dual_cone.py b/tests/unit/aggregation/_utils/test_dual_cone.py index 3923d3f0d..68a8a75d7 100644 --- a/tests/unit/aggregation/_utils/test_dual_cone.py +++ b/tests/unit/aggregation/_utils/test_dual_cone.py @@ -8,7 +8,7 @@ @mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)]) -def test_solution_weights(shape: tuple[int, int]): +def test_solution_weights(shape: tuple[int, int]) -> None: r""" Tests that `_project_weights` returns valid weights corresponding to the projection onto the dual cone of a matrix with the specified shape. @@ -54,7 +54,7 @@ def test_solution_weights(shape: tuple[int, int]): @mark.parametrize("shape", [(5, 7), (9, 37), (32, 114)]) @mark.parametrize("scaling", [2 ** (-4), 2 ** (-2), 2**2, 2**4]) -def test_scale_invariant(shape: tuple[int, int], scaling: float): +def test_scale_invariant(shape: tuple[int, int], scaling: float) -> None: """ Tests that `_project_weights` is invariant under scaling. """ @@ -70,7 +70,7 @@ def test_scale_invariant(shape: tuple[int, int], scaling: float): @mark.parametrize("shape", [(5, 2, 3), (1, 3, 6, 9), (2, 1, 1, 5, 8), (3, 1)]) -def test_tensorization_shape(shape: tuple[int, ...]): +def test_tensorization_shape(shape: tuple[int, ...]) -> None: """ Tests that applying `_project_weights` on a tensor is equivalent to applying it on the tensor reshaped as matrix and to reshape the result back to the original tensor's shape. @@ -88,7 +88,7 @@ def test_tensorization_shape(shape: tuple[int, ...]): assert_close(W_matrix.reshape(shape), W_tensor) -def test_project_weight_vector_failure(): +def test_project_weight_vector_failure() -> None: """Tests that `_project_weight_vector` raises an error when the input G has too large values.""" large_J = np.random.randn(10, 100) * 1e5 diff --git a/tests/unit/aggregation/_utils/test_pref_vector.py b/tests/unit/aggregation/_utils/test_pref_vector.py index 159582dd1..0871726a5 100644 --- a/tests/unit/aggregation/_utils/test_pref_vector.py +++ b/tests/unit/aggregation/_utils/test_pref_vector.py @@ -21,6 +21,8 @@ (ones_([1, 1, 1]), raises(ValueError)), ], ) -def test_pref_vector_to_weighting_check(pref_vector: Tensor | None, expectation: ExceptionContext): +def test_pref_vector_to_weighting_check( + pref_vector: Tensor | None, expectation: ExceptionContext +) -> None: with expectation: _ = pref_vector_to_weighting(pref_vector, default=MeanWeighting()) diff --git a/tests/unit/aggregation/test_aggregator_bases.py b/tests/unit/aggregation/test_aggregator_bases.py index b08c37a89..80c9aeaa1 100644 --- a/tests/unit/aggregation/test_aggregator_bases.py +++ b/tests/unit/aggregation/test_aggregator_bases.py @@ -18,6 +18,6 @@ ([1, 2, 3, 4], raises(ValueError)), ], ) -def test_check_is_matrix(shape: Sequence[int], expectation: ExceptionContext): +def test_check_is_matrix(shape: Sequence[int], expectation: ExceptionContext) -> None: with expectation: Aggregator._check_is_matrix(randn_(shape)) diff --git a/tests/unit/aggregation/test_aligned_mtl.py b/tests/unit/aggregation/test_aligned_mtl.py index 70847cbf1..db89207bc 100644 --- a/tests/unit/aggregation/test_aligned_mtl.py +++ b/tests/unit/aggregation/test_aligned_mtl.py @@ -19,16 +19,16 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: AlignedMTL, matrix: Tensor): +def test_expected_structure(aggregator: AlignedMTL, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_permutation_invariant(aggregator: AlignedMTL, matrix: Tensor): +def test_permutation_invariant(aggregator: AlignedMTL, matrix: Tensor) -> None: assert_permutation_invariant(aggregator, matrix) -def test_representations(): +def test_representations() -> None: A = AlignedMTL(pref_vector=None) assert repr(A) == "AlignedMTL(pref_vector=None, scale_mode='min')" assert str(A) == "AlignedMTL" @@ -38,7 +38,7 @@ def test_representations(): assert str(A) == "AlignedMTL([1., 2., 3.])" -def test_invalid_scale_mode(): +def test_invalid_scale_mode() -> None: aggregator = AlignedMTL(scale_mode="test") # type: ignore[arg-type] matrix = ones_(3, 4) with raises(ValueError, match=r"Invalid scale_mode=.*Expected"): diff --git a/tests/unit/aggregation/test_cagrad.py b/tests/unit/aggregation/test_cagrad.py index e77cb729b..c7d18b1f2 100644 --- a/tests/unit/aggregation/test_cagrad.py +++ b/tests/unit/aggregation/test_cagrad.py @@ -23,17 +23,17 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: CAGrad, matrix: Tensor): +def test_expected_structure(aggregator: CAGrad, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) -def test_non_differentiable(aggregator: CAGrad, matrix: Tensor): +def test_non_differentiable(aggregator: CAGrad, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], non_conflicting_pairs_1 + non_conflicting_pairs_2) -def test_non_conflicting(aggregator: CAGrad, matrix: Tensor): +def test_non_conflicting(aggregator: CAGrad, matrix: Tensor) -> None: """Tests that CAGrad is non-conflicting when c >= 1 (it should not hold when c < 1).""" assert_non_conflicting(aggregator, matrix) @@ -48,12 +48,12 @@ def test_non_conflicting(aggregator: CAGrad, matrix: Tensor): (50.0, does_not_raise()), ], ) -def test_c_check(c: float, expectation: ExceptionContext): +def test_c_check(c: float, expectation: ExceptionContext) -> None: with expectation: _ = CAGrad(c=c) -def test_representations(): +def test_representations() -> None: A = CAGrad(c=0.5, norm_eps=0.0001) assert repr(A) == "CAGrad(c=0.5, norm_eps=0.0001)" assert str(A) == "CAGrad0.5" diff --git a/tests/unit/aggregation/test_config.py b/tests/unit/aggregation/test_config.py index 69cc4af14..2db2ea0fa 100644 --- a/tests/unit/aggregation/test_config.py +++ b/tests/unit/aggregation/test_config.py @@ -20,26 +20,26 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: ConFIG, matrix: Tensor): +def test_expected_structure(aggregator: ConFIG, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_permutation_invariant(aggregator: ConFIG, matrix: Tensor): +def test_permutation_invariant(aggregator: ConFIG, matrix: Tensor) -> None: assert_permutation_invariant(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_linear_under_scaling(aggregator: ConFIG, matrix: Tensor): +def test_linear_under_scaling(aggregator: ConFIG, matrix: Tensor) -> None: assert_linear_under_scaling(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) -def test_non_differentiable(aggregator: ConFIG, matrix: Tensor): +def test_non_differentiable(aggregator: ConFIG, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) -def test_representations(): +def test_representations() -> None: A = ConFIG() assert repr(A) == "ConFIG(pref_vector=None)" assert str(A) == "ConFIG" diff --git a/tests/unit/aggregation/test_constant.py b/tests/unit/aggregation/test_constant.py index 984b7a1fe..aa1332fcb 100644 --- a/tests/unit/aggregation/test_constant.py +++ b/tests/unit/aggregation/test_constant.py @@ -28,17 +28,17 @@ def _make_aggregator(matrix: Tensor) -> Constant: @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: Constant, matrix: Tensor): +def test_expected_structure(aggregator: Constant, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_linear_under_scaling(aggregator: Constant, matrix: Tensor): +def test_linear_under_scaling(aggregator: Constant, matrix: Tensor) -> None: assert_linear_under_scaling(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], non_strong_pairs) -def test_strongly_stationary(aggregator: Constant, matrix: Tensor): +def test_strongly_stationary(aggregator: Constant, matrix: Tensor) -> None: assert_strongly_stationary(aggregator, matrix) @@ -57,7 +57,7 @@ def test_strongly_stationary(aggregator: Constant, matrix: Tensor): ([1, 1, 1, 1, 1], raises(ValueError)), ], ) -def test_weights_shape_check(weights_shape: list[int], expectation: ExceptionContext): +def test_weights_shape_check(weights_shape: list[int], expectation: ExceptionContext) -> None: weights = ones_(weights_shape) with expectation: _ = Constant(weights=weights) @@ -75,7 +75,9 @@ def test_weights_shape_check(weights_shape: list[int], expectation: ExceptionCon ([5], 4, raises(ValueError)), ], ) -def test_matrix_shape_check(weights_shape: list[int], n_rows: int, expectation: ExceptionContext): +def test_matrix_shape_check( + weights_shape: list[int], n_rows: int, expectation: ExceptionContext +) -> None: matrix = ones_([n_rows, 5]) weights = ones_(weights_shape) aggregator = Constant(weights) @@ -84,7 +86,7 @@ def test_matrix_shape_check(weights_shape: list[int], n_rows: int, expectation: _ = aggregator(matrix) -def test_representations(): +def test_representations() -> None: A = Constant(weights=torch.tensor([1.0, 2.0], device="cpu")) assert repr(A) == "Constant(weights=tensor([1., 2.]))" assert str(A) == "Constant([1., 2.])" diff --git a/tests/unit/aggregation/test_dualproj.py b/tests/unit/aggregation/test_dualproj.py index 0f4407d23..5bd0e71af 100644 --- a/tests/unit/aggregation/test_dualproj.py +++ b/tests/unit/aggregation/test_dualproj.py @@ -21,31 +21,31 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: DualProj, matrix: Tensor): +def test_expected_structure(aggregator: DualProj, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_non_conflicting(aggregator: DualProj, matrix: Tensor): +def test_non_conflicting(aggregator: DualProj, matrix: Tensor) -> None: assert_non_conflicting(aggregator, matrix, atol=1e-04, rtol=1e-04) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_permutation_invariant(aggregator: DualProj, matrix: Tensor): +def test_permutation_invariant(aggregator: DualProj, matrix: Tensor) -> None: assert_permutation_invariant(aggregator, matrix, n_runs=5, atol=2e-07, rtol=2e-07) @mark.parametrize(["aggregator", "matrix"], non_strong_pairs) -def test_strongly_stationary(aggregator: DualProj, matrix: Tensor): +def test_strongly_stationary(aggregator: DualProj, matrix: Tensor) -> None: assert_strongly_stationary(aggregator, matrix, threshold=3e-03) @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) -def test_non_differentiable(aggregator: DualProj, matrix: Tensor): +def test_non_differentiable(aggregator: DualProj, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) -def test_representations(): +def test_representations() -> None: A = DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog") assert ( repr(A) == "DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')" diff --git a/tests/unit/aggregation/test_graddrop.py b/tests/unit/aggregation/test_graddrop.py index 59e6e1ae0..2868dca0d 100644 --- a/tests/unit/aggregation/test_graddrop.py +++ b/tests/unit/aggregation/test_graddrop.py @@ -18,12 +18,12 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: GradDrop, matrix: Tensor): +def test_expected_structure(aggregator: GradDrop, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) -def test_non_differentiable(aggregator: GradDrop, matrix: Tensor): +def test_non_differentiable(aggregator: GradDrop, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) @@ -42,7 +42,7 @@ def test_non_differentiable(aggregator: GradDrop, matrix: Tensor): ([1, 1, 1, 1, 1], raises(ValueError)), ], ) -def test_leak_shape_check(leak_shape: list[int], expectation: ExceptionContext): +def test_leak_shape_check(leak_shape: list[int], expectation: ExceptionContext) -> None: leak = ones_(leak_shape) with expectation: _ = GradDrop(leak=leak) @@ -60,7 +60,9 @@ def test_leak_shape_check(leak_shape: list[int], expectation: ExceptionContext): ([5], 4, raises(ValueError)), ], ) -def test_matrix_shape_check(leak_shape: list[int], n_rows: int, expectation: ExceptionContext): +def test_matrix_shape_check( + leak_shape: list[int], n_rows: int, expectation: ExceptionContext +) -> None: matrix = ones_([n_rows, 5]) leak = ones_(leak_shape) aggregator = GradDrop(leak=leak) @@ -69,7 +71,7 @@ def test_matrix_shape_check(leak_shape: list[int], n_rows: int, expectation: Exc _ = aggregator(matrix) -def test_representations(): +def test_representations() -> None: A = GradDrop(leak=torch.tensor([0.0, 1.0], device="cpu")) assert re.match( r"GradDrop\(f=, leak=tensor\(\[0\., 1\.\]\)\)", diff --git a/tests/unit/aggregation/test_imtl_g.py b/tests/unit/aggregation/test_imtl_g.py index e9ba838cb..03c41d5ef 100644 --- a/tests/unit/aggregation/test_imtl_g.py +++ b/tests/unit/aggregation/test_imtl_g.py @@ -18,21 +18,21 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: IMTLG, matrix: Tensor): +def test_expected_structure(aggregator: IMTLG, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_permutation_invariant(aggregator: IMTLG, matrix: Tensor): +def test_permutation_invariant(aggregator: IMTLG, matrix: Tensor) -> None: assert_permutation_invariant(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) -def test_non_differentiable(aggregator: IMTLG, matrix: Tensor): +def test_non_differentiable(aggregator: IMTLG, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) -def test_imtlg_zero(): +def test_imtlg_zero() -> None: """ Tests that IMTLG correctly returns the 0 vector in the special case where input matrix only consists of zeros. @@ -43,7 +43,7 @@ def test_imtlg_zero(): assert_close(A(J), zeros_(3)) -def test_representations(): +def test_representations() -> None: A = IMTLG() assert repr(A) == "IMTLG()" assert str(A) == "IMTLG" diff --git a/tests/unit/aggregation/test_krum.py b/tests/unit/aggregation/test_krum.py index 48fa4019a..4097f2ebe 100644 --- a/tests/unit/aggregation/test_krum.py +++ b/tests/unit/aggregation/test_krum.py @@ -15,7 +15,7 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: Krum, matrix: Tensor): +def test_expected_structure(aggregator: Krum, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @@ -29,7 +29,7 @@ def test_expected_structure(aggregator: Krum, matrix: Tensor): (5, does_not_raise()), ], ) -def test_n_byzantine_check(n_byzantine: int, expectation: ExceptionContext): +def test_n_byzantine_check(n_byzantine: int, expectation: ExceptionContext) -> None: with expectation: _ = Krum(n_byzantine=n_byzantine, n_selected=1) @@ -44,7 +44,7 @@ def test_n_byzantine_check(n_byzantine: int, expectation: ExceptionContext): (5, does_not_raise()), ], ) -def test_n_selected_check(n_selected: int, expectation: ExceptionContext): +def test_n_selected_check(n_selected: int, expectation: ExceptionContext) -> None: with expectation: _ = Krum(n_byzantine=1, n_selected=n_selected) @@ -66,7 +66,7 @@ def test_matrix_shape_check( n_selected: int, n_rows: int, expectation: ExceptionContext, -): +) -> None: aggregator = Krum(n_byzantine=n_byzantine, n_selected=n_selected) matrix = ones_([n_rows, 5]) @@ -74,7 +74,7 @@ def test_matrix_shape_check( _ = aggregator(matrix) -def test_representations(): +def test_representations() -> None: A = Krum(n_byzantine=1, n_selected=2) assert repr(A) == "Krum(n_byzantine=1, n_selected=2)" assert str(A) == "Krum1-2" diff --git a/tests/unit/aggregation/test_mean.py b/tests/unit/aggregation/test_mean.py index 4d3fbf3a0..88c28e937 100644 --- a/tests/unit/aggregation/test_mean.py +++ b/tests/unit/aggregation/test_mean.py @@ -17,26 +17,26 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: Mean, matrix: Tensor): +def test_expected_structure(aggregator: Mean, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_permutation_invariant(aggregator: Mean, matrix: Tensor): +def test_permutation_invariant(aggregator: Mean, matrix: Tensor) -> None: assert_permutation_invariant(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_linear_under_scaling(aggregator: Mean, matrix: Tensor): +def test_linear_under_scaling(aggregator: Mean, matrix: Tensor) -> None: assert_linear_under_scaling(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], non_strong_pairs) -def test_strongly_stationary(aggregator: Mean, matrix: Tensor): +def test_strongly_stationary(aggregator: Mean, matrix: Tensor) -> None: assert_strongly_stationary(aggregator, matrix) -def test_representations(): +def test_representations() -> None: A = Mean() assert repr(A) == "Mean()" assert str(A) == "Mean" diff --git a/tests/unit/aggregation/test_mgda.py b/tests/unit/aggregation/test_mgda.py index 2d1fe068d..5c925b8fe 100644 --- a/tests/unit/aggregation/test_mgda.py +++ b/tests/unit/aggregation/test_mgda.py @@ -19,17 +19,17 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: MGDA, matrix: Tensor): +def test_expected_structure(aggregator: MGDA, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_non_conflicting(aggregator: MGDA, matrix: Tensor): +def test_non_conflicting(aggregator: MGDA, matrix: Tensor) -> None: assert_non_conflicting(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_permutation_invariant(aggregator: MGDA, matrix: Tensor): +def test_permutation_invariant(aggregator: MGDA, matrix: Tensor) -> None: assert_permutation_invariant(aggregator, matrix) @@ -43,7 +43,7 @@ def test_permutation_invariant(aggregator: MGDA, matrix: Tensor): (50, 100), ], ) -def test_mgda_satisfies_kkt_conditions(shape: tuple[int, int]): +def test_mgda_satisfies_kkt_conditions(shape: tuple[int, int]) -> None: matrix = randn_(shape) gramian = compute_gramian(matrix) @@ -66,7 +66,7 @@ def test_mgda_satisfies_kkt_conditions(shape: tuple[int, int]): assert_close(positive_mu.norm(), mu.norm(), atol=1e-02, rtol=0.0) -def test_representations(): +def test_representations() -> None: A = MGDA(epsilon=0.001, max_iters=100) assert repr(A) == "MGDA(epsilon=0.001, max_iters=100)" assert str(A) == "MGDA" diff --git a/tests/unit/aggregation/test_nash_mtl.py b/tests/unit/aggregation/test_nash_mtl.py index 44e154002..a1200d465 100644 --- a/tests/unit/aggregation/test_nash_mtl.py +++ b/tests/unit/aggregation/test_nash_mtl.py @@ -29,18 +29,18 @@ def _make_aggregator(matrix: Tensor) -> NashMTL: "ignore:You are solving a parameterized problem that is not DPP.", ) @mark.parametrize(["aggregator", "matrix"], standard_pairs) -def test_expected_structure(aggregator: NashMTL, matrix: Tensor): +def test_expected_structure(aggregator: NashMTL, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.filterwarnings("ignore:You are solving a parameterized problem that is not DPP.") @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) -def test_non_differentiable(aggregator: NashMTL, matrix: Tensor): +def test_non_differentiable(aggregator: NashMTL, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) @mark.filterwarnings("ignore: You are solving a parameterized problem that is not DPP.") -def test_nash_mtl_reset(): +def test_nash_mtl_reset() -> None: """ Tests that the reset method of NashMTL correctly resets its internal state, by verifying that the result is the same after reset as it is right after instantiation. @@ -59,7 +59,7 @@ def test_nash_mtl_reset(): assert_close(result, expected) -def test_representations(): +def test_representations() -> None: A = NashMTL(n_tasks=2, max_norm=1.5, update_weights_every=2, optim_niter=5) assert repr(A) == "NashMTL(n_tasks=2, max_norm=1.5, update_weights_every=2, optim_niter=5)" assert str(A) == "NashMTL" diff --git a/tests/unit/aggregation/test_pcgrad.py b/tests/unit/aggregation/test_pcgrad.py index 57a9120cf..b776071d3 100644 --- a/tests/unit/aggregation/test_pcgrad.py +++ b/tests/unit/aggregation/test_pcgrad.py @@ -17,12 +17,12 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: PCGrad, matrix: Tensor): +def test_expected_structure(aggregator: PCGrad, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) -def test_non_differentiable(aggregator: PCGrad, matrix: Tensor): +def test_non_differentiable(aggregator: PCGrad, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) @@ -41,7 +41,7 @@ def test_non_differentiable(aggregator: PCGrad, matrix: Tensor): (2, 11100), ], ) -def test_equivalence_upgrad_sum_two_rows(shape: tuple[int, int]): +def test_equivalence_upgrad_sum_two_rows(shape: tuple[int, int]) -> None: """ Tests that UPGradWeighting of a SumWeighting is equivalent to PCGradWeighting for matrices of 2 rows. @@ -64,7 +64,7 @@ def test_equivalence_upgrad_sum_two_rows(shape: tuple[int, int]): assert_close(result, expected, atol=4e-04, rtol=0.0) -def test_representations(): +def test_representations() -> None: A = PCGrad() assert repr(A) == "PCGrad()" assert str(A) == "PCGrad" diff --git a/tests/unit/aggregation/test_random.py b/tests/unit/aggregation/test_random.py index d93929e2b..77ab7f423 100644 --- a/tests/unit/aggregation/test_random.py +++ b/tests/unit/aggregation/test_random.py @@ -12,16 +12,16 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: Random, matrix: Tensor): +def test_expected_structure(aggregator: Random, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], non_strong_pairs) -def test_strongly_stationary(aggregator: Random, matrix: Tensor): +def test_strongly_stationary(aggregator: Random, matrix: Tensor) -> None: assert_strongly_stationary(aggregator, matrix) -def test_representations(): +def test_representations() -> None: A = Random() assert repr(A) == "Random()" assert str(A) == "Random" diff --git a/tests/unit/aggregation/test_sum.py b/tests/unit/aggregation/test_sum.py index 99fe4e9f5..386c507f9 100644 --- a/tests/unit/aggregation/test_sum.py +++ b/tests/unit/aggregation/test_sum.py @@ -17,26 +17,26 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: Sum, matrix: Tensor): +def test_expected_structure(aggregator: Sum, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_permutation_invariant(aggregator: Sum, matrix: Tensor): +def test_permutation_invariant(aggregator: Sum, matrix: Tensor) -> None: assert_permutation_invariant(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_linear_under_scaling(aggregator: Sum, matrix: Tensor): +def test_linear_under_scaling(aggregator: Sum, matrix: Tensor) -> None: assert_linear_under_scaling(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], non_strong_pairs) -def test_strongly_stationary(aggregator: Sum, matrix: Tensor): +def test_strongly_stationary(aggregator: Sum, matrix: Tensor) -> None: assert_strongly_stationary(aggregator, matrix) -def test_representations(): +def test_representations() -> None: A = Sum() assert repr(A) == "Sum()" assert str(A) == "Sum" diff --git a/tests/unit/aggregation/test_trimmed_mean.py b/tests/unit/aggregation/test_trimmed_mean.py index cdeb93986..3a6ccb2bc 100644 --- a/tests/unit/aggregation/test_trimmed_mean.py +++ b/tests/unit/aggregation/test_trimmed_mean.py @@ -15,12 +15,12 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: TrimmedMean, matrix: Tensor): +def test_expected_structure(aggregator: TrimmedMean, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_permutation_invariant(aggregator: TrimmedMean, matrix: Tensor): +def test_permutation_invariant(aggregator: TrimmedMean, matrix: Tensor) -> None: assert_permutation_invariant(aggregator, matrix) @@ -34,7 +34,7 @@ def test_permutation_invariant(aggregator: TrimmedMean, matrix: Tensor): (5, does_not_raise()), ], ) -def test_trim_number_check(trim_number: int, expectation: ExceptionContext): +def test_trim_number_check(trim_number: int, expectation: ExceptionContext) -> None: with expectation: _ = TrimmedMean(trim_number=trim_number) @@ -49,7 +49,7 @@ def test_trim_number_check(trim_number: int, expectation: ExceptionContext): (10, 5, raises(ValueError)), ], ) -def test_matrix_shape_check(n_rows: int, trim_number: int, expectation: ExceptionContext): +def test_matrix_shape_check(n_rows: int, trim_number: int, expectation: ExceptionContext) -> None: matrix = ones_([n_rows, 5]) aggregator = TrimmedMean(trim_number=trim_number) @@ -57,7 +57,7 @@ def test_matrix_shape_check(n_rows: int, trim_number: int, expectation: Exceptio _ = aggregator(matrix) -def test_representations(): +def test_representations() -> None: aggregator = TrimmedMean(trim_number=2) assert repr(aggregator) == "TrimmedMean(trim_number=2)" assert str(aggregator) == "TM2" diff --git a/tests/unit/aggregation/test_upgrad.py b/tests/unit/aggregation/test_upgrad.py index 9fc480d2f..1859b6625 100644 --- a/tests/unit/aggregation/test_upgrad.py +++ b/tests/unit/aggregation/test_upgrad.py @@ -22,36 +22,36 @@ @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) -def test_expected_structure(aggregator: UPGrad, matrix: Tensor): +def test_expected_structure(aggregator: UPGrad, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_non_conflicting(aggregator: UPGrad, matrix: Tensor): +def test_non_conflicting(aggregator: UPGrad, matrix: Tensor) -> None: assert_non_conflicting(aggregator, matrix, atol=4e-04, rtol=4e-04) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_permutation_invariant(aggregator: UPGrad, matrix: Tensor): +def test_permutation_invariant(aggregator: UPGrad, matrix: Tensor) -> None: assert_permutation_invariant(aggregator, matrix, n_runs=5, atol=5e-07, rtol=5e-07) @mark.parametrize(["aggregator", "matrix"], typical_pairs) -def test_linear_under_scaling(aggregator: UPGrad, matrix: Tensor): +def test_linear_under_scaling(aggregator: UPGrad, matrix: Tensor) -> None: assert_linear_under_scaling(aggregator, matrix, n_runs=5, atol=6e-02, rtol=6e-02) @mark.parametrize(["aggregator", "matrix"], non_strong_pairs) -def test_strongly_stationary(aggregator: UPGrad, matrix: Tensor): +def test_strongly_stationary(aggregator: UPGrad, matrix: Tensor) -> None: assert_strongly_stationary(aggregator, matrix, threshold=5e-03) @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) -def test_non_differentiable(aggregator: UPGrad, matrix: Tensor): +def test_non_differentiable(aggregator: UPGrad, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) -def test_representations(): +def test_representations() -> None: A = UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog") assert repr(A) == "UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')" assert str(A) == "UPGrad" diff --git a/tests/unit/aggregation/test_values.py b/tests/unit/aggregation/test_values.py index 719be0c39..860f313d8 100644 --- a/tests/unit/aggregation/test_values.py +++ b/tests/unit/aggregation/test_values.py @@ -110,14 +110,14 @@ @mark.parametrize(["A", "J", "expected_output"], AGGREGATOR_PARAMETRIZATIONS) -def test_aggregator_output(A: Aggregator, J: Tensor, expected_output: Tensor): +def test_aggregator_output(A: Aggregator, J: Tensor, expected_output: Tensor) -> None: """Test that the output values of an aggregator are fixed (on cpu).""" assert_close(A(J), expected_output, rtol=0, atol=1e-4) @mark.parametrize(["W", "G", "expected_output"], WEIGHTING_PARAMETRIZATIONS) -def test_weighting_output(W: Weighting, G: Tensor, expected_output: Tensor): +def test_weighting_output(W: Weighting, G: Tensor, expected_output: Tensor) -> None: """Test that the output values of a weighting are fixed (on cpu).""" assert_close(W(G), expected_output, rtol=0, atol=1e-4) diff --git a/tests/unit/autogram/test_edge_registry.py b/tests/unit/autogram/test_edge_registry.py index 88d6da8cc..56e5f7204 100644 --- a/tests/unit/autogram/test_edge_registry.py +++ b/tests/unit/autogram/test_edge_registry.py @@ -4,7 +4,7 @@ from torchjd.autogram._edge_registry import EdgeRegistry -def test_all_edges_are_leaves1(): +def test_all_edges_are_leaves1() -> None: """Tests that get_leaf_edges works correctly when all edges are already leaves.""" a = randn_([3, 4], requires_grad=True) @@ -22,7 +22,7 @@ def test_all_edges_are_leaves1(): assert leaves == expected_leaves -def test_all_edges_are_leaves2(): +def test_all_edges_are_leaves2() -> None: """ Tests that get_leaf_edges works correctly when all edges are already leaves of the graph of edges leading to them, but are not leaves of the autograd graph. @@ -46,7 +46,7 @@ def test_all_edges_are_leaves2(): assert leaves == expected_leaves -def test_some_edges_are_not_leaves1(): +def test_some_edges_are_not_leaves1() -> None: """Tests that get_leaf_edges works correctly when some edges are leaves and some are not.""" a = randn_([3, 4], requires_grad=True) @@ -67,7 +67,7 @@ def test_some_edges_are_not_leaves1(): assert leaves == expected_leaves -def test_some_edges_are_not_leaves2(): +def test_some_edges_are_not_leaves2() -> None: """ Tests that get_leaf_edges works correctly when some edges are leaves and some are not. This time, not all tensors in the graph are registered so not all leavese in the graph have to be diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index b96824ec5..76fdc41fc 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -147,7 +147,7 @@ def _assert_gramian_is_equivalent_to_autograd( factory: ModuleFactory, batch_size: int, batch_dim: int | None, -): +) -> None: model_autograd, model_autogram = factory(), factory() engine = Engine(model_autogram, batch_dim=batch_dim) inputs, targets = make_inputs_and_targets(model_autograd, batch_size) @@ -191,7 +191,7 @@ def _get_losses_and_params_without_cross_terms( @mark.parametrize(["factory", "batch_size"], PARAMETRIZATIONS) @mark.parametrize("batch_dim", [0, None]) -def test_compute_gramian(factory: ModuleFactory, batch_size: int, batch_dim: int | None): +def test_compute_gramian(factory: ModuleFactory, batch_size: int, batch_dim: int | None) -> None: """Tests that the autograd and the autogram engines compute the same gramian.""" _assert_gramian_is_equivalent_to_autograd(factory, batch_size, batch_dim) @@ -213,7 +213,7 @@ def test_compute_gramian_with_weird_modules( factory: ModuleFactory, batch_size: int, batch_dim: int | None, -): +) -> None: """ Tests that compute_gramian works even with some problematic modules when batch_dim is None. It is expected to fail on those when the engine uses the batched optimization (when batch_dim=0). @@ -237,7 +237,7 @@ def test_compute_gramian_unsupported_architectures( factory: ModuleFactory, batch_size: int, batch_dim: int | None, -): +) -> None: """ Tests compute_gramian on some architectures that are known to be unsupported. It is expected to fail. @@ -275,7 +275,7 @@ def test_compute_gramian_various_output_shapes( batch_dim: int | None, movedim_source: list[int], movedim_destination: list[int], -): +) -> None: """ Tests that the autograd and the autogram engines compute the same gramian when the output can have various different shapes, and can be batched in any of its dimensions. @@ -312,7 +312,7 @@ def _non_empty_subsets(S: set) -> list[list]: @mark.parametrize("gramian_module_names", _non_empty_subsets({"fc0", "fc1", "fc2", "fc3", "fc4"})) @mark.parametrize("batch_dim", [0, None]) -def test_compute_partial_gramian(gramian_module_names: set[str], batch_dim: int | None): +def test_compute_partial_gramian(gramian_module_names: set[str], batch_dim: int | None) -> None: """ Tests that the autograd and the autogram engines compute the same gramian when only a subset of the model parameters is specified. @@ -340,7 +340,9 @@ def test_compute_partial_gramian(gramian_module_names: set[str], batch_dim: int @mark.parametrize(["factory", "batch_size"], PARAMETRIZATIONS) @mark.parametrize("batch_dim", [0, None]) -def test_iwrm_steps_with_autogram(factory: ModuleFactory, batch_size: int, batch_dim: int | None): +def test_iwrm_steps_with_autogram( + factory: ModuleFactory, batch_size: int, batch_dim: int | None +) -> None: """Tests that the autogram engine doesn't raise any error during several IWRM iterations.""" n_iter = 3 @@ -365,7 +367,7 @@ def test_autograd_while_modules_are_hooked( batch_size: int, use_engine: bool, batch_dim: int | None, -): +) -> None: """ Tests that the hooks added when constructing the engine do not interfere with a simple autograd call. @@ -402,7 +404,7 @@ def test_autograd_while_modules_are_hooked( (ModuleFactory(BatchNorm2d, num_features=3, affine=True, track_running_stats=False), 0), ], ) -def test_incompatible_modules(factory: ModuleFactory, batch_dim: int | None): +def test_incompatible_modules(factory: ModuleFactory, batch_dim: int | None) -> None: """Tests that the engine cannot be constructed with incompatible modules.""" model = factory() @@ -410,7 +412,7 @@ def test_incompatible_modules(factory: ModuleFactory, batch_dim: int | None): _ = Engine(model, batch_dim=batch_dim) -def test_compute_gramian_manual(): +def test_compute_gramian_manual() -> None: """ Tests that the Gramian computed by the `Engine` equals to a manual computation of the expected Gramian. @@ -454,7 +456,7 @@ def test_compute_gramian_manual(): [1], ], ) -def test_reshape_equivariance(shape: list[int]): +def test_reshape_equivariance(shape: list[int]) -> None: """ Test equivariance of `compute_gramian` under reshape operation. More precisely, if we reshape the `output` to some `shape`, then the result is the same as reshaping the Gramian to the @@ -492,7 +494,7 @@ def test_reshape_equivariance(shape: list[int]): ([1, 1, 1], [1, 0], [0, 1]), ], ) -def test_movedim_equivariance(shape: list[int], source: list[int], destination: list[int]): +def test_movedim_equivariance(shape: list[int], source: list[int], destination: list[int]) -> None: """ Test equivariance of `compute_gramian` under movedim operation. More precisely, if we movedim the `output` on some dimensions, then the result is the same as movedim on the Gramian with the @@ -532,7 +534,7 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination: ([4, 3, 1], 2), ], ) -def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int): +def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int) -> None: """ Tests that for a vector with some batched dimensions, the gramian is the same if we use the appropriate `batch_dim` or if we don't use any. @@ -558,7 +560,7 @@ def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int): @mark.parametrize(["factory", "batch_size"], PARAMETRIZATIONS) -def test_batched_non_batched_equivalence_2(factory: ModuleFactory, batch_size: int): +def test_batched_non_batched_equivalence_2(factory: ModuleFactory, batch_size: int) -> None: """ Same as test_batched_non_batched_equivalence but on real architectures, and thus only between batch_size=0 and batch_size=None. diff --git a/tests/unit/autogram/test_gramian_utils.py b/tests/unit/autogram/test_gramian_utils.py index 5f74df8dc..7d4c22162 100644 --- a/tests/unit/autogram/test_gramian_utils.py +++ b/tests/unit/autogram/test_gramian_utils.py @@ -26,7 +26,7 @@ ([6, 7, 9], [6, 7, 9]), ], ) -def test_reshape_equivarience(original_shape: list[int], target_shape: list[int]): +def test_reshape_equivarience(original_shape: list[int], target_shape: list[int]) -> None: """Tests that reshape_gramian is such that compute_gramian is equivariant to a reshape.""" original_matrix = randn_([*original_shape, 2]) @@ -55,7 +55,7 @@ def test_reshape_equivarience(original_shape: list[int], target_shape: list[int] ([6, 7, 9], [6, 7, 9]), ], ) -def test_reshape_yields_psd(original_shape: list[int], target_shape: list[int]): +def test_reshape_yields_psd(original_shape: list[int], target_shape: list[int]) -> None: matrix = randn_([*original_shape, 2]) gramian = compute_gramian(matrix, 1) reshaped_gramian = reshape(gramian, target_shape) @@ -72,7 +72,7 @@ def test_reshape_yields_psd(original_shape: list[int], target_shape: list[int]): [6, 7, 9], ], ) -def test_flatten_yields_matrix(shape: list[int]): +def test_flatten_yields_matrix(shape: list[int]) -> None: matrix = randn_([*shape, 2]) gramian = compute_gramian(matrix, 1) flattened_gramian = flatten(gramian) @@ -89,7 +89,7 @@ def test_flatten_yields_matrix(shape: list[int]): [6, 7, 9], ], ) -def test_flatten_yields_psd(shape: list[int]): +def test_flatten_yields_psd(shape: list[int]) -> None: matrix = randn_([*shape, 2]) gramian = compute_gramian(matrix, 1) flattened_gramian = flatten(gramian) @@ -114,7 +114,7 @@ def test_flatten_yields_psd(shape: list[int]): ([2, 2, 3], [0, 2, 1], [1, 0, 2]), ], ) -def test_movedim_equivariance(shape: list[int], source: list[int], destination: list[int]): +def test_movedim_equivariance(shape: list[int], source: list[int], destination: list[int]) -> None: """Tests that movedim_gramian is such that compute_gramian is equivariant to a movedim.""" original_matrix = randn_([*shape, 2]) @@ -146,7 +146,7 @@ def test_movedim_equivariance(shape: list[int], source: list[int], destination: ([2, 2, 3], [0, 2, 1], [1, 0, 2]), ], ) -def test_movedim_yields_psd(shape: list[int], source: list[int], destination: list[int]): +def test_movedim_yields_psd(shape: list[int], source: list[int], destination: list[int]) -> None: matrix = randn_([*shape, 2]) gramian = compute_gramian(matrix, 1) moveddim_gramian = movedim(gramian, source, destination) diff --git a/tests/unit/autojac/_transform/test_accumulate.py b/tests/unit/autojac/_transform/test_accumulate.py index 8c179a896..6e2a137e3 100644 --- a/tests/unit/autojac/_transform/test_accumulate.py +++ b/tests/unit/autojac/_transform/test_accumulate.py @@ -6,7 +6,7 @@ from torchjd.autojac._transform import AccumulateGrad, AccumulateJac -def test_single_grad_accumulation(): +def test_single_grad_accumulation() -> None: """ Tests that the AccumulateGrad transform correctly accumulates gradients in .grad fields when run once. @@ -27,7 +27,7 @@ def test_single_grad_accumulation(): @mark.parametrize("iterations", [1, 2, 4, 10, 13]) -def test_multiple_grad_accumulations(iterations: int): +def test_multiple_grad_accumulations(iterations: int) -> None: """ Tests that the AccumulateGrad transform correctly accumulates gradients in .grad fields when run `iterations` times. @@ -47,7 +47,7 @@ def test_multiple_grad_accumulations(iterations: int): assert_grad_close(key, iterations * value) -def test_accumulate_grad_fails_when_no_requires_grad(): +def test_accumulate_grad_fails_when_no_requires_grad() -> None: """ Tests that the AccumulateGrad transform raises an error when it tries to populate a .grad of a tensor that does not require grad. @@ -63,7 +63,7 @@ def test_accumulate_grad_fails_when_no_requires_grad(): accumulate(input) -def test_accumulate_grad_fails_when_no_leaf_and_no_retains_grad(): +def test_accumulate_grad_fails_when_no_leaf_and_no_retains_grad() -> None: """ Tests that the AccumulateGrad transform raises an error when it tries to populate a .grad of a tensor that is not a leaf and that does not retain grad. @@ -79,7 +79,7 @@ def test_accumulate_grad_fails_when_no_leaf_and_no_retains_grad(): accumulate(input) -def test_accumulate_grad_check_keys(): +def test_accumulate_grad_check_keys() -> None: """Tests that the `check_keys` method works correctly for AccumulateGrad.""" key = tensor_([1.0], requires_grad=True) @@ -89,7 +89,7 @@ def test_accumulate_grad_check_keys(): assert output_keys == set() -def test_single_jac_accumulation(): +def test_single_jac_accumulation() -> None: """ Tests that the AccumulateJac transform correctly accumulates jacobians in .jac fields when run once. @@ -110,7 +110,7 @@ def test_single_jac_accumulation(): @mark.parametrize("iterations", [1, 2, 4, 10, 13]) -def test_multiple_jac_accumulations(iterations: int): +def test_multiple_jac_accumulations(iterations: int) -> None: """ Tests that the AccumulateJac transform correctly accumulates jacobians in .jac fields when run `iterations` times. @@ -131,7 +131,7 @@ def test_multiple_jac_accumulations(iterations: int): assert_jac_close(key, iterations * value) -def test_accumulate_jac_fails_when_no_requires_grad(): +def test_accumulate_jac_fails_when_no_requires_grad() -> None: """ Tests that the AccumulateJac transform raises an error when it tries to populate a .jac of a tensor that does not require grad. @@ -147,7 +147,7 @@ def test_accumulate_jac_fails_when_no_requires_grad(): accumulate(input) -def test_accumulate_jac_fails_when_no_leaf_and_no_retains_grad(): +def test_accumulate_jac_fails_when_no_leaf_and_no_retains_grad() -> None: """ Tests that the AccumulateJac transform raises an error when it tries to populate a .jac of a tensor that is not a leaf and that does not retain grad. @@ -163,7 +163,7 @@ def test_accumulate_jac_fails_when_no_leaf_and_no_retains_grad(): accumulate(input) -def test_accumulate_jac_fails_when_shape_mismatch(): +def test_accumulate_jac_fails_when_shape_mismatch() -> None: """ Tests that the AccumulateJac transform raises an error when the jacobian shape does not match the parameter shape (ignoring the first dimension). @@ -179,7 +179,7 @@ def test_accumulate_jac_fails_when_shape_mismatch(): accumulate(input) -def test_accumulate_jac_check_keys(): +def test_accumulate_jac_check_keys() -> None: """Tests that the `check_keys` method works correctly for AccumulateJac.""" key = tensor_([1.0], requires_grad=True) diff --git a/tests/unit/autojac/_transform/test_base.py b/tests/unit/autojac/_transform/test_base.py index 254147bdc..d0d7f96d5 100644 --- a/tests/unit/autojac/_transform/test_base.py +++ b/tests/unit/autojac/_transform/test_base.py @@ -10,11 +10,11 @@ class FakeTransform(Transform): Fake ``Transform`` to test `check_keys` when composing and conjuncting. """ - def __init__(self, required_keys: set[Tensor], output_keys: set[Tensor]): + def __init__(self, required_keys: set[Tensor], output_keys: set[Tensor]) -> None: self._required_keys = required_keys self._output_keys = output_keys - def __str__(self): + def __str__(self) -> str: return "T" def __call__(self, _input: TensorDict, /) -> TensorDict: @@ -29,7 +29,7 @@ def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]: return self._output_keys -def test_composition_check_keys(): +def test_composition_check_keys() -> None: """ Tests that `check_keys` works correctly for a composition of transforms: the inner transform's `output_keys` has to satisfy the outer transform's requirements. @@ -52,7 +52,7 @@ def test_composition_check_keys(): (t2 << t1).check_keys({a1}) -def test_conjunct_check_keys_1(): +def test_conjunct_check_keys_1() -> None: """ Tests that `check_keys` works correctly for a conjunction of transforms: all transforms should successfully check their keys. @@ -75,7 +75,7 @@ def test_conjunct_check_keys_1(): (t1 | t2 | t3).check_keys({a1, a2}) -def test_conjunct_check_keys_2(): +def test_conjunct_check_keys_2() -> None: """ Tests that `check_keys` works correctly for a conjunction of transforms: their `output_keys` should be disjoint. @@ -98,7 +98,7 @@ def test_conjunct_check_keys_2(): (t1 | t2 | t3).check_keys(set()) -def test_empty_conjunction(): +def test_empty_conjunction() -> None: """ Tests that it is possible to take the conjunction of no transform. This should return an empty dictionary. @@ -109,7 +109,7 @@ def test_empty_conjunction(): assert len(conjunction({})) == 0 -def test_str(): +def test_str() -> None: """ Tests that the __str__ method works correctly even for transform involving compositions and conjunctions. diff --git a/tests/unit/autojac/_transform/test_diagonalize.py b/tests/unit/autojac/_transform/test_diagonalize.py index c1b30d31e..b4e8e2554 100644 --- a/tests/unit/autojac/_transform/test_diagonalize.py +++ b/tests/unit/autojac/_transform/test_diagonalize.py @@ -6,7 +6,7 @@ from torchjd.autojac._transform import Diagonalize, OrderedSet, RequirementError -def test_single_input(): +def test_single_input() -> None: """Tests that the Diagonalize transform works when given a single input.""" key = tensor_([1.0, 2.0, 3.0]) @@ -21,7 +21,7 @@ def test_single_input(): assert_tensor_dicts_are_close(output, expected_output) -def test_multiple_inputs(): +def test_multiple_inputs() -> None: """Tests that the Diagonalize transform works when given multiple inputs.""" key1 = tensor_([[1.0, 2.0], [4.0, 5.0]]) @@ -77,7 +77,7 @@ def test_multiple_inputs(): assert_tensor_dicts_are_close(output, expected_output) -def test_permute_order(): +def test_permute_order() -> None: """ Tests that the Diagonalize transform outputs a permuted mapping when its keys are permuted. """ @@ -98,7 +98,7 @@ def test_permute_order(): assert_tensor_dicts_are_close(output, expected_output) -def test_check_keys(): +def test_check_keys() -> None: """ Tests that the `check_keys` method works correctly. The input_keys must match the stored considered keys. diff --git a/tests/unit/autojac/_transform/test_grad.py b/tests/unit/autojac/_transform/test_grad.py index f834f73a6..daa35bdfc 100644 --- a/tests/unit/autojac/_transform/test_grad.py +++ b/tests/unit/autojac/_transform/test_grad.py @@ -6,7 +6,7 @@ from torchjd.autojac._transform import Grad, OrderedSet, RequirementError -def test_single_input(): +def test_single_input() -> None: """ Tests that the Grad transform works correctly for a very simple example of differentiation. Here, the function considered is: `y = a * x`. We want to compute the derivative of `y` with @@ -26,7 +26,7 @@ def test_single_input(): assert_tensor_dicts_are_close(gradients, expected_gradients) -def test_empty_inputs_1(): +def test_empty_inputs_1() -> None: """ Tests that the Grad transform works correctly when the `inputs` parameter is an empty `Iterable`. @@ -43,7 +43,7 @@ def test_empty_inputs_1(): assert_tensor_dicts_are_close(gradients, expected_gradients) -def test_empty_inputs_2(): +def test_empty_inputs_2() -> None: """ Tests that the Grad transform works correctly when the `inputs` parameter is an empty `Iterable`. @@ -62,7 +62,7 @@ def test_empty_inputs_2(): assert_tensor_dicts_are_close(gradients, expected_gradients) -def test_empty_outputs(): +def test_empty_outputs() -> None: """ Tests that the Grad transform works correctly when the `outputs` parameter is an empty `Iterable`. @@ -80,7 +80,7 @@ def test_empty_outputs(): assert_tensor_dicts_are_close(gradients, expected_gradients) -def test_retain_graph(): +def test_retain_graph() -> None: """Tests that the `Grad` transform behaves as expected with the `retain_graph` flag.""" x = tensor_(5.0) @@ -100,7 +100,7 @@ def test_retain_graph(): grad_discard_graph(input) -def test_single_input_two_levels(): +def test_single_input_two_levels() -> None: """ Tests that the Grad transform works correctly when composed with another Grad transform. Here, the function considered is: `z = a * x1 * x2`, which is computed in 2 parts: `y = a * x1` @@ -125,7 +125,7 @@ def test_single_input_two_levels(): assert_tensor_dicts_are_close(gradients, expected_gradients) -def test_empty_inputs_two_levels(): +def test_empty_inputs_two_levels() -> None: """ Tests that the Grad transform works correctly when the `inputs` parameter is an empty `Iterable`, with 2 composed Grad transforms. @@ -148,7 +148,7 @@ def test_empty_inputs_two_levels(): assert_tensor_dicts_are_close(gradients, expected_gradients) -def test_vector_output(): +def test_vector_output() -> None: """ Tests that the Grad transform works correctly when the `outputs` contains a single vector. The input (grad_outputs) is not the same for both values of the output, so that this test also @@ -168,7 +168,7 @@ def test_vector_output(): assert_tensor_dicts_are_close(gradients, expected_gradients) -def test_multiple_outputs(): +def test_multiple_outputs() -> None: """ Tests that the Grad transform works correctly when the `outputs` contains 2 scalars. The input (grad_outputs) is not the same for both outputs, so that this test also checks that @@ -189,7 +189,7 @@ def test_multiple_outputs(): assert_tensor_dicts_are_close(gradients, expected_gradients) -def test_multiple_tensor_outputs(): +def test_multiple_tensor_outputs() -> None: """ Tests that the Grad transform works correctly when the `outputs` contains several tensors of different shapes. The input (grad_outputs) is not the same for all values of the outputs, so @@ -216,7 +216,7 @@ def test_multiple_tensor_outputs(): assert_tensor_dicts_are_close(gradients, expected_gradients) -def test_composition_of_grads_is_grad(): +def test_composition_of_grads_is_grad() -> None: """ Tests that the composition of 2 Grad transforms is equivalent to computing the Grad directly in a single transform. @@ -243,7 +243,7 @@ def test_composition_of_grads_is_grad(): assert_tensor_dicts_are_close(gradients, expected_gradients) -def test_conjunction_of_grads_is_grad(): +def test_conjunction_of_grads_is_grad() -> None: """ Tests that the conjunction of 2 Grad transforms is equivalent to computing the Grad directly in a single transform. @@ -267,7 +267,7 @@ def test_conjunction_of_grads_is_grad(): assert_tensor_dicts_are_close(gradients, expected_gradients) -def test_create_graph(): +def test_create_graph() -> None: """Tests that the Grad transform behaves correctly when `create_graph` is set to `True`.""" a = tensor_(2.0, requires_grad=True) @@ -281,7 +281,7 @@ def test_create_graph(): assert gradients[a].requires_grad -def test_check_keys(): +def test_check_keys() -> None: """ Tests that the `check_keys` method works correctly: the input_keys should match the stored outputs. diff --git a/tests/unit/autojac/_transform/test_init.py b/tests/unit/autojac/_transform/test_init.py index 38e4a29e7..1833c1459 100644 --- a/tests/unit/autojac/_transform/test_init.py +++ b/tests/unit/autojac/_transform/test_init.py @@ -5,7 +5,7 @@ from torchjd.autojac._transform import Init, RequirementError -def test_single_input(): +def test_single_input() -> None: """ Tests that when there is a single key to initialize, the Init transform creates a TensorDict whose value is a tensor full of ones, of the same shape as its key. @@ -22,7 +22,7 @@ def test_single_input(): assert_tensor_dicts_are_close(output, expected_output) -def test_multiple_inputs(): +def test_multiple_inputs() -> None: """ Tests that when there are several keys to initialize, the Init transform creates a TensorDict whose values are tensors full of ones, of the same shape as their corresponding keys. @@ -42,7 +42,7 @@ def test_multiple_inputs(): assert_tensor_dicts_are_close(output, expected) -def test_conjunction_of_inits_is_init(): +def test_conjunction_of_inits_is_init() -> None: """ Tests that the conjunction of 2 Init transforms is equivalent to a single Init transform with multiple keys. @@ -63,7 +63,7 @@ def test_conjunction_of_inits_is_init(): assert_tensor_dicts_are_close(output, expected_output) -def test_check_keys(): +def test_check_keys() -> None: """Tests that the `check_keys` method works correctly: the input_keys should be empty.""" key = tensor_([1.0]) diff --git a/tests/unit/autojac/_transform/test_interactions.py b/tests/unit/autojac/_transform/test_interactions.py index 470f5d6d9..3f9a725a3 100644 --- a/tests/unit/autojac/_transform/test_interactions.py +++ b/tests/unit/autojac/_transform/test_interactions.py @@ -18,7 +18,7 @@ ) -def test_jac_is_stack_of_grads(): +def test_jac_is_stack_of_grads() -> None: """ Tests that the Jac transform (composed with a Diagonalize) is equivalent to a Stack of Grad and Select transforms. @@ -52,7 +52,7 @@ def test_jac_is_stack_of_grads(): assert_tensor_dicts_are_close(jacobians, expected_jacobians) -def test_single_differentiation(): +def test_single_differentiation() -> None: """ Tests that we can perform a single scalar differentiation with the composition of a Grad and an Init transform. @@ -72,7 +72,7 @@ def test_single_differentiation(): assert_tensor_dicts_are_close(output, expected_output) -def test_multiple_differentiations(): +def test_multiple_differentiations() -> None: """ Tests that we can perform multiple scalar differentiations with the conjunction of multiple Grad transforms, composed with an Init transform. @@ -100,7 +100,7 @@ def test_multiple_differentiations(): assert_tensor_dicts_are_close(output, expected_output) -def test_str(): +def test_str() -> None: """Tests that the __str__ method works correctly even for a complex transform.""" init = Init(set()) diag = Diagonalize(OrderedSet([])) @@ -110,7 +110,7 @@ def test_str(): assert str(transform) == "Jac ∘ Diagonalize ∘ Init" -def test_simple_conjunction(): +def test_simple_conjunction() -> None: """ Tests that the Conjunction transform works correctly with a simple example involving several Select transforms, whose keys form a partition of the keys of the input tensor dict. @@ -133,7 +133,7 @@ def test_simple_conjunction(): assert_tensor_dicts_are_close(output, expected_output) -def test_conjunction_is_commutative(): +def test_conjunction_is_commutative() -> None: """ Tests that the Conjunction transform gives the same result no matter the order in which its transforms are given. @@ -154,7 +154,7 @@ def test_conjunction_is_commutative(): assert_tensor_dicts_are_close(output, expected_output) -def test_conjunction_is_associative(): +def test_conjunction_is_associative() -> None: """ Tests that the Conjunction transform gives the same result no matter how it is parenthesized. """ @@ -184,7 +184,7 @@ def test_conjunction_is_associative(): assert_tensor_dicts_are_close(output, expected_output) -def test_conjunction_accumulate_select(): +def test_conjunction_accumulate_select() -> None: """ Tests that it is possible to conjunct an AccumulateGrad and a Select in this order. It is not trivial since the type of the TensorDict returned by the first transform @@ -206,7 +206,7 @@ def test_conjunction_accumulate_select(): assert_tensor_dicts_are_close(output, expected_output) -def test_equivalence_jac_grads(): +def test_equivalence_jac_grads() -> None: """ Tests that differentiation in parallel using `_jac` is equivalent to sequential differentiation using several calls to `_grad` and stacking the resulting gradients. @@ -248,7 +248,7 @@ def test_equivalence_jac_grads(): assert_close(jac_c, torch.stack([grad_1_c, grad_2_c])) -def test_stack_check_keys(): +def test_stack_check_keys() -> None: """ Tests that the `check_keys` method works correctly for a stack of transforms: all of them should successfully check their keys. diff --git a/tests/unit/autojac/_transform/test_jac.py b/tests/unit/autojac/_transform/test_jac.py index e1efecf53..7e21f5fc9 100644 --- a/tests/unit/autojac/_transform/test_jac.py +++ b/tests/unit/autojac/_transform/test_jac.py @@ -7,7 +7,7 @@ @mark.parametrize("chunk_size", [1, 3, None]) -def test_single_input(chunk_size: int | None): +def test_single_input(chunk_size: int | None) -> None: """ Tests that the Jac transform works correctly for an example of multiple differentiation. Here, the function considered is: `y = [a1 * x, a2 * x]`. We want to compute the jacobians of `y` with @@ -32,7 +32,7 @@ def test_single_input(chunk_size: int | None): @mark.parametrize("chunk_size", [1, 3, None]) -def test_empty_inputs_1(chunk_size: int | None): +def test_empty_inputs_1(chunk_size: int | None) -> None: """ Tests that the Jac transform works correctly when the `inputs` parameter is an empty `Iterable`. """ @@ -51,7 +51,7 @@ def test_empty_inputs_1(chunk_size: int | None): @mark.parametrize("chunk_size", [1, 3, None]) -def test_empty_inputs_2(chunk_size: int | None): +def test_empty_inputs_2(chunk_size: int | None) -> None: """ Tests that the Jac transform works correctly when the `inputs` parameter is an empty `Iterable`. """ @@ -73,7 +73,7 @@ def test_empty_inputs_2(chunk_size: int | None): @mark.parametrize("chunk_size", [1, 3, None]) -def test_empty_outputs(chunk_size: int | None): +def test_empty_outputs(chunk_size: int | None) -> None: """ Tests that the Jac transform works correctly when the `outputs` parameter is an empty `Iterable`. @@ -94,7 +94,7 @@ def test_empty_outputs(chunk_size: int | None): assert_tensor_dicts_are_close(jacobians, expected_jacobians) -def test_retain_graph(): +def test_retain_graph() -> None: """Tests that the `Jac` transform behaves as expected with the `retain_graph` flag.""" x = tensor_(5.0) @@ -127,7 +127,7 @@ def test_retain_graph(): jac_discard_graph(input) -def test_two_levels(): +def test_two_levels() -> None: """ Tests that the Jac transform works correctly for an example of chained differentiation. Here, the function considered is: `z = a * x1 * x2`, which is computed in 2 parts: `y = a * x1` and @@ -167,7 +167,7 @@ def test_two_levels(): @mark.parametrize("chunk_size", [1, 3, None]) -def test_multiple_outputs_1(chunk_size: int | None): +def test_multiple_outputs_1(chunk_size: int | None) -> None: """ Tests that the Jac transform works correctly when the `outputs` contains 3 vectors. The input (jac_outputs) is not the same for all outputs, so that this test also checks that the @@ -201,7 +201,7 @@ def test_multiple_outputs_1(chunk_size: int | None): @mark.parametrize("chunk_size", [1, 3, None]) -def test_multiple_outputs_2(chunk_size: int | None): +def test_multiple_outputs_2(chunk_size: int | None) -> None: """ Same as test_multiple_outputs_1 but with different jac_outputs, so the returned jacobians are of different shapes. @@ -232,7 +232,7 @@ def test_multiple_outputs_2(chunk_size: int | None): assert_tensor_dicts_are_close(jacobians, expected_jacobians) -def test_composition_of_jacs_is_jac(): +def test_composition_of_jacs_is_jac() -> None: """ Tests that the composition of 2 Jac transforms is equivalent to computing the Jac directly in a single transform. @@ -268,7 +268,7 @@ def test_composition_of_jacs_is_jac(): assert_tensor_dicts_are_close(jacobians, expected_jacobians) -def test_conjunction_of_jacs_is_jac(): +def test_conjunction_of_jacs_is_jac() -> None: """ Tests that the conjunction of 2 Jac transforms is equivalent to computing the Jac directly in a single transform. @@ -294,7 +294,7 @@ def test_conjunction_of_jacs_is_jac(): assert_tensor_dicts_are_close(jacobians, expected_jacobians) -def test_create_graph(): +def test_create_graph() -> None: """Tests that the Jac transform behaves correctly when `create_graph` is set to `True`.""" x = tensor_(5.0) @@ -318,7 +318,7 @@ def test_create_graph(): assert jacobians[a2].requires_grad -def test_check_keys(): +def test_check_keys() -> None: """ Tests that the `check_keys` method works correctly: the input_keys should match the stored outputs. diff --git a/tests/unit/autojac/_transform/test_select.py b/tests/unit/autojac/_transform/test_select.py index 041eefc9a..897ad8175 100644 --- a/tests/unit/autojac/_transform/test_select.py +++ b/tests/unit/autojac/_transform/test_select.py @@ -6,7 +6,7 @@ from torchjd.autojac._transform import RequirementError, Select -def test_partition(): +def test_partition() -> None: """ Tests that the Select transform works correctly by applying 2 different Selects to a TensorDict, whose keys form a partition of the keys of the TensorDict. @@ -34,7 +34,7 @@ def test_partition(): assert_tensor_dicts_are_close(output2, expected_output2) -def test_conjunction_of_selects_is_select(): +def test_conjunction_of_selects_is_select() -> None: """ Tests that the conjunction of 2 Select transforms is equivalent to directly using a Select with the union of the keys of the 2 Selects. @@ -56,7 +56,7 @@ def test_conjunction_of_selects_is_select(): assert_tensor_dicts_are_close(output, expected_output) -def test_check_keys(): +def test_check_keys() -> None: """ Tests that the `check_keys` method works correctly: the set of keys to select should be a subset of the set of required_keys. diff --git a/tests/unit/autojac/_transform/test_stack.py b/tests/unit/autojac/_transform/test_stack.py index 35e617d1b..28515fa97 100644 --- a/tests/unit/autojac/_transform/test_stack.py +++ b/tests/unit/autojac/_transform/test_stack.py @@ -12,7 +12,7 @@ class FakeGradientsTransform(Transform): """Transform that produces gradients filled with ones, for testing purposes.""" - def __init__(self, keys: Iterable[Tensor]): + def __init__(self, keys: Iterable[Tensor]) -> None: self.keys = set(keys) def __call__(self, _input: TensorDict, /) -> TensorDict: @@ -22,7 +22,7 @@ def check_keys(self, _input_keys: set[Tensor], /) -> set[Tensor]: return self.keys -def test_single_key(): +def test_single_key() -> None: """ Tests that the Stack transform correctly stacks gradients into a jacobian, in a very simple example with 2 transforms sharing the same key. @@ -40,7 +40,7 @@ def test_single_key(): assert_tensor_dicts_are_close(output, expected_output) -def test_disjoint_key_sets(): +def test_disjoint_key_sets() -> None: """ Tests that the Stack transform correctly stacks gradients into a jacobian, in an example where the output key sets of all of its transforms are disjoint. The missing values should be replaced @@ -64,7 +64,7 @@ def test_disjoint_key_sets(): assert_tensor_dicts_are_close(output, expected_output) -def test_overlapping_key_sets(): +def test_overlapping_key_sets() -> None: """ Tests that the Stack transform correctly stacks gradients into a jacobian, in an example where the output key sets all of its transforms are overlapping (non-empty intersection, but not @@ -90,7 +90,7 @@ def test_overlapping_key_sets(): assert_tensor_dicts_are_close(output, expected_output) -def test_empty(): +def test_empty() -> None: """Tests that the Stack transform correctly handles an empty list of transforms.""" stack = Stack([]) diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index a0398c42d..806eb545d 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -9,7 +9,7 @@ @mark.parametrize("default_jac_tensors", [True, False]) -def test_check_create_transform(default_jac_tensors: bool): +def test_check_create_transform(default_jac_tensors: bool) -> None: """Tests that _create_transform creates a valid Transform.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -37,7 +37,7 @@ def test_check_create_transform(default_jac_tensors: bool): assert output_keys == set() -def test_jac_is_populated(): +def test_jac_is_populated() -> None: """Tests that backward correctly fills the .jac field.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -59,7 +59,7 @@ def test_value_is_correct( shape: tuple[int, int], manually_specify_inputs: bool, chunk_size: int | None, -): +) -> None: """ Tests that the .jac value filled by backward is correct in a simple example of matrix-vector product. @@ -81,7 +81,7 @@ def test_value_is_correct( @mark.parametrize("rows", [1, 2, 5]) -def test_jac_tensors_value_is_correct(rows: int): +def test_jac_tensors_value_is_correct(rows: int) -> None: """ Tests that backward correctly computes the product of jac_tensors and the Jacobian. result = jac_tensors @ Jacobian(tensors, inputs). @@ -107,7 +107,7 @@ def test_jac_tensors_value_is_correct(rows: int): @mark.parametrize("rows", [1, 3]) -def test_jac_tensors_multiple_components(rows: int): +def test_jac_tensors_multiple_components(rows: int) -> None: """ Tests that jac_tensors works correctly when tensors is a list of multiple tensors. The jac_tensors must match the structure of tensors. @@ -132,7 +132,7 @@ def test_jac_tensors_multiple_components(rows: int): assert_jac_close(input, expected) -def test_jac_tensors_length_mismatch(): +def test_jac_tensors_length_mismatch() -> None: """Tests that backward raises a ValueError early if len(jac_tensors) != len(tensors).""" x = tensor_([1.0, 2.0], requires_grad=True) y1 = x * 2 @@ -147,7 +147,7 @@ def test_jac_tensors_length_mismatch(): backward([y1, y2], jac_tensors=[J1], inputs=[x]) -def test_jac_tensors_shape_mismatch(): +def test_jac_tensors_shape_mismatch() -> None: """ Tests that backward raises a ValueError early if the shape of a tensor in jac_tensors is incompatible with the corresponding tensor. @@ -171,7 +171,7 @@ def test_jac_tensors_shape_mismatch(): (1, 2), ], ) -def test_jac_tensors_inconsistent_first_dimension(rows_y1: int, rows_y2: int): +def test_jac_tensors_inconsistent_first_dimension(rows_y1: int, rows_y2: int) -> None: """ Tests that backward raises a ValueError early when the provided jac_tensors have inconsistent first dimensions. @@ -190,7 +190,7 @@ def test_jac_tensors_inconsistent_first_dimension(rows_y1: int, rows_y2: int): backward([y1, y2], jac_tensors=[j1, j2], inputs=[x]) -def test_empty_inputs(): +def test_empty_inputs() -> None: """Tests that backward does not fill the .jac values if no input is specified.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -205,7 +205,7 @@ def test_empty_inputs(): assert_has_no_jac(a) -def test_partial_inputs(): +def test_partial_inputs() -> None: """ Tests that backward fills the right .jac values when only a subset of the actual inputs are specified as inputs. @@ -223,7 +223,7 @@ def test_partial_inputs(): assert_has_no_jac(a2) -def test_empty_tensors_fails(): +def test_empty_tensors_fails() -> None: """Tests that backward raises an error when called with an empty list of tensors.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -233,7 +233,7 @@ def test_empty_tensors_fails(): backward([], inputs=[a1, a2]) -def test_multiple_tensors(): +def test_multiple_tensors() -> None: """ Tests that giving multiple tensors to backward is equivalent to giving a single tensor containing all the values of the original tensors. @@ -268,7 +268,7 @@ def test_multiple_tensors(): @mark.parametrize("chunk_size", [None, 1, 2, 4]) -def test_various_valid_chunk_sizes(chunk_size): +def test_various_valid_chunk_sizes(chunk_size: int | None) -> None: """Tests that backward works for various valid values of parallel_chunk_size.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -284,7 +284,7 @@ def test_various_valid_chunk_sizes(chunk_size): @mark.parametrize("chunk_size", [0, -1]) -def test_non_positive_chunk_size_fails(chunk_size: int): +def test_non_positive_chunk_size_fails(chunk_size: int) -> None: """Tests that backward raises an error when using invalid chunk sizes.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -297,7 +297,7 @@ def test_non_positive_chunk_size_fails(chunk_size: int): backward([y1, y2], parallel_chunk_size=chunk_size) -def test_input_retaining_grad_fails(): +def test_input_retaining_grad_fails() -> None: """ Tests that backward raises an error when some input in the computation graph of the ``tensors`` parameter retains grad and vmap has to be used. @@ -317,7 +317,7 @@ def test_input_retaining_grad_fails(): _ = -b.grad # type: ignore[unsupported-operator] -def test_non_input_retaining_grad_fails(): +def test_non_input_retaining_grad_fails() -> None: """ Tests that backward fails to fill a valid `.grad` when some tensor in the computation graph of the ``tensors`` parameter retains grad and vmap has to be used. @@ -337,7 +337,7 @@ def test_non_input_retaining_grad_fails(): @mark.parametrize("chunk_size", [1, 3, None]) -def test_tensor_used_multiple_times(chunk_size: int | None): +def test_tensor_used_multiple_times(chunk_size: int | None) -> None: """ Tests that backward works correctly when one of the inputs is used multiple times. In this setup, the autograd graph is still acyclic, but the graph of tensors used becomes cyclic. @@ -356,7 +356,7 @@ def test_tensor_used_multiple_times(chunk_size: int | None): assert_jac_close(a, J) -def test_repeated_tensors(): +def test_repeated_tensors() -> None: """ Tests that backward does not allow repeating tensors. @@ -375,7 +375,7 @@ def test_repeated_tensors(): backward([y1, y1, y2]) -def test_repeated_inputs(): +def test_repeated_inputs() -> None: """ Tests that backward correctly works when some inputs are repeated. In this case, since torch.autograd.backward ignores the repetition of the inputs, it is natural for autojac to diff --git a/tests/unit/autojac/test_jac.py b/tests/unit/autojac/test_jac.py index 3ee6561fb..75c68cb9f 100644 --- a/tests/unit/autojac/test_jac.py +++ b/tests/unit/autojac/test_jac.py @@ -9,7 +9,7 @@ @mark.parametrize("default_jac_outputs", [True, False]) -def test_check_create_transform(default_jac_outputs: bool): +def test_check_create_transform(default_jac_outputs: bool) -> None: """Tests that _create_transform creates a valid Transform.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -37,7 +37,7 @@ def test_check_create_transform(default_jac_outputs: bool): assert output_keys == {a1, a2} -def test_jac(): +def test_jac() -> None: """Tests that jac works.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -65,7 +65,7 @@ def test_value_is_correct( chunk_size: int | None, outputs_is_list: bool, inputs_is_list: bool, -): +) -> None: """ Tests that the jacobians returned by jac are correct in a simple example of matrix-vector product. @@ -85,7 +85,7 @@ def test_value_is_correct( @mark.parametrize("rows", [1, 2, 5]) -def test_jac_outputs_value_is_correct(rows: int): +def test_jac_outputs_value_is_correct(rows: int) -> None: """ Tests that jac correctly computes the product of jac_outputs and the Jacobian. result = jac_outputs @ Jacobian(outputs, inputs). @@ -111,7 +111,7 @@ def test_jac_outputs_value_is_correct(rows: int): @mark.parametrize("rows", [1, 3]) -def test_jac_outputs_multiple_components(rows: int): +def test_jac_outputs_multiple_components(rows: int) -> None: """ Tests that jac_outputs works correctly when outputs is a list of multiple tensors. The jac_outputs must match the structure of outputs. @@ -136,7 +136,7 @@ def test_jac_outputs_multiple_components(rows: int): assert_close(jacobians[0], expected) -def test_jac_outputs_length_mismatch(): +def test_jac_outputs_length_mismatch() -> None: """Tests that jac raises a ValueError early if len(jac_outputs) != len(outputs).""" x = tensor_([1.0, 2.0], requires_grad=True) y1 = x * 2 @@ -151,7 +151,7 @@ def test_jac_outputs_length_mismatch(): jac([y1, y2], x, jac_outputs=[J1]) -def test_jac_outputs_shape_mismatch(): +def test_jac_outputs_shape_mismatch() -> None: """ Tests that jac raises a ValueError early if the shape of a tensor in jac_outputs is incompatible with the corresponding output tensor. @@ -175,7 +175,7 @@ def test_jac_outputs_shape_mismatch(): (1, 2), ], ) -def test_jac_outputs_inconsistent_first_dimension(rows_y1: int, rows_y2: int): +def test_jac_outputs_inconsistent_first_dimension(rows_y1: int, rows_y2: int) -> None: """ Tests that jac raises a ValueError early when the provided jac_outputs have inconsistent first dimensions. @@ -194,7 +194,7 @@ def test_jac_outputs_inconsistent_first_dimension(rows_y1: int, rows_y2: int): jac([y1, y2], x, jac_outputs=[j1, j2]) -def test_empty_inputs(): +def test_empty_inputs() -> None: """Tests that jac does not return any jacobian no input is specified.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -207,7 +207,7 @@ def test_empty_inputs(): assert len(jacobians) == 0 -def test_partial_inputs(): +def test_partial_inputs() -> None: """ Tests that jac returns the right jacobians when only a subset of the actual inputs are specified as inputs. @@ -223,7 +223,7 @@ def test_partial_inputs(): assert len(jacobians) == 1 -def test_empty_tensors_fails(): +def test_empty_tensors_fails() -> None: """Tests that jac raises an error when called with an empty list of tensors.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -233,7 +233,7 @@ def test_empty_tensors_fails(): jac([], inputs=[a1, a2]) -def test_multiple_tensors(): +def test_multiple_tensors() -> None: """ Tests that giving multiple tensors to jac is equivalent to giving a single tensor containing all the values of the original tensors. @@ -268,7 +268,7 @@ def test_multiple_tensors(): @mark.parametrize("chunk_size", [None, 1, 2, 4]) -def test_various_valid_chunk_sizes(chunk_size): +def test_various_valid_chunk_sizes(chunk_size: int | None) -> None: """Tests that jac works for various valid values of parallel_chunk_size.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -282,7 +282,7 @@ def test_various_valid_chunk_sizes(chunk_size): @mark.parametrize("chunk_size", [0, -1]) -def test_non_positive_chunk_size_fails(chunk_size: int): +def test_non_positive_chunk_size_fails(chunk_size: int) -> None: """Tests that jac raises an error when using invalid chunk sizes.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -295,7 +295,7 @@ def test_non_positive_chunk_size_fails(chunk_size: int): jac([y1, y2], [a1, a2], parallel_chunk_size=chunk_size) -def test_input_retaining_grad_fails(): +def test_input_retaining_grad_fails() -> None: """ Tests that jac raises an error when some input in the computation graph of the ``tensors`` parameter retains grad and vmap has to be used. @@ -315,7 +315,7 @@ def test_input_retaining_grad_fails(): _ = -b.grad # type: ignore[unsupported-operator] -def test_non_input_retaining_grad_fails(): +def test_non_input_retaining_grad_fails() -> None: """ Tests that jac fails to fill a valid `.grad` when some tensor in the computation graph of the ``tensors`` parameter retains grad and vmap has to be used. @@ -335,7 +335,7 @@ def test_non_input_retaining_grad_fails(): @mark.parametrize("chunk_size", [1, 3, None]) -def test_tensor_used_multiple_times(chunk_size: int | None): +def test_tensor_used_multiple_times(chunk_size: int | None) -> None: """ Tests that jac works correctly when one of the inputs is used multiple times. In this setup, the autograd graph is still acyclic, but the graph of tensors used becomes cyclic. @@ -355,7 +355,7 @@ def test_tensor_used_multiple_times(chunk_size: int | None): assert_close(jacobians[0], J) -def test_repeated_tensors(): +def test_repeated_tensors() -> None: """ Tests that jac does not allow repeating tensors. @@ -374,7 +374,7 @@ def test_repeated_tensors(): jac([y1, y1, y2], [a1, a2]) -def test_repeated_inputs(): +def test_repeated_inputs() -> None: """ Tests that jac correctly works when some inputs are repeated. In this case, since torch.autograd.grad repeats the output gradients, it is natural for autojac to also repeat the diff --git a/tests/unit/autojac/test_jac_to_grad.py b/tests/unit/autojac/test_jac_to_grad.py index a3f830972..b8ea5c6cd 100644 --- a/tests/unit/autojac/test_jac_to_grad.py +++ b/tests/unit/autojac/test_jac_to_grad.py @@ -7,7 +7,7 @@ @mark.parametrize("aggregator", [Mean(), UPGrad(), PCGrad()]) -def test_various_aggregators(aggregator: Aggregator): +def test_various_aggregators(aggregator: Aggregator) -> None: """Tests that jac_to_grad works for various aggregators.""" t1 = tensor_(1.0, requires_grad=True) @@ -25,7 +25,7 @@ def test_various_aggregators(aggregator: Aggregator): assert_grad_close(t2, g2) -def test_single_tensor(): +def test_single_tensor() -> None: """Tests that jac_to_grad works when a single tensor is provided.""" aggregator = UPGrad() @@ -39,7 +39,7 @@ def test_single_tensor(): assert_grad_close(t, g) -def test_no_jac_field(): +def test_no_jac_field() -> None: """Tests that jac_to_grad fails when a tensor does not have a jac field.""" aggregator = UPGrad() @@ -52,7 +52,7 @@ def test_no_jac_field(): jac_to_grad([t1, t2], aggregator) -def test_no_requires_grad(): +def test_no_requires_grad() -> None: """Tests that jac_to_grad fails when a tensor does not require grad.""" aggregator = UPGrad() @@ -66,7 +66,7 @@ def test_no_requires_grad(): jac_to_grad([t1, t2], aggregator) -def test_row_mismatch(): +def test_row_mismatch() -> None: """Tests that jac_to_grad fails when the number of rows of the .jac is not constant.""" aggregator = UPGrad() @@ -79,14 +79,14 @@ def test_row_mismatch(): jac_to_grad([t1, t2], aggregator) -def test_no_tensors(): +def test_no_tensors() -> None: """Tests that jac_to_grad correctly does nothing when an empty list of tensors is provided.""" jac_to_grad([], aggregator=UPGrad()) @mark.parametrize("retain_jac", [True, False]) -def test_jacs_are_freed(retain_jac: bool): +def test_jacs_are_freed(retain_jac: bool) -> None: """Tests that jac_to_grad frees the jac fields if an only if retain_jac is False.""" aggregator = UPGrad() @@ -103,7 +103,7 @@ def test_jacs_are_freed(retain_jac: bool): check(t2) -def test_noncontiguous_jac(): +def test_noncontiguous_jac() -> None: """Tests that jac_to_grad works when the .jac field is non-contiguous.""" aggregator = UPGrad() diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index 9f15e6fdd..5ca138827 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -18,7 +18,7 @@ @mark.parametrize("default_grad_tensors", [True, False]) -def test_check_create_transform(default_grad_tensors: bool): +def test_check_create_transform(default_grad_tensors: bool) -> None: """Tests that _create_transform creates a valid Transform.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -48,7 +48,7 @@ def test_check_create_transform(default_grad_tensors: bool): assert output_keys == set() -def test_shape_is_correct(): +def test_shape_is_correct() -> None: """Tests that mtl_backward works correctly.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -76,7 +76,7 @@ def test_value_is_correct( manually_specify_shared_params: bool, manually_specify_tasks_params: bool, chunk_size: int | None, -): +) -> None: """ Tests that the .jac value filled by mtl_backward is correct in a simple example of matrix-vector product for three tasks whose loss are given by a simple inner product of the @@ -116,7 +116,7 @@ def test_value_is_correct( assert_jac_close(p0, expected_jacobian) -def test_empty_tasks_fails(): +def test_empty_tasks_fails() -> None: """Tests that mtl_backward raises an error when called with an empty list of tasks.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -128,7 +128,7 @@ def test_empty_tasks_fails(): mtl_backward([], features=[f1, f2]) -def test_single_task(): +def test_single_task() -> None: """Tests that mtl_backward works correctly with a single task.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -144,7 +144,7 @@ def test_single_task(): assert_has_grad(p1) -def test_incoherent_task_number_fails(): +def test_incoherent_task_number_fails() -> None: """ Tests that mtl_backward raises an error when called with the number of tasks losses different from the number of tasks parameters. @@ -175,7 +175,7 @@ def test_incoherent_task_number_fails(): ) -def test_empty_params(): +def test_empty_params() -> None: """Tests that mtl_backward does not fill the .jac/.grad values if no parameter is specified.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -199,7 +199,7 @@ def test_empty_params(): assert_has_no_grad(p) -def test_multiple_params_per_task(): +def test_multiple_params_per_task() -> None: """Tests that mtl_backward works correctly when the tasks each have several parameters.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -234,7 +234,7 @@ def test_multiple_params_per_task(): [(5, 4, 3, 2), (5, 4, 3, 2)], ], ) -def test_various_shared_params(shared_params_shapes: list[tuple[int]]): +def test_various_shared_params(shared_params_shapes: list[tuple[int]]) -> None: """Tests that mtl_backward works correctly with various kinds of shared_params.""" shared_params = [rand_(shape, requires_grad=True) for shape in shared_params_shapes] @@ -258,7 +258,7 @@ def test_various_shared_params(shared_params_shapes: list[tuple[int]]): assert_has_grad(p) -def test_partial_params(): +def test_partial_params() -> None: """ Tests that mtl_backward fills the right .jac/.grad values when only a subset of the parameters are specified as inputs. @@ -285,7 +285,7 @@ def test_partial_params(): assert_has_no_grad(p2) -def test_empty_features_fails(): +def test_empty_features_fails() -> None: """Tests that mtl_backward expectedly raises an error when no there is no feature.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -310,7 +310,7 @@ def test_empty_features_fails(): (5, 4, 3, 2), ], ) -def test_various_single_features(shape: tuple[int, ...]): +def test_various_single_features(shape: tuple[int, ...]) -> None: """Tests that mtl_backward works correctly with various kinds of feature tensors.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -342,7 +342,7 @@ def test_various_single_features(shape: tuple[int, ...]): [(5, 4, 3, 2), (5, 4, 3, 2)], ], ) -def test_various_feature_lists(shapes: list[tuple[int]]): +def test_various_feature_lists(shapes: list[tuple[int]]) -> None: """Tests that mtl_backward works correctly with various kinds of feature lists.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -361,7 +361,7 @@ def test_various_feature_lists(shapes: list[tuple[int]]): assert_has_grad(p) -def test_non_scalar_loss_fails(): +def test_non_scalar_loss_fails() -> None: """Tests that mtl_backward raises an error when used with a non-scalar loss.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -378,7 +378,7 @@ def test_non_scalar_loss_fails(): @mark.parametrize("chunk_size", [None, 1, 2, 4]) -def test_various_valid_chunk_sizes(chunk_size): +def test_various_valid_chunk_sizes(chunk_size: int | None) -> None: """Tests that mtl_backward works for various valid values of parallel_chunk_size.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -402,7 +402,7 @@ def test_various_valid_chunk_sizes(chunk_size): @mark.parametrize("chunk_size", [0, -1]) -def test_non_positive_chunk_size_fails(chunk_size: int): +def test_non_positive_chunk_size_fails(chunk_size: int) -> None: """Tests that mtl_backward raises an error when using invalid chunk sizes.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -422,7 +422,7 @@ def test_non_positive_chunk_size_fails(chunk_size: int): ) -def test_shared_param_retaining_grad_fails(): +def test_shared_param_retaining_grad_fails() -> None: """ Tests that mtl_backward fails to fill a valid `.grad` when some shared param in the computation graph of the ``features`` parameter retains grad and vmap has to be used. @@ -451,7 +451,7 @@ def test_shared_param_retaining_grad_fails(): _ = -a.grad # type: ignore[unsupported-operator] -def test_shared_activation_retaining_grad_fails(): +def test_shared_activation_retaining_grad_fails() -> None: """ Tests that mtl_backward fails to fill a valid `.grad` when some tensor in the computation graph of the ``features`` parameter retains grad and vmap has to be used. @@ -480,7 +480,7 @@ def test_shared_activation_retaining_grad_fails(): _ = -a.grad # type: ignore[unsupported-operator] -def test_tasks_params_overlap(): +def test_tasks_params_overlap() -> None: """Tests that mtl_backward works correctly when the tasks' parameters have some overlap.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -502,7 +502,7 @@ def test_tasks_params_overlap(): assert_jac_close(p0, J) -def test_tasks_params_are_the_same(): +def test_tasks_params_are_the_same() -> None: """Tests that mtl_backward works correctly when the tasks have the same params.""" p0 = tensor_([1.0, 2.0], requires_grad=True) @@ -520,7 +520,7 @@ def test_tasks_params_are_the_same(): assert_jac_close(p0, J) -def test_task_params_is_subset_of_other_task_params(): +def test_task_params_is_subset_of_other_task_params() -> None: """ Tests that mtl_backward works correctly when one task's params is a subset of another task's params. @@ -543,7 +543,7 @@ def test_task_params_is_subset_of_other_task_params(): assert_jac_close(p0, J) -def test_shared_params_overlapping_with_tasks_params_fails(): +def test_shared_params_overlapping_with_tasks_params_fails() -> None: """ Tests that mtl_backward raises an error when the set of shared params overlaps with the set of task-specific params. @@ -566,7 +566,7 @@ def test_shared_params_overlapping_with_tasks_params_fails(): ) -def test_default_shared_params_overlapping_with_default_tasks_params_fails(): +def test_default_shared_params_overlapping_with_default_tasks_params_fails() -> None: """ Tests that mtl_backward raises an error when the set of shared params obtained by default overlaps with the set of task-specific params obtained by default. @@ -587,7 +587,7 @@ def test_default_shared_params_overlapping_with_default_tasks_params_fails(): ) -def test_repeated_losses(): +def test_repeated_losses() -> None: """ Tests that mtl_backward does not allow repeating losses. @@ -610,7 +610,7 @@ def test_repeated_losses(): mtl_backward(losses, features=[f1, f2], retain_graph=True) -def test_repeated_features(): +def test_repeated_features() -> None: """ Tests that mtl_backward does not allow repeating features. @@ -633,7 +633,7 @@ def test_repeated_features(): mtl_backward([y1, y2], features=features) -def test_repeated_shared_params(): +def test_repeated_shared_params() -> None: """ Tests that mtl_backward correctly works when some shared are repeated. Since these are tensors with respect to which we differentiate, to match the behavior of torch.autograd.backward, this @@ -661,7 +661,7 @@ def test_repeated_shared_params(): assert_grad_close(p2, g2) -def test_repeated_task_params(): +def test_repeated_task_params() -> None: """ Tests that mtl_backward correctly works when some task-specific params are repeated for some task. Since these are tensors with respect to which we differentiate, to match the behavior of @@ -689,7 +689,7 @@ def test_repeated_task_params(): assert_grad_close(p2, g2) -def test_grad_tensors_value_is_correct(): +def test_grad_tensors_value_is_correct() -> None: """ Tests that mtl_ackward correctly computes the element-wise product of grad_tensors and the tensors. @@ -724,7 +724,7 @@ def test_grad_tensors_value_is_correct(): assert_jac_close(p0, expected_jacobian) -def test_grad_tensors_length_mismatch(): +def test_grad_tensors_length_mismatch() -> None: """Tests that mtl_backward raises a ValueError early if len(grad_tensors) != len(tensors).""" p0 = randn_(3, requires_grad=True) @@ -747,7 +747,7 @@ def test_grad_tensors_length_mismatch(): ) -def test_grad_tensors_shape_mismatch(): +def test_grad_tensors_shape_mismatch() -> None: """ Tests that mtl_backward raises a ValueError early if the shape of a tensor in grad_tensors is incompatible with the corresponding tensor. diff --git a/tests/unit/autojac/test_utils.py b/tests/unit/autojac/test_utils.py index f4dbf7a40..bb1450d82 100644 --- a/tests/unit/autojac/test_utils.py +++ b/tests/unit/autojac/test_utils.py @@ -6,7 +6,7 @@ from torchjd.autojac._utils import get_leaf_tensors -def test_simple_get_leaf_tensors(): +def test_simple_get_leaf_tensors() -> None: """Tests that _get_leaf_tensors works correctly in a very simple setting.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -19,7 +19,7 @@ def test_simple_get_leaf_tensors(): assert set(leaves) == {a1, a2} -def test_get_leaf_tensors_excluded_1(): +def test_get_leaf_tensors_excluded_1() -> None: """ Tests that _get_leaf_tensors works correctly when some tensors are excluded from the search. @@ -40,7 +40,7 @@ def test_get_leaf_tensors_excluded_1(): assert set(leaves) == {a1} -def test_get_leaf_tensors_excluded_2(): +def test_get_leaf_tensors_excluded_2() -> None: """ Tests that _get_leaf_tensors works correctly when some tensors are excluded from the search. @@ -61,7 +61,7 @@ def test_get_leaf_tensors_excluded_2(): assert set(leaves) == {a1, a2} -def test_get_leaf_tensors_leaf_not_requiring_grad(): +def test_get_leaf_tensors_leaf_not_requiring_grad() -> None: """ Tests that _get_leaf_tensors does not include tensors that do not require grad in its results. """ @@ -76,7 +76,7 @@ def test_get_leaf_tensors_leaf_not_requiring_grad(): assert set(leaves) == {a1} -def test_get_leaf_tensors_model(): +def test_get_leaf_tensors_model() -> None: """ Tests that _get_leaf_tensors works correctly when the autograd graph is generated by a simple sequential model. @@ -95,7 +95,7 @@ def test_get_leaf_tensors_model(): assert set(leaves) == set(model.parameters()) -def test_get_leaf_tensors_model_excluded_2(): +def test_get_leaf_tensors_model_excluded_2() -> None: """ Tests that _get_leaf_tensors works correctly when the autograd graph is generated by a simple sequential model, and some intermediate values are excluded. @@ -116,7 +116,7 @@ def test_get_leaf_tensors_model_excluded_2(): assert set(leaves) == set(model2.parameters()) -def test_get_leaf_tensors_single_root(): +def test_get_leaf_tensors_single_root() -> None: """Tests that _get_leaf_tensors returns no leaves when roots is the empty set.""" p = tensor_([1.0, 2.0], requires_grad=True) @@ -126,14 +126,14 @@ def test_get_leaf_tensors_single_root(): assert set(leaves) == {p} -def test_get_leaf_tensors_empty_roots(): +def test_get_leaf_tensors_empty_roots() -> None: """Tests that _get_leaf_tensors returns no leaves when roots is the empty set.""" leaves = get_leaf_tensors(tensors=[], excluded=set()) assert set(leaves) == set() -def test_get_leaf_tensors_excluded_root(): +def test_get_leaf_tensors_excluded_root() -> None: """Tests that _get_leaf_tensors correctly excludes the root.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -147,7 +147,7 @@ def test_get_leaf_tensors_excluded_root(): @mark.parametrize("depth", [100, 1000, 10000]) -def test_get_leaf_tensors_deep(depth: int): +def test_get_leaf_tensors_deep(depth: int) -> None: """Tests that _get_leaf_tensors works when the graph is very deep.""" one = tensor_(1.0, requires_grad=True) @@ -159,7 +159,7 @@ def test_get_leaf_tensors_deep(depth: int): assert set(leaves) == {one} -def test_get_leaf_tensors_leaf(): +def test_get_leaf_tensors_leaf() -> None: """Tests that _get_leaf_tensors raises an error some of the provided tensors are leaves.""" a = tensor_(1.0, requires_grad=True) @@ -167,7 +167,7 @@ def test_get_leaf_tensors_leaf(): _ = get_leaf_tensors(tensors=[a], excluded=set()) -def test_get_leaf_tensors_tensor_not_requiring_grad(): +def test_get_leaf_tensors_tensor_not_requiring_grad() -> None: """ Tests that _get_leaf_tensors raises an error some of the provided tensors do not require grad. """ @@ -177,7 +177,7 @@ def test_get_leaf_tensors_tensor_not_requiring_grad(): _ = get_leaf_tensors(tensors=[a], excluded=set()) -def test_get_leaf_tensors_excluded_leaf(): +def test_get_leaf_tensors_excluded_leaf() -> None: """Tests that _get_leaf_tensors raises an error some of the excluded tensors are leaves.""" a = tensor_(1.0, requires_grad=True) * 2 @@ -186,7 +186,7 @@ def test_get_leaf_tensors_excluded_leaf(): _ = get_leaf_tensors(tensors=[a], excluded={b}) -def test_get_leaf_tensors_excluded_not_requiring_grad(): +def test_get_leaf_tensors_excluded_not_requiring_grad() -> None: """ Tests that _get_leaf_tensors raises an error some of the excluded tensors do not require grad. """ diff --git a/tests/unit/linalg/test_gramian.py b/tests/unit/linalg/test_gramian.py index 53373822c..2c8c1eec2 100644 --- a/tests/unit/linalg/test_gramian.py +++ b/tests/unit/linalg/test_gramian.py @@ -20,13 +20,13 @@ [6, 7, 9], ], ) -def test_gramian_is_psd(shape: list[int]): +def test_gramian_is_psd(shape: list[int]) -> None: matrix = randn_(shape) gramian = compute_gramian(matrix) assert_is_psd_matrix(gramian) -def test_compute_gramian_scalar_input_0(): +def test_compute_gramian_scalar_input_0() -> None: t = tensor_(5.0) gramian = compute_gramian(t, contracted_dims=0) expected = tensor_(25.0) @@ -34,7 +34,7 @@ def test_compute_gramian_scalar_input_0(): assert_close(gramian, expected) -def test_compute_gramian_vector_input_0(): +def test_compute_gramian_vector_input_0() -> None: t = tensor_([2.0, 3.0]) gramian = compute_gramian(t, contracted_dims=0) expected = tensor_([[4.0, 6.0], [6.0, 9.0]]) @@ -42,7 +42,7 @@ def test_compute_gramian_vector_input_0(): assert_close(gramian, expected) -def test_compute_gramian_vector_input_1(): +def test_compute_gramian_vector_input_1() -> None: t = tensor_([2.0, 3.0]) gramian = compute_gramian(t, contracted_dims=1) expected = tensor_(13.0) @@ -50,7 +50,7 @@ def test_compute_gramian_vector_input_1(): assert_close(gramian, expected) -def test_compute_gramian_matrix_input_0(): +def test_compute_gramian_matrix_input_0() -> None: t = tensor_([[1.0, 2.0], [3.0, 4.0]]) gramian = compute_gramian(t, contracted_dims=0) expected = tensor_( @@ -63,7 +63,7 @@ def test_compute_gramian_matrix_input_0(): assert_close(gramian, expected) -def test_compute_gramian_matrix_input_1(): +def test_compute_gramian_matrix_input_1() -> None: t = tensor_([[1.0, 2.0], [3.0, 4.0]]) gramian = compute_gramian(t, contracted_dims=1) expected = tensor_([[5.0, 11.0], [11.0, 25.0]]) @@ -71,7 +71,7 @@ def test_compute_gramian_matrix_input_1(): assert_close(gramian, expected) -def test_compute_gramian_matrix_input_2(): +def test_compute_gramian_matrix_input_2() -> None: t = tensor_([[1.0, 2.0], [3.0, 4.0]]) gramian = compute_gramian(t, contracted_dims=2) expected = tensor_(30.0) @@ -89,7 +89,7 @@ def test_compute_gramian_matrix_input_2(): [5, 0], ], ) -def test_normalize_yields_psd(shape: list[int]): +def test_normalize_yields_psd(shape: list[int]) -> None: matrix = randn_(shape) assert is_matrix(matrix) gramian = compute_gramian(matrix) @@ -107,7 +107,7 @@ def test_normalize_yields_psd(shape: list[int]): [5, 0], ], ) -def test_regularize_yields_psd(shape: list[int]): +def test_regularize_yields_psd(shape: list[int]) -> None: matrix = randn_(shape) assert is_matrix(matrix) gramian = compute_gramian(matrix) diff --git a/tests/unit/test_deprecations.py b/tests/unit/test_deprecations.py index f11214787..7d38a04d5 100644 --- a/tests/unit/test_deprecations.py +++ b/tests/unit/test_deprecations.py @@ -2,7 +2,7 @@ # deprecated since 2025-08-18 -def test_deprecate_imports_from_torchjd(): +def test_deprecate_imports_from_torchjd() -> None: with pytest.deprecated_call(): from torchjd import backward # noqa: F401 diff --git a/tests/utils/architectures.py b/tests/utils/architectures.py index cf3261f8a..20ce71cdd 100644 --- a/tests/utils/architectures.py +++ b/tests/utils/architectures.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Generic, TypeVar +from typing import Any, Generic, TypeVar import torch import torchvision @@ -14,7 +14,7 @@ class ModuleFactory(Generic[_T]): - def __init__(self, architecture: type[_T], *args, **kwargs): + def __init__(self, architecture: type[_T], *args: Any, **kwargs: Any) -> None: self.architecture: type[_T] = architecture self.args = args self.kwargs = kwargs @@ -36,7 +36,7 @@ class ShapedModule(nn.Module): INPUT_SHAPES: PyTree # meant to be overridden OUTPUT_SHAPES: PyTree # meant to be overridden - def __init_subclass__(cls): + def __init_subclass__(cls) -> None: super().__init_subclass__() if getattr(cls, "INPUT_SHAPES", None) is None: raise TypeError(f"{cls.__name__} must define INPUT_SHAPES") @@ -63,7 +63,7 @@ class OverlyNested(ShapedModule): INPUT_SHAPES = (9,) OUTPUT_SHAPES = (14,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.seq = nn.Sequential( nn.Sequential( @@ -95,7 +95,7 @@ class MultiInputSingleOutput(ShapedModule): INPUT_SHAPES = ((50,), (50,)) OUTPUT_SHAPES = (60,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.matrix1 = nn.Parameter(torch.randn(50, 60)) self.matrix2 = nn.Parameter(torch.randn(50, 60)) @@ -112,7 +112,7 @@ class MultiInputMultiOutput(ShapedModule): INPUT_SHAPES = ((50,), (50,)) OUTPUT_SHAPES = ((60,), (70,)) - def __init__(self): + def __init__(self) -> None: super().__init__() self.matrix1_1 = nn.Parameter(torch.randn(50, 60)) self.matrix2_1 = nn.Parameter(torch.randn(50, 60)) @@ -136,7 +136,7 @@ class SingleInputPyTreeOutput(ShapedModule): "third": ([((90,),)],), } - def __init__(self): + def __init__(self) -> None: super().__init__() self.matrix1 = nn.Parameter(torch.randn(50, 50)) self.matrix2 = nn.Parameter(torch.randn(50, 60)) @@ -161,7 +161,7 @@ class PyTreeInputSingleOutput(ShapedModule): } OUTPUT_SHAPES = (350,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.matrix1 = nn.Parameter(torch.randn(10, 50)) self.matrix2 = nn.Parameter(torch.randn(20, 60)) @@ -203,7 +203,7 @@ class PyTreeInputPyTreeOutput(ShapedModule): "third": ([((90,),)],), } - def __init__(self): + def __init__(self) -> None: super().__init__() self.matrix1 = nn.Parameter(torch.randn(10, 50)) self.matrix2 = nn.Parameter(torch.randn(20, 60)) @@ -231,7 +231,7 @@ class SimpleBranched(ShapedModule): INPUT_SHAPES = (9,) OUTPUT_SHAPES = (16,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.relu = nn.ReLU() self.fc0 = nn.Linear(9, 13) @@ -257,7 +257,7 @@ class MISOBranched(ShapedModule): INPUT_SHAPES = (50,) OUTPUT_SHAPES = MultiInputSingleOutput.OUTPUT_SHAPES - def __init__(self): + def __init__(self) -> None: super().__init__() self.miso = MultiInputSingleOutput() @@ -274,7 +274,7 @@ class MIMOBranched(ShapedModule): INPUT_SHAPES = (50,) OUTPUT_SHAPES = (130,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.mimo = MultiInputMultiOutput() @@ -291,7 +291,7 @@ class SIPOBranched(ShapedModule): INPUT_SHAPES = (50,) OUTPUT_SHAPES = (350,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.sipo = SingleInputPyTreeOutput() @@ -314,7 +314,7 @@ class PISOBranched(ShapedModule): INPUT_SHAPES = (86,) OUTPUT_SHAPES = (350,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.piso = PyTreeInputSingleOutput() @@ -342,7 +342,7 @@ class PIPOBranched(ShapedModule): INPUT_SHAPES = (86,) OUTPUT_SHAPES = (350,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.pipo = PyTreeInputPyTreeOutput() @@ -379,7 +379,7 @@ class WithNoTensorOutput(ShapedModule): OUTPUT_SHAPES = (10,) class _NoneOutput(nn.Module): - def __init__(self, shape: tuple[int, ...]): + def __init__(self, shape: tuple[int, ...]) -> None: super().__init__() self.matrix = nn.Parameter(torch.randn(shape)) @@ -387,7 +387,7 @@ def forward(self, _: PyTree) -> None: pass class _NonePyTreeOutput(nn.Module): - def __init__(self, shape: tuple[int, ...]): + def __init__(self, shape: tuple[int, ...]) -> None: super().__init__() self.matrix = nn.Parameter(torch.randn(shape)) @@ -395,7 +395,7 @@ def forward(self, _: PyTree) -> PyTree: return {"one": [None, ()], "two": None} class _EmptyTupleOutput(nn.Module): - def __init__(self, shape: tuple[int, ...]): + def __init__(self, shape: tuple[int, ...]) -> None: super().__init__() self.matrix = nn.Parameter(torch.randn(shape)) @@ -403,14 +403,14 @@ def forward(self, _: PyTree) -> tuple: return () class _EmptyPytreeOutput(nn.Module): - def __init__(self, shape: tuple[int, ...]): + def __init__(self, shape: tuple[int, ...]) -> None: super().__init__() self.matrix = nn.Parameter(torch.randn(shape)) def forward(self, _: PyTree) -> PyTree: return {"one": [(), ()], "two": [[], []]} - def __init__(self): + def __init__(self) -> None: super().__init__() self.none_output = self._NoneOutput((27, 10)) self.none_pytree_output = self._NonePyTreeOutput((27, 10)) @@ -432,7 +432,7 @@ class IntraModuleParamReuse(ShapedModule): INPUT_SHAPES = (50,) OUTPUT_SHAPES = (10,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.matrix = nn.Parameter(torch.randn(50, 10)) @@ -452,14 +452,14 @@ class _MatMulModule(nn.Module): that this parameter can be used in other modules too. """ - def __init__(self, matrix: nn.Parameter): + def __init__(self, matrix: nn.Parameter) -> None: super().__init__() self.matrix = matrix - def forward(self, input: Tensor): + def forward(self, input: Tensor) -> Tensor: return input @ self.matrix - def __init__(self): + def __init__(self) -> None: super().__init__() matrix = nn.Parameter(torch.randn(50, 10)) self.module1 = self._MatMulModule(matrix) @@ -475,7 +475,7 @@ class ModuleReuse(ShapedModule): INPUT_SHAPES = (50,) OUTPUT_SHAPES = (10,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.module = nn.Linear(50, 10) @@ -489,7 +489,7 @@ class SomeUnusedParam(ShapedModule): INPUT_SHAPES = (50,) OUTPUT_SHAPES = (10,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.unused_param = nn.Parameter(torch.randn(50, 10)) self.matrix = nn.Parameter(torch.randn(50, 10)) @@ -507,7 +507,7 @@ class SomeFrozenParam(ShapedModule): INPUT_SHAPES = (50,) OUTPUT_SHAPES = (10,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.frozen_param = nn.Parameter(torch.randn(50, 10), requires_grad=False) self.matrix = nn.Parameter(torch.randn(50, 10)) @@ -524,7 +524,7 @@ class WithSomeFrozenModule(ShapedModule): INPUT_SHAPES = (50,) OUTPUT_SHAPES = (10,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.non_frozen = nn.Linear(50, 10) self.all_frozen = nn.Linear(50, 10) @@ -553,7 +553,7 @@ class SomeFrozenParamAndUnusedTrainableParam(ShapedModule): INPUT_SHAPES = (50,) OUTPUT_SHAPES = (10,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.frozen_param = nn.Parameter(torch.randn(50, 10), requires_grad=False) self.non_frozen_param = nn.Parameter(torch.randn(50, 10)) @@ -561,7 +561,7 @@ def __init__(self): def forward(self, input: Tensor) -> Tensor: return input @ self.frozen_param - def __init__(self): + def __init__(self) -> None: super().__init__() self.weird_module = self.SomeFrozenParamAndUnusedTrainableParam() self.normal_module = nn.Linear(10, 3) @@ -579,7 +579,7 @@ class MultiOutputWithFrozenBranch(ShapedModule): INPUT_SHAPES = (50,) OUTPUT_SHAPES = ((10,), (10,)) - def __init__(self): + def __init__(self) -> None: super().__init__() self.frozen_param = nn.Parameter(torch.randn(50, 10), requires_grad=False) self.matrix = nn.Parameter(torch.randn(50, 10)) @@ -597,14 +597,14 @@ class WithBuffered(ShapedModule): class _Buffered(nn.Module): buffer: Tensor - def __init__(self): + def __init__(self) -> None: super().__init__() self.register_buffer("buffer", torch.tensor(1.5)) def forward(self, input: Tensor) -> Tensor: return input * self.buffer - def __init__(self): + def __init__(self) -> None: super().__init__() self.module_with_buffer = self._Buffered() self.linear = nn.Linear(27, 10) @@ -619,7 +619,7 @@ class Randomness(ShapedModule): INPUT_SHAPES = (9,) OUTPUT_SHAPES = (10,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.matrix = nn.Parameter(torch.randn(9, 10)) @@ -635,7 +635,7 @@ class WithSideEffect(ShapedModule): INPUT_SHAPES = (9,) OUTPUT_SHAPES = (10,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.matrix = nn.Parameter(torch.randn(9, 10)) self.register_buffer("buffer", torch.zeros((9,))) @@ -654,7 +654,7 @@ class SomeUnusedOutput(ShapedModule): INPUT_SHAPES = (9,) OUTPUT_SHAPES = (10,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.linear1 = nn.Linear(9, 12) self.linear2 = nn.Linear(9, 10) @@ -671,7 +671,7 @@ class Ndim0Output(ShapedModule): INPUT_SHAPES = (5,) OUTPUT_SHAPES = () - def __init__(self): + def __init__(self) -> None: super().__init__() self.linear = nn.Linear(5, 1) @@ -685,7 +685,7 @@ class Ndim1Output(ShapedModule): INPUT_SHAPES = (5,) OUTPUT_SHAPES = (3,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.linear = nn.Linear(5, 3) @@ -699,7 +699,7 @@ class Ndim2Output(ShapedModule): INPUT_SHAPES = (5,) OUTPUT_SHAPES = (2, 3) - def __init__(self): + def __init__(self) -> None: super().__init__() self.linear1 = nn.Linear(5, 3) self.linear2 = nn.Linear(5, 3) @@ -714,7 +714,7 @@ class Ndim3Output(ShapedModule): INPUT_SHAPES = (6,) OUTPUT_SHAPES = (2, 3, 4) - def __init__(self): + def __init__(self) -> None: super().__init__() self.tensor = nn.Parameter(torch.randn(6, 2, 3, 4)) @@ -728,7 +728,7 @@ class Ndim4Output(ShapedModule): INPUT_SHAPES = (6,) OUTPUT_SHAPES = (2, 3, 4, 5) - def __init__(self): + def __init__(self) -> None: super().__init__() self.tensor = nn.Parameter(torch.randn(6, 2, 3, 4, 5)) @@ -742,7 +742,7 @@ class WithRNN(ShapedModule): INPUT_SHAPES = (20, 8) # Size 20, dim input_size (8) OUTPUT_SHAPES = (20, 5) # Size 20, dim hidden_size (5) - def __init__(self): + def __init__(self) -> None: super().__init__() self.rnn = nn.RNN(input_size=8, hidden_size=5, batch_first=True) @@ -757,7 +757,7 @@ class WithDropout(ShapedModule): INPUT_SHAPES = (3, 6, 6) OUTPUT_SHAPES = (3, 4, 4) - def __init__(self): + def __init__(self) -> None: super().__init__() self.conv2d = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3) self.dropout = nn.Dropout2d(p=0.5) @@ -775,7 +775,7 @@ class ModelUsingSubmoduleParamsDirectly(ShapedModule): INPUT_SHAPES = (2,) OUTPUT_SHAPES = (3,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.linear = nn.Linear(2, 3) @@ -791,7 +791,7 @@ class ModelAlsoUsingSubmoduleParamsDirectly(ShapedModule): INPUT_SHAPES = (2,) OUTPUT_SHAPES = (3,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.linear = nn.Linear(2, 3) @@ -800,7 +800,7 @@ def forward(self, input: Tensor) -> Tensor: class _WithStringArg(nn.Module): - def __init__(self): + def __init__(self) -> None: super().__init__() self.matrix = nn.Parameter(torch.randn(2, 3)) @@ -816,7 +816,7 @@ class WithModuleWithStringArg(ShapedModule): INPUT_SHAPES = (2,) OUTPUT_SHAPES = (3,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.with_string_arg = _WithStringArg() @@ -830,7 +830,7 @@ class WithModuleWithStringKwarg(ShapedModule): INPUT_SHAPES = (2,) OUTPUT_SHAPES = (3,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.with_string_arg = _WithStringArg() @@ -839,7 +839,7 @@ def forward(self, input: Tensor) -> Tensor: class _WithHybridPyTreeArg(nn.Module): - def __init__(self): + def __init__(self) -> None: super().__init__() self.m0 = nn.Parameter(torch.randn(3, 3)) self.m1 = nn.Parameter(torch.randn(4, 3)) @@ -869,7 +869,7 @@ class WithModuleWithHybridPyTreeArg(ShapedModule): INPUT_SHAPES = (10,) OUTPUT_SHAPES = (3,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.linear = nn.Linear(10, 18) self.with_string_arg = _WithHybridPyTreeArg() @@ -898,7 +898,7 @@ class WithModuleWithHybridPyTreeKwarg(ShapedModule): INPUT_SHAPES = (10,) OUTPUT_SHAPES = (3,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.linear = nn.Linear(10, 18) self.with_string_arg = _WithHybridPyTreeArg() @@ -925,14 +925,14 @@ class WithModuleWithStringOutput(ShapedModule): OUTPUT_SHAPES = (3,) class WithStringOutput(nn.Module): - def __init__(self): + def __init__(self) -> None: super().__init__() self.matrix = nn.Parameter(torch.randn(2, 3)) def forward(self, input: Tensor) -> tuple[str, Tensor]: return "test", input @ self.matrix - def __init__(self): + def __init__(self) -> None: super().__init__() self.with_string_output = self.WithStringOutput() @@ -947,7 +947,7 @@ class WithMultiHeadAttention(ShapedModule): INPUT_SHAPES = ((20, 8), (10, 9), (10, 11)) OUTPUT_SHAPES = (20, 8) - def __init__(self): + def __init__(self) -> None: super().__init__() self.mha = nn.MultiheadAttention( embed_dim=8, @@ -970,7 +970,7 @@ class WithTransformer(ShapedModule): INPUT_SHAPES = ((10, 8), (20, 8)) OUTPUT_SHAPES = (20, 8) - def __init__(self): + def __init__(self) -> None: super().__init__() self.transformer = nn.Transformer( d_model=8, @@ -993,7 +993,7 @@ class WithTransformerLarge(ShapedModule): INPUT_SHAPES = ((10, 512), (20, 512)) OUTPUT_SHAPES = (20, 512) - def __init__(self): + def __init__(self) -> None: super().__init__() self.transformer = nn.Transformer( batch_first=True, @@ -1014,7 +1014,7 @@ class FreeParam(ShapedModule): INPUT_SHAPES = (15,) OUTPUT_SHAPES = (80,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.matrix = nn.Parameter(torch.randn(15, 16)) # Free parameter self.relu = nn.ReLU() @@ -1041,7 +1041,7 @@ class NoFreeParam(ShapedModule): INPUT_SHAPES = (15,) OUTPUT_SHAPES = (80,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.linear0 = nn.Linear(15, 16, bias=False) self.relu = nn.ReLU() @@ -1071,7 +1071,7 @@ class Body(ShapedModule): INPUT_SHAPES = (3, 32, 32) OUTPUT_SHAPES = (1024,) - def __init__(self): + def __init__(self) -> None: super().__init__() layers = [ nn.Conv2d(3, 32, 3), @@ -1092,7 +1092,7 @@ class Head(ShapedModule): INPUT_SHAPES = (1024,) OUTPUT_SHAPES = (10,) - def __init__(self): + def __init__(self) -> None: super().__init__() layers = [ nn.Linear(1024, 128), @@ -1107,7 +1107,7 @@ def forward(self, input: Tensor) -> Tensor: INPUT_SHAPES = Body.INPUT_SHAPES OUTPUT_SHAPES = Head.OUTPUT_SHAPES - def __init__(self): + def __init__(self) -> None: super().__init__() self.body = self.Body() self.head = self.Head() @@ -1128,7 +1128,7 @@ class AlexNet(ShapedModule): INPUT_SHAPES = (3, 224, 224) OUTPUT_SHAPES = (1000,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.alexnet = torchvision.models.alexnet() @@ -1145,7 +1145,7 @@ class InstanceNormResNet18(ShapedModule): INPUT_SHAPES = (3, 224, 224) OUTPUT_SHAPES = (1000,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.resnet18 = torchvision.models.resnet18( norm_layer=partial(nn.InstanceNorm2d, track_running_stats=False, affine=True), @@ -1161,7 +1161,7 @@ class GroupNormMobileNetV3Small(ShapedModule): INPUT_SHAPES = (3, 224, 224) OUTPUT_SHAPES = (1000,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.mobile_net = torchvision.models.mobilenet_v3_small( norm_layer=partial(nn.GroupNorm, 2, affine=True), @@ -1177,7 +1177,7 @@ class SqueezeNet(ShapedModule): INPUT_SHAPES = (3, 224, 224) OUTPUT_SHAPES = (1000,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.squeezenet = torchvision.models.squeezenet1_0() @@ -1191,7 +1191,7 @@ class InstanceNormMobileNetV2(ShapedModule): INPUT_SHAPES = (3, 224, 224) OUTPUT_SHAPES = (1000,) - def __init__(self): + def __init__(self) -> None: super().__init__() self.mobilenet = torchvision.models.mobilenet_v2( norm_layer=partial(nn.InstanceNorm2d, track_running_stats=False, affine=True), diff --git a/tests/utils/asserts.py b/tests/utils/asserts.py index 836110a8c..392422cff 100644 --- a/tests/utils/asserts.py +++ b/tests/utils/asserts.py @@ -1,3 +1,5 @@ +from typing import Any + import torch from torch import Tensor from torch.testing import assert_close @@ -16,7 +18,7 @@ def assert_has_no_jac(t: Tensor) -> None: assert not is_tensor_with_jac(t) -def assert_jac_close(t: Tensor, expected_jac: Tensor, **kwargs) -> None: +def assert_jac_close(t: Tensor, expected_jac: Tensor, **kwargs: Any) -> None: assert is_tensor_with_jac(t) assert_close(t.jac, expected_jac, **kwargs) @@ -29,12 +31,12 @@ def assert_has_no_grad(t: Tensor) -> None: assert t.grad is None -def assert_grad_close(t: Tensor, expected_grad: Tensor, **kwargs) -> None: +def assert_grad_close(t: Tensor, expected_grad: Tensor, **kwargs: Any) -> None: assert t.grad is not None assert_close(t.grad, expected_grad, **kwargs) -def assert_is_psd_matrix(matrix: Tensor, **kwargs) -> None: +def assert_is_psd_matrix(matrix: Tensor, **kwargs: Any) -> None: assert is_psd_matrix(matrix) assert_close(matrix, matrix.mH, **kwargs) @@ -44,7 +46,7 @@ def assert_is_psd_matrix(matrix: Tensor, **kwargs) -> None: assert_close(eig_vals, expected_eig_vals, **kwargs) -def assert_is_psd_tensor(t: Tensor, **kwargs) -> None: +def assert_is_psd_tensor(t: Tensor, **kwargs: Any) -> None: assert is_psd_tensor(t) matrix = flatten(t) assert_is_psd_matrix(matrix, **kwargs) diff --git a/tests/utils/forward_backwards.py b/tests/utils/forward_backwards.py index 57f9b90a0..008fbad4b 100644 --- a/tests/utils/forward_backwards.py +++ b/tests/utils/forward_backwards.py @@ -1,4 +1,5 @@ from collections.abc import Callable +from types import TracebackType import torch from torch import Tensor, nn, vmap @@ -162,7 +163,7 @@ class CloneParams: algorithm rather than a module-based algorithm. """ - def __init__(self, model: nn.Module): + def __init__(self, model: nn.Module) -> None: self.model = model self.clones = list[nn.Parameter]() self._module_to_original_params = dict[nn.Module, dict[str, nn.Parameter]]() @@ -192,7 +193,12 @@ def post_hook(module: nn.Module, _, __) -> None: return self.clones - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool: """Remove hooks and restore parameters.""" for handle in self._handles: handle.remove() @@ -201,7 +207,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): return False # don't suppress exceptions - def _restore_original_params(self, module: nn.Module): + def _restore_original_params(self, module: nn.Module) -> None: original_params = self._module_to_original_params.pop(module, {}) for name, param in original_params.items(): self._set_module_param(module, name, param) diff --git a/tests/utils/tensors.py b/tests/utils/tensors.py index 7988157df..8febba429 100644 --- a/tests/utils/tensors.py +++ b/tests/utils/tensors.py @@ -37,7 +37,7 @@ def make_inputs_and_targets(model: nn.Module, batch_size: int) -> tuple[PyTree, def _make_tensors(batch_size: int, tensor_shapes: PyTree) -> PyTree: - def is_leaf(s): + def is_leaf(s: PyTree) -> bool: return isinstance(s, tuple) and all(isinstance(e, int) for e in s) return tree_map(lambda s: randn_((batch_size, *s)), tensor_shapes, is_leaf=is_leaf)