From 99817b574df158afaca0af2e8043d47aae8b18e4 Mon Sep 17 00:00:00 2001 From: Daniel Korchinski Date: Thu, 19 Feb 2026 08:05:45 +0100 Subject: [PATCH] Fix potential issue with .view on noncontiguous memory and added test coverage for noncontiguous jacobians. --- src/torchjd/autojac/_jac_to_grad.py | 2 +- tests/unit/autojac/test_jac_to_grad.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 9e46caa0..764b79bb 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -91,7 +91,7 @@ def _disunite_gradient( tensors: list[TensorWithJac], ) -> list[Tensor]: gradient_vectors = gradient_vector.split([t.numel() for t in tensors]) - gradients = [g.view(t.shape) for g, t in zip(gradient_vectors, tensors, strict=True)] + gradients = [g.reshape(t.shape) for g, t in zip(gradient_vectors, tensors, strict=True)] return gradients diff --git a/tests/unit/autojac/test_jac_to_grad.py b/tests/unit/autojac/test_jac_to_grad.py index 60ea6838..a3f83097 100644 --- a/tests/unit/autojac/test_jac_to_grad.py +++ b/tests/unit/autojac/test_jac_to_grad.py @@ -101,3 +101,17 @@ def test_jacs_are_freed(retain_jac: bool): check = assert_has_jac if retain_jac else assert_has_no_jac check(t1) check(t2) + + +def test_noncontiguous_jac(): + """Tests that jac_to_grad works when the .jac field is non-contiguous.""" + + aggregator = UPGrad() + t = tensor_([2.0, 3.0, 4.0], requires_grad=True) + jac_T = tensor_([[-4.0, 1.0], [1.0, 6.0], [1.0, 1.0]]) + jac = jac_T.T + t.__setattr__("jac", jac) + g = aggregator(jac) + + jac_to_grad([t], aggregator) + assert_grad_close(t, g)