Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 10 additions & 15 deletions src/torchjd/autojac/_jac.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Iterable, Sequence
from collections.abc import Sequence

from torch import Tensor

Expand All @@ -13,13 +13,12 @@
check_matching_jac_shapes,
check_matching_length,
check_optional_positive_chunk_size,
get_leaf_tensors,
)


def jac(
outputs: Sequence[Tensor] | Tensor,
inputs: Iterable[Tensor] | None = None,
inputs: Sequence[Tensor] | Tensor,
*,
jac_outputs: Sequence[Tensor] | Tensor | None = None,
retain_graph: bool = False,
Expand All @@ -32,9 +31,8 @@ def jac(
``[m] + t.shape``.

:param outputs: The tensor or tensors to differentiate. Should be non-empty.
:param inputs: The tensors with respect to which the Jacobian must be computed. These must have
their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors
that were used to compute the ``outputs`` parameter.
:param inputs: The tensor or tensors with respect to which the Jacobian must be computed. These
must have their ``requires_grad`` flag set to ``True``.
:param jac_outputs: The initial Jacobians to backpropagate, analog to the ``grad_outputs``
parameter of :func:`torch.autograd.grad`. If provided, it must have the same structure as
``outputs`` and each tensor in ``jac_outputs`` must match the shape of the corresponding
Expand Down Expand Up @@ -69,7 +67,7 @@ def jac(
>>> y1 = torch.tensor([-1., 1.]) @ param
>>> y2 = (param ** 2).sum()
>>>
>>> jacobians = jac([y1, y2], [param])
>>> jacobians = jac([y1, y2], param)
>>>
>>> jacobians
(tensor([[-1., 1.],
Expand Down Expand Up @@ -131,13 +129,13 @@ def jac(
>>> jac_h = jac([y1, y2], [h])[0] # Shape: [2, 2]
>>>
>>> # Step 2: Use chain rule to compute d[y1,y2]/dx = (d[y1,y2]/dh) @ (dh/dx)
>>> jac_x = jac(h, [x], jac_outputs=jac_h)[0]
>>> jac_x = jac(h, x, jac_outputs=jac_h)[0]
>>>
>>> jac_x
tensor([[ 2., 4.],
[ 2., -4.]])

This two-step computation is equivalent to directly computing ``jac([y1, y2], [x])``.
This two-step computation is equivalent to directly computing ``jac([y1, y2], x)``.

.. warning::
To differentiate in parallel, ``jac`` relies on ``torch.vmap``, which has some
Expand All @@ -155,12 +153,9 @@ def jac(
if len(outputs_) == 0:
raise ValueError("`outputs` cannot be empty.")

if inputs is None:
inputs_ = get_leaf_tensors(tensors=outputs_, excluded=set())
inputs_with_repetition = list(inputs_)
else:
inputs_with_repetition = list(inputs) # Create a list to avoid emptying generator
inputs_ = OrderedSet(inputs_with_repetition)
# Preserve repetitions to duplicate jacobians at the return statement
inputs_with_repetition = (inputs,) if isinstance(inputs, Tensor) else inputs
inputs_ = OrderedSet(inputs_with_repetition)

jac_outputs_dict = _create_jac_outputs_dict(outputs_, jac_outputs)
transform = _create_transform(outputs_, inputs_, parallel_chunk_size, retain_graph)
Expand Down
4 changes: 2 additions & 2 deletions tests/doc/test_jac.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_jac():
# Compute arbitrary quantities that are function of param
y1 = torch.tensor([-1.0, 1.0]) @ param
y2 = (param**2).sum()
jacobians = jac([y1, y2], [param])
jacobians = jac([y1, y2], param)

assert len(jacobians) == 1
assert_close(jacobians[0], torch.tensor([[-1.0, 1.0], [2.0, 4.0]]), rtol=0.0, atol=1e-04)
Expand Down Expand Up @@ -57,6 +57,6 @@ def test_jac_3():
# Step 1: Compute d[y1,y2]/dh
jac_h = jac([y1, y2], [h])[0] # Shape: [2, 2]
# Step 2: Use jac_outputs to compute d[y1,y2]/dx = (d[y1,y2]/dh) @ (dh/dx)
jac_x = jac(h, [x], jac_outputs=jac_h)[0]
jac_x = jac(h, x, jac_outputs=jac_h)[0]

assert_close(jac_x, torch.tensor([[2.0, 4.0], [2.0, -4.0]]), rtol=0.0, atol=1e-04)
45 changes: 22 additions & 23 deletions tests/unit/autojac/test_jac.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,15 @@ def test_jac():
assert jacobian.shape[1:] == a.shape


@mark.parametrize("shape", [(1, 3), (2, 3), (2, 6), (5, 8), (20, 55)])
@mark.parametrize("manually_specify_inputs", [True, False])
@mark.parametrize("shape", [(1, 1), (1, 3), (2, 1), (2, 6), (20, 55)])
@mark.parametrize("chunk_size", [1, 2, None])
@mark.parametrize("outputs_is_list", [True, False])
@mark.parametrize("inputs_is_list", [True, False])
def test_value_is_correct(
shape: tuple[int, int],
manually_specify_inputs: bool,
chunk_size: int | None,
outputs_is_list: bool,
inputs_is_list: bool,
):
"""
Tests that the jacobians returned by jac are correct in a simple example of matrix-vector
Expand All @@ -73,13 +75,10 @@ def test_value_is_correct(
input = randn_([shape[1]], requires_grad=True)
output = J @ input # Note that the Jacobian of output w.r.t. input is J.

inputs = [input] if manually_specify_inputs else None
outputs = [output] if outputs_is_list else output
inputs = [input] if inputs_is_list else input

jacobians = jac(
[output],
inputs=inputs,
parallel_chunk_size=chunk_size,
)
jacobians = jac(outputs, inputs, parallel_chunk_size=chunk_size)

assert len(jacobians) == 1
assert_close(jacobians[0], J)
Expand All @@ -103,7 +102,7 @@ def test_jac_outputs_value_is_correct(rows: int):

jacobians = jac(
output,
inputs=[input],
input,
jac_outputs=J_init,
)

Expand All @@ -126,7 +125,7 @@ def test_jac_outputs_multiple_components(rows: int):
J1 = randn_((rows, 2))
J2 = randn_((rows, 3))

jacobians = jac([y1, y2], inputs=[input], jac_outputs=[J1, J2])
jacobians = jac([y1, y2], input, jac_outputs=[J1, J2])

jac_y1 = eye_(2) * 2

Expand All @@ -149,7 +148,7 @@ def test_jac_outputs_length_mismatch():
ValueError,
match=r"`jac_outputs` should have the same length as `outputs`\. \(got 1 and 2\)",
):
jac([y1, y2], inputs=[x], jac_outputs=[J1])
jac([y1, y2], x, jac_outputs=[J1])


def test_jac_outputs_shape_mismatch():
Expand All @@ -166,7 +165,7 @@ def test_jac_outputs_shape_mismatch():
ValueError,
match=r"Shape mismatch: `jac_outputs\[0\]` has shape .* but `outputs\[0\]` has shape .*\.",
):
jac(y, inputs=[x], jac_outputs=J_bad)
jac(y, x, jac_outputs=J_bad)


@mark.parametrize(
Expand All @@ -192,7 +191,7 @@ def test_jac_outputs_inconsistent_first_dimension(rows_y1: int, rows_y2: int):
with raises(
ValueError, match=r"All Jacobians in `jac_outputs` should have the same number of rows\."
):
jac([y1, y2], inputs=[x], jac_outputs=[j1, j2])
jac([y1, y2], x, jac_outputs=[j1, j2])


def test_empty_inputs():
Expand Down Expand Up @@ -220,7 +219,7 @@ def test_partial_inputs():
y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum()
y2 = (a1**2).sum() + a2.norm()

jacobians = jac([y1, y2], inputs=[a1])
jacobians = jac([y1, y2], a1)
assert len(jacobians) == 1


Expand Down Expand Up @@ -250,7 +249,7 @@ def test_multiple_tensors():
y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum()
y2 = (a1**2).sum() + a2.norm()

jacobians = jac([y1, y2])
jacobians = jac([y1, y2], [a1, a2])
assert len(jacobians) == 2
assert_close(jacobians[0], J1)
assert_close(jacobians[1], J2)
Expand All @@ -262,7 +261,7 @@ def test_multiple_tensors():
z1 = tensor_([-1.0, 1.0]) @ b1 + b2.sum()
z2 = (b1**2).sum() + b2.norm()

jacobians = jac(torch.cat([z1.reshape(-1), z2.reshape(-1)]))
jacobians = jac(torch.cat([z1.reshape(-1), z2.reshape(-1)]), [b1, b2])
assert len(jacobians) == 2
assert_close(jacobians[0], J1)
assert_close(jacobians[1], J2)
Expand All @@ -278,7 +277,7 @@ def test_various_valid_chunk_sizes(chunk_size):
y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum()
y2 = (a1**2).sum() + a2.norm()

jacobians = jac([y1, y2], parallel_chunk_size=chunk_size)
jacobians = jac([y1, y2], [a1, a2], parallel_chunk_size=chunk_size)
assert len(jacobians) == 2


Expand All @@ -293,7 +292,7 @@ def test_non_positive_chunk_size_fails(chunk_size: int):
y2 = (a1**2).sum() + a2.norm()

with raises(ValueError):
jac([y1, y2], parallel_chunk_size=chunk_size)
jac([y1, y2], [a1, a2], parallel_chunk_size=chunk_size)


def test_input_retaining_grad_fails():
Expand All @@ -309,7 +308,7 @@ def test_input_retaining_grad_fails():

# jac itself doesn't raise the error, but it fills b.grad with a BatchedTensor (and it also
# returns the correct Jacobian)
jac(y, inputs=[b])
jac(y, b)

with raises(RuntimeError):
# Using such a BatchedTensor should result in an error
Expand All @@ -328,7 +327,7 @@ def test_non_input_retaining_grad_fails():
y = 3 * b

# jac itself doesn't raise the error, but it fills b.grad with a BatchedTensor
jac(y, inputs=[a])
jac(y, a)

with raises(RuntimeError):
# Using such a BatchedTensor should result in an error
Expand All @@ -348,7 +347,7 @@ def test_tensor_used_multiple_times(chunk_size: int | None):
d = a * c
e = a * d

jacobians = jac([d, e], parallel_chunk_size=chunk_size)
jacobians = jac([d, e], a, parallel_chunk_size=chunk_size)
assert len(jacobians) == 1

J = tensor_([2.0 * 3.0 * (a**2).item(), 2.0 * 4.0 * (a**3).item()])
Expand All @@ -372,7 +371,7 @@ def test_repeated_tensors():
y2 = (a1**2).sum() + (a2**2).sum()

with raises(ValueError):
jac([y1, y1, y2])
jac([y1, y1, y2], [a1, a2])


def test_repeated_inputs():
Expand Down