diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 3f1d9a5f..9467b4ef 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -92,7 +92,7 @@ def _disunite_gradient( tensors: list[TensorWithJac], ) -> list[Tensor]: gradient_vectors = gradient_vector.split([t.numel() for t in tensors]) - gradients = [g.view(t.shape) for g, t in zip(gradient_vectors, tensors, strict=True)] + gradients = [g.reshape(t.shape) for g, t in zip(gradient_vectors, tensors, strict=True)] return gradients diff --git a/tests/unit/autojac/test_jac_to_grad.py b/tests/unit/autojac/test_jac_to_grad.py index 60ea6838..a3f83097 100644 --- a/tests/unit/autojac/test_jac_to_grad.py +++ b/tests/unit/autojac/test_jac_to_grad.py @@ -101,3 +101,17 @@ def test_jacs_are_freed(retain_jac: bool): check = assert_has_jac if retain_jac else assert_has_no_jac check(t1) check(t2) + + +def test_noncontiguous_jac(): + """Tests that jac_to_grad works when the .jac field is non-contiguous.""" + + aggregator = UPGrad() + t = tensor_([2.0, 3.0, 4.0], requires_grad=True) + jac_T = tensor_([[-4.0, 1.0], [1.0, 6.0], [1.0, 1.0]]) + jac = jac_T.T + t.__setattr__("jac", jac) + g = aggregator(jac) + + jac_to_grad([t], aggregator) + assert_grad_close(t, g)