diff --git a/src/torchjd/autojac/_jac.py b/src/torchjd/autojac/_jac.py index 78c10df2..404bb730 100644 --- a/src/torchjd/autojac/_jac.py +++ b/src/torchjd/autojac/_jac.py @@ -1,4 +1,4 @@ -from collections.abc import Iterable, Sequence +from collections.abc import Sequence from torch import Tensor @@ -13,13 +13,12 @@ check_matching_jac_shapes, check_matching_length, check_optional_positive_chunk_size, - get_leaf_tensors, ) def jac( outputs: Sequence[Tensor] | Tensor, - inputs: Iterable[Tensor] | None = None, + inputs: Sequence[Tensor] | Tensor, *, jac_outputs: Sequence[Tensor] | Tensor | None = None, retain_graph: bool = False, @@ -32,9 +31,8 @@ def jac( ``[m] + t.shape``. :param outputs: The tensor or tensors to differentiate. Should be non-empty. - :param inputs: The tensors with respect to which the Jacobian must be computed. These must have - their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors - that were used to compute the ``outputs`` parameter. + :param inputs: The tensor or tensors with respect to which the Jacobian must be computed. These + must have their ``requires_grad`` flag set to ``True``. :param jac_outputs: The initial Jacobians to backpropagate, analog to the ``grad_outputs`` parameter of :func:`torch.autograd.grad`. If provided, it must have the same structure as ``outputs`` and each tensor in ``jac_outputs`` must match the shape of the corresponding @@ -69,7 +67,7 @@ def jac( >>> y1 = torch.tensor([-1., 1.]) @ param >>> y2 = (param ** 2).sum() >>> - >>> jacobians = jac([y1, y2], [param]) + >>> jacobians = jac([y1, y2], param) >>> >>> jacobians (tensor([[-1., 1.], @@ -131,13 +129,13 @@ def jac( >>> jac_h = jac([y1, y2], [h])[0] # Shape: [2, 2] >>> >>> # Step 2: Use chain rule to compute d[y1,y2]/dx = (d[y1,y2]/dh) @ (dh/dx) - >>> jac_x = jac(h, [x], jac_outputs=jac_h)[0] + >>> jac_x = jac(h, x, jac_outputs=jac_h)[0] >>> >>> jac_x tensor([[ 2., 4.], [ 2., -4.]]) - This two-step computation is equivalent to directly computing ``jac([y1, y2], [x])``. + This two-step computation is equivalent to directly computing ``jac([y1, y2], x)``. .. warning:: To differentiate in parallel, ``jac`` relies on ``torch.vmap``, which has some @@ -155,12 +153,9 @@ def jac( if len(outputs_) == 0: raise ValueError("`outputs` cannot be empty.") - if inputs is None: - inputs_ = get_leaf_tensors(tensors=outputs_, excluded=set()) - inputs_with_repetition = list(inputs_) - else: - inputs_with_repetition = list(inputs) # Create a list to avoid emptying generator - inputs_ = OrderedSet(inputs_with_repetition) + # Preserve repetitions to duplicate jacobians at the return statement + inputs_with_repetition = (inputs,) if isinstance(inputs, Tensor) else inputs + inputs_ = OrderedSet(inputs_with_repetition) jac_outputs_dict = _create_jac_outputs_dict(outputs_, jac_outputs) transform = _create_transform(outputs_, inputs_, parallel_chunk_size, retain_graph) diff --git a/tests/doc/test_jac.py b/tests/doc/test_jac.py index 92e8745c..e195c233 100644 --- a/tests/doc/test_jac.py +++ b/tests/doc/test_jac.py @@ -14,7 +14,7 @@ def test_jac(): # Compute arbitrary quantities that are function of param y1 = torch.tensor([-1.0, 1.0]) @ param y2 = (param**2).sum() - jacobians = jac([y1, y2], [param]) + jacobians = jac([y1, y2], param) assert len(jacobians) == 1 assert_close(jacobians[0], torch.tensor([[-1.0, 1.0], [2.0, 4.0]]), rtol=0.0, atol=1e-04) @@ -57,6 +57,6 @@ def test_jac_3(): # Step 1: Compute d[y1,y2]/dh jac_h = jac([y1, y2], [h])[0] # Shape: [2, 2] # Step 2: Use jac_outputs to compute d[y1,y2]/dx = (d[y1,y2]/dh) @ (dh/dx) - jac_x = jac(h, [x], jac_outputs=jac_h)[0] + jac_x = jac(h, x, jac_outputs=jac_h)[0] assert_close(jac_x, torch.tensor([[2.0, 4.0], [2.0, -4.0]]), rtol=0.0, atol=1e-04) diff --git a/tests/unit/autojac/test_jac.py b/tests/unit/autojac/test_jac.py index 9b108be5..3ee6561f 100644 --- a/tests/unit/autojac/test_jac.py +++ b/tests/unit/autojac/test_jac.py @@ -56,13 +56,15 @@ def test_jac(): assert jacobian.shape[1:] == a.shape -@mark.parametrize("shape", [(1, 3), (2, 3), (2, 6), (5, 8), (20, 55)]) -@mark.parametrize("manually_specify_inputs", [True, False]) +@mark.parametrize("shape", [(1, 1), (1, 3), (2, 1), (2, 6), (20, 55)]) @mark.parametrize("chunk_size", [1, 2, None]) +@mark.parametrize("outputs_is_list", [True, False]) +@mark.parametrize("inputs_is_list", [True, False]) def test_value_is_correct( shape: tuple[int, int], - manually_specify_inputs: bool, chunk_size: int | None, + outputs_is_list: bool, + inputs_is_list: bool, ): """ Tests that the jacobians returned by jac are correct in a simple example of matrix-vector @@ -73,13 +75,10 @@ def test_value_is_correct( input = randn_([shape[1]], requires_grad=True) output = J @ input # Note that the Jacobian of output w.r.t. input is J. - inputs = [input] if manually_specify_inputs else None + outputs = [output] if outputs_is_list else output + inputs = [input] if inputs_is_list else input - jacobians = jac( - [output], - inputs=inputs, - parallel_chunk_size=chunk_size, - ) + jacobians = jac(outputs, inputs, parallel_chunk_size=chunk_size) assert len(jacobians) == 1 assert_close(jacobians[0], J) @@ -103,7 +102,7 @@ def test_jac_outputs_value_is_correct(rows: int): jacobians = jac( output, - inputs=[input], + input, jac_outputs=J_init, ) @@ -126,7 +125,7 @@ def test_jac_outputs_multiple_components(rows: int): J1 = randn_((rows, 2)) J2 = randn_((rows, 3)) - jacobians = jac([y1, y2], inputs=[input], jac_outputs=[J1, J2]) + jacobians = jac([y1, y2], input, jac_outputs=[J1, J2]) jac_y1 = eye_(2) * 2 @@ -149,7 +148,7 @@ def test_jac_outputs_length_mismatch(): ValueError, match=r"`jac_outputs` should have the same length as `outputs`\. \(got 1 and 2\)", ): - jac([y1, y2], inputs=[x], jac_outputs=[J1]) + jac([y1, y2], x, jac_outputs=[J1]) def test_jac_outputs_shape_mismatch(): @@ -166,7 +165,7 @@ def test_jac_outputs_shape_mismatch(): ValueError, match=r"Shape mismatch: `jac_outputs\[0\]` has shape .* but `outputs\[0\]` has shape .*\.", ): - jac(y, inputs=[x], jac_outputs=J_bad) + jac(y, x, jac_outputs=J_bad) @mark.parametrize( @@ -192,7 +191,7 @@ def test_jac_outputs_inconsistent_first_dimension(rows_y1: int, rows_y2: int): with raises( ValueError, match=r"All Jacobians in `jac_outputs` should have the same number of rows\." ): - jac([y1, y2], inputs=[x], jac_outputs=[j1, j2]) + jac([y1, y2], x, jac_outputs=[j1, j2]) def test_empty_inputs(): @@ -220,7 +219,7 @@ def test_partial_inputs(): y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum() y2 = (a1**2).sum() + a2.norm() - jacobians = jac([y1, y2], inputs=[a1]) + jacobians = jac([y1, y2], a1) assert len(jacobians) == 1 @@ -250,7 +249,7 @@ def test_multiple_tensors(): y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum() y2 = (a1**2).sum() + a2.norm() - jacobians = jac([y1, y2]) + jacobians = jac([y1, y2], [a1, a2]) assert len(jacobians) == 2 assert_close(jacobians[0], J1) assert_close(jacobians[1], J2) @@ -262,7 +261,7 @@ def test_multiple_tensors(): z1 = tensor_([-1.0, 1.0]) @ b1 + b2.sum() z2 = (b1**2).sum() + b2.norm() - jacobians = jac(torch.cat([z1.reshape(-1), z2.reshape(-1)])) + jacobians = jac(torch.cat([z1.reshape(-1), z2.reshape(-1)]), [b1, b2]) assert len(jacobians) == 2 assert_close(jacobians[0], J1) assert_close(jacobians[1], J2) @@ -278,7 +277,7 @@ def test_various_valid_chunk_sizes(chunk_size): y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum() y2 = (a1**2).sum() + a2.norm() - jacobians = jac([y1, y2], parallel_chunk_size=chunk_size) + jacobians = jac([y1, y2], [a1, a2], parallel_chunk_size=chunk_size) assert len(jacobians) == 2 @@ -293,7 +292,7 @@ def test_non_positive_chunk_size_fails(chunk_size: int): y2 = (a1**2).sum() + a2.norm() with raises(ValueError): - jac([y1, y2], parallel_chunk_size=chunk_size) + jac([y1, y2], [a1, a2], parallel_chunk_size=chunk_size) def test_input_retaining_grad_fails(): @@ -309,7 +308,7 @@ def test_input_retaining_grad_fails(): # jac itself doesn't raise the error, but it fills b.grad with a BatchedTensor (and it also # returns the correct Jacobian) - jac(y, inputs=[b]) + jac(y, b) with raises(RuntimeError): # Using such a BatchedTensor should result in an error @@ -328,7 +327,7 @@ def test_non_input_retaining_grad_fails(): y = 3 * b # jac itself doesn't raise the error, but it fills b.grad with a BatchedTensor - jac(y, inputs=[a]) + jac(y, a) with raises(RuntimeError): # Using such a BatchedTensor should result in an error @@ -348,7 +347,7 @@ def test_tensor_used_multiple_times(chunk_size: int | None): d = a * c e = a * d - jacobians = jac([d, e], parallel_chunk_size=chunk_size) + jacobians = jac([d, e], a, parallel_chunk_size=chunk_size) assert len(jacobians) == 1 J = tensor_([2.0 * 3.0 * (a**2).item(), 2.0 * 4.0 * (a**3).item()]) @@ -372,7 +371,7 @@ def test_repeated_tensors(): y2 = (a1**2).sum() + (a2**2).sum() with raises(ValueError): - jac([y1, y1, y2]) + jac([y1, y1, y2], [a1, a2]) def test_repeated_inputs():