Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
2d66897
Add ANN ruff rule
ValerianRey Feb 19, 2026
f27a630
Add missing -> None to __init__
ValerianRey Feb 19, 2026
0cb67ca
Improve ANN rule configuration
ValerianRey Feb 19, 2026
f67f8d2
Fix untyped ctx
ValerianRey Feb 19, 2026
905c658
Add missing -> None in tests
ValerianRey Feb 19, 2026
f0330c1
Add missing -> str
ValerianRey Feb 19, 2026
ecc011b
Annotate obj as object in docs/source/conf.py
ValerianRey Feb 19, 2026
3c022e7
Add missing type annotations in conftest.py
ValerianRey Feb 19, 2026
0d7ea2e
Add type annotations in _make_tensors
ValerianRey Feb 19, 2026
0d88143
Add type annotations to CloneParams
ValerianRey Feb 19, 2026
eb182c4
Add type annotation to InterModuleParamReuse.forward
ValerianRey Feb 19, 2026
e11de01
Add -> None to __init_subclass__
ValerianRey Feb 19, 2026
aacf671
Add missing type annotation for chunk_size
ValerianRey Feb 19, 2026
4c2b730
Add type annotations to time_call
ValerianRey Feb 19, 2026
2f53b64
Add type annotations in run_profiler.py
ValerianRey Feb 19, 2026
02f4980
Add -> None to main in plot_memory_timeline.py
ValerianRey Feb 19, 2026
040570d
Add return type annotation in MemoryFrame.from_event
ValerianRey Feb 19, 2026
7837c10
Add type annotations in static_plotter and rename a variable
ValerianRey Feb 19, 2026
0fe3b9b
Add noqa comment in test_lightning_integration
ValerianRey Feb 19, 2026
693db5f
Add type annotation to update_gradient_coordinate
ValerianRey Feb 19, 2026
802b777
Add annotation as Any for args and kwargs
ValerianRey Feb 19, 2026
bab8995
Add -> to test_noncontiguous_jac
ValerianRey Feb 19, 2026
bd1af4e
Merge branch 'main' into add-missing-annotations
ValerianRey Feb 20, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def linkcode_resolve(domain: str, info: dict[str, str]) -> str | None:
return link


def _get_obj(_info: dict[str, str]):
def _get_obj(_info: dict[str, str]) -> object:
module_name = _info["module"]
full_name = _info["fullname"]
sub_module = sys.modules.get(module_name)
Expand All @@ -112,7 +112,7 @@ def _get_obj(_info: dict[str, str]):
return obj


def _get_file_name(obj) -> str | None:
def _get_file_name(obj: object) -> str | None:
try:
file_name = inspect.getsourcefile(obj)
file_name = os.path.relpath(file_name, start=_PATH_ROOT)
Expand All @@ -121,7 +121,7 @@ def _get_file_name(obj) -> str | None:
return file_name


def _get_line_str(obj) -> str:
def _get_line_str(obj: object) -> str:
source, start = inspect.getsourcelines(obj)
end = start + len(source) - 1
line_str = f"#L{start}-L{end}"
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ select = [
"I", # isort
"UP", # pyupgrade
"ARG", # flake8-unused-arguments
"ANN", # flake-8-annotations
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"FIX", # flake8-fixme
Expand All @@ -156,12 +157,16 @@ ignore = [
"RUF012", # Mutable default value for class attribute (a bit tedious to fix)
"RET504", # Unnecessary assignment return statement
"COM812", # Trailing comma missing (conflicts with formatter, see https://github.com/astral-sh/ruff/issues/9216)
"ANN401", # Prevent annotating as Any (we rarely do that, and when we do it's hard to find an alternative)
]

[tool.ruff.lint.per-file-ignores]
"**/conftest.py" = ["ARG"] # Can't change argument names in the functions pytest expects
"tests/doc/test_rst.py" = ["ARG"] # For the lightning example

[tool.ruff.lint.flake8-annotations]
suppress-dummy-args = true

[tool.ruff.lint.isort]
combine-as-imports = true

Expand Down
6 changes: 3 additions & 3 deletions src/torchjd/aggregation/_aggregator_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Aggregator(nn.Module, ABC):
:math:`m \times n` into row vectors of dimension :math:`n`.
"""

def __init__(self):
def __init__(self) -> None:
super().__init__()

@staticmethod
Expand Down Expand Up @@ -48,7 +48,7 @@ class WeightedAggregator(Aggregator):
:param weighting: The object responsible for extracting the vector of weights from the matrix.
"""

def __init__(self, weighting: Weighting[Matrix]):
def __init__(self, weighting: Weighting[Matrix]) -> None:
super().__init__()
self.weighting = weighting

Expand Down Expand Up @@ -77,6 +77,6 @@ class GramianWeightedAggregator(WeightedAggregator):
gramian.
"""

def __init__(self, gramian_weighting: Weighting[PSDMatrix]):
def __init__(self, gramian_weighting: Weighting[PSDMatrix]) -> None:
super().__init__(gramian_weighting << compute_gramian)
self.gramian_weighting = gramian_weighting
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_aligned_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
self,
pref_vector: Tensor | None = None,
scale_mode: SUPPORTED_SCALE_MODE = "min",
):
) -> None:
self._pref_vector = pref_vector
self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode
super().__init__(AlignedMTLWeighting(pref_vector, scale_mode=scale_mode))
Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(
self,
pref_vector: Tensor | None = None,
scale_mode: SUPPORTED_SCALE_MODE = "min",
):
) -> None:
super().__init__()
self._pref_vector = pref_vector
self._scale_mode: SUPPORTED_SCALE_MODE = scale_mode
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_cagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class CAGrad(GramianWeightedAggregator):
To install it, use ``pip install "torchjd[cagrad]"``.
"""

def __init__(self, c: float, norm_eps: float = 0.0001):
def __init__(self, c: float, norm_eps: float = 0.0001) -> None:
super().__init__(CAGradWeighting(c=c, norm_eps=norm_eps))
self._c = c
self._norm_eps = norm_eps
Expand Down Expand Up @@ -67,7 +67,7 @@ class CAGradWeighting(Weighting[PSDMatrix]):
function.
"""

def __init__(self, c: float, norm_eps: float = 0.0001):
def __init__(self, c: float, norm_eps: float = 0.0001) -> None:
super().__init__()

if c < 0.0:
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class ConFIG(Aggregator):
<https://github.com/tum-pbs/ConFIG/tree/main/conflictfree>`_.
"""

def __init__(self, pref_vector: Tensor | None = None):
def __init__(self, pref_vector: Tensor | None = None) -> None:
super().__init__()
self.weighting = pref_vector_to_weighting(pref_vector, default=SumWeighting())
self._pref_vector = pref_vector
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class Constant(WeightedAggregator):
:param weights: The weights associated to the rows of the input matrices.
"""

def __init__(self, weights: Tensor):
def __init__(self, weights: Tensor) -> None:
super().__init__(weighting=ConstantWeighting(weights=weights))
self._weights = weights

Expand All @@ -35,7 +35,7 @@ class ConstantWeighting(Weighting[Matrix]):
:param weights: The weights to return at each call.
"""

def __init__(self, weights: Tensor):
def __init__(self, weights: Tensor) -> None:
if weights.dim() != 1:
raise ValueError(
"Parameter `weights` should be a 1-dimensional tensor. Found `weights.shape = "
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_dualproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
norm_eps: float = 0.0001,
reg_eps: float = 0.0001,
solver: SUPPORTED_SOLVER = "quadprog",
):
) -> None:
self._pref_vector = pref_vector
self._norm_eps = norm_eps
self._reg_eps = reg_eps
Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(
norm_eps: float = 0.0001,
reg_eps: float = 0.0001,
solver: SUPPORTED_SOLVER = "quadprog",
):
) -> None:
super().__init__()
self._pref_vector = pref_vector
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_flattening.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Flattening(GeneralizedWeighting):
:param weighting: The weighting to apply to the Gramian matrix.
"""

def __init__(self, weighting: Weighting):
def __init__(self, weighting: Weighting) -> None:
super().__init__()
self.weighting = weighting

Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_graddrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class GradDrop(Aggregator):
through. Defaults to None, which means no leak.
"""

def __init__(self, f: Callable = _identity, leak: Tensor | None = None):
def __init__(self, f: Callable = _identity, leak: Tensor | None = None) -> None:
if leak is not None and leak.dim() != 1:
raise ValueError(
"Parameter `leak` should be a 1-dimensional tensor. Found `leak.shape = "
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_imtl_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class IMTLG(GramianWeightedAggregator):
<https://arxiv.org/pdf/2406.16232>`_, supports matrices with some linearly dependant rows.
"""

def __init__(self):
def __init__(self) -> None:
super().__init__(IMTLGWeighting())

# This prevents computing gradients that can be very wrong.
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_krum.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Krum(GramianWeightedAggregator):
:param n_selected: The number of selected rows in the context of Multi-Krum. Defaults to 1.
"""

def __init__(self, n_byzantine: int, n_selected: int = 1):
def __init__(self, n_byzantine: int, n_selected: int = 1) -> None:
self._n_byzantine = n_byzantine
self._n_selected = n_selected
super().__init__(KrumWeighting(n_byzantine=n_byzantine, n_selected=n_selected))
Expand All @@ -44,7 +44,7 @@ class KrumWeighting(Weighting[PSDMatrix]):
:param n_selected: The number of selected rows in the context of Multi-Krum. Defaults to 1.
"""

def __init__(self, n_byzantine: int, n_selected: int = 1):
def __init__(self, n_byzantine: int, n_selected: int = 1) -> None:
super().__init__()
if n_byzantine < 0:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Mean(WeightedAggregator):
matrices.
"""

def __init__(self):
def __init__(self) -> None:
super().__init__(weighting=MeanWeighting())


Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_mgda.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class MGDA(GramianWeightedAggregator):
:param max_iters: The maximum number of iterations of the optimization loop.
"""

def __init__(self, epsilon: float = 0.001, max_iters: int = 100):
def __init__(self, epsilon: float = 0.001, max_iters: int = 100) -> None:
super().__init__(MGDAWeighting(epsilon=epsilon, max_iters=max_iters))
self._epsilon = epsilon
self._max_iters = max_iters
Expand All @@ -38,7 +38,7 @@ class MGDAWeighting(Weighting[PSDMatrix]):
:param max_iters: The maximum number of iterations of the optimization loop.
"""

def __init__(self, epsilon: float = 0.001, max_iters: int = 100):
def __init__(self, epsilon: float = 0.001, max_iters: int = 100) -> None:
super().__init__()
self.epsilon = epsilon
self.max_iters = max_iters
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_nash_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
max_norm: float = 1.0,
update_weights_every: int = 1,
optim_niter: int = 20,
):
) -> None:
super().__init__(
weighting=_NashMTLWeighting(
n_tasks=n_tasks,
Expand Down Expand Up @@ -126,7 +126,7 @@ def __init__(
max_norm: float,
update_weights_every: int,
optim_niter: int,
):
) -> None:
super().__init__()

self.n_tasks = n_tasks
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_pcgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class PCGrad(GramianWeightedAggregator):
`Gradient Surgery for Multi-Task Learning <https://arxiv.org/pdf/2001.06782.pdf>`_.
"""

def __init__(self):
def __init__(self) -> None:
super().__init__(PCGradWeighting())

# This prevents running into a RuntimeError due to modifying stored tensors in place.
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Random(WeightedAggregator):
<https://arxiv.org/pdf/2111.10603.pdf>`_.
"""

def __init__(self):
def __init__(self) -> None:
super().__init__(RandomWeighting())


Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Sum(WeightedAggregator):
matrices.
"""

def __init__(self):
def __init__(self) -> None:
super().__init__(weighting=SumWeighting())


Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_trimmed_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class TrimmedMean(Aggregator):
input matrix (note that ``2 * trim_number`` values are removed from each column).
"""

def __init__(self, trim_number: int):
def __init__(self, trim_number: int) -> None:
super().__init__()
if trim_number < 0:
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_upgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
norm_eps: float = 0.0001,
reg_eps: float = 0.0001,
solver: SUPPORTED_SOLVER = "quadprog",
):
) -> None:
self._pref_vector = pref_vector
self._norm_eps = norm_eps
self._reg_eps = reg_eps
Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(
norm_eps: float = 0.0001,
reg_eps: float = 0.0001,
solver: SUPPORTED_SOLVER = "quadprog",
):
) -> None:
super().__init__()
self._pref_vector = pref_vector
self.weighting = pref_vector_to_weighting(pref_vector, default=MeanWeighting())
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_utils/non_differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


class NonDifferentiableError(RuntimeError):
def __init__(self, module: nn.Module):
def __init__(self, module: nn.Module) -> None:
super().__init__(f"Trying to differentiate through {module}, which is not differentiable.")


Expand Down
6 changes: 3 additions & 3 deletions src/torchjd/aggregation/_weighting_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Weighting(nn.Module, ABC, Generic[_T]):
generally its Gramian, of dimension :math:`m \times m`.
"""

def __init__(self):
def __init__(self) -> None:
super().__init__()

@abstractmethod
Expand All @@ -46,7 +46,7 @@ class _Composition(Weighting[_T]):
output of the function.
"""

def __init__(self, weighting: Weighting[_FnOutputT], fn: Callable[[_T], _FnOutputT]):
def __init__(self, weighting: Weighting[_FnOutputT], fn: Callable[[_T], _FnOutputT]) -> None:
super().__init__()
self.fn = fn
self.weighting = weighting
Expand All @@ -63,7 +63,7 @@ class GeneralizedWeighting(nn.Module, ABC):
:math:`m_1 \times \dots \times m_k \times m_k \times \dots \times m_1`.
"""

def __init__(self):
def __init__(self) -> None:
super().__init__()

@abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/autogram/_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def __init__(
self,
*modules: nn.Module,
batch_dim: int | None,
):
) -> None:
self._gramian_accumulator = GramianAccumulator()
self._target_edges = EdgeRegistry()
self._batch_dim = batch_dim
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/autogram/_gramian_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def reset(self) -> None:


class JacobianBasedGramianComputer(GramianComputer, ABC):
def __init__(self, jacobian_computer: JacobianComputer):
def __init__(self, jacobian_computer: JacobianComputer) -> None:
self.jacobian_computer = jacobian_computer


Expand All @@ -39,7 +39,7 @@ class JacobianBasedGramianComputerWithCrossTerms(JacobianBasedGramianComputer):
the gramian.
"""

def __init__(self, jacobian_computer: JacobianComputer):
def __init__(self, jacobian_computer: JacobianComputer) -> None:
super().__init__(jacobian_computer)
self.remaining_counter = 0
self.summed_jacobian: Matrix | None = None
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/autogram/_jacobian_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class JacobianComputer(ABC):
:params module: The module to differentiate.
"""

def __init__(self, module: nn.Module):
def __init__(self, module: nn.Module) -> None:
self.module = module

self.rg_params = dict[str, Parameter]()
Expand Down
Loading