Skip to content

BUG: matfree does not work with NLVS and Cofunction. #4876

@JHopeCollins

Description

@JHopeCollins

Describe the bug
Assembling a matrix free Jacobian crashes for an NLVS when the form has a Cofunction.

Steps to Reproduce

from firedrake import *
from ufl.algorithms import expand_derivatives

mesh = UnitIntervalMesh(2)
V = FunctionSpace(mesh, "CG", 1)

u = Function(V)
v = TestFunction(V)
b = Cofunction(V.dual())

F = inner(u, v)*dx - b

J = derivative(F, u)

print(f"{type(J) = }")
print(f"{str(J) = }\n")

print(f"{type(expand_derivatives(J)) = }")
print(f"{str(expand_derivatives(J)) = }\n")

print("Building matrix-free solver")
nlvs = NonlinearVariationalSolver(
    NonlinearVariationalProblem(F, u),
    solver_parameters={'mat_type': 'matfree'})

Error message

type(J) = <class 'ufl.form.FormSum'>
str(J) = '1*d/dfj { w₂ * (conj((v_0))) }, with fh=ExprList(*(w₂,)), dfh/dfj = ExprList(*(v_1,)), and coefficient derivatives ExprMapping(*()) * dx(<Mesh #0>[everywhere], {}, {})\n  +  -1*d/dfj { cofunction_3 }, with fh=ExprList(*(w₂,)), dfh/dfj = ExprList(*(v_1,)), and coefficient derivatives ExprMapping(*())'

type(expand_derivatives(J)) = <class 'ufl.form.Form'>
str(expand_derivatives(J)) = 'v_1 * (conj((v_0))) * dx(<Mesh #0>[everywhere], {}, {})'

Building matrix-free solver
Traceback (most recent call last):
  File "/home/jhc/codes/fd/fd-dev/wrk/auxsnes/formsum_mat_mfe/formsum_mat_mfe.py", line 26, in <module>
    nlvs = NonlinearVariationalSolver(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "petsc4py/PETSc/Log.pyx", line 250, in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
  File "petsc4py/PETSc/Log.pyx", line 251, in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
  File "/usr/lib/python3.12/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/jhc/codes/fd/fd-dev/src/firedrake/firedrake/adjoint_utils/variational_solver.py", line 47, in wrapper
    init(self, problem, *args, **kwargs)
  File "/home/jhc/codes/fd/fd-dev/src/firedrake/firedrake/variational_solver.py", line 307, in __init__
    ctx.set_jacobian(self.snes)
  File "/home/jhc/codes/fd/fd-dev/src/firedrake/firedrake/solving_utils.py", line 331, in set_jacobian
    snes.setJacobian(self.form_jacobian, J=self._jac.petscmat,
                                           ^^^^^^^^^
  File "/usr/lib/python3.12/functools.py", line 995, in __get__
    val = self.func(instance)
          ^^^^^^^^^^^^^^^^^^^
  File "/home/jhc/codes/fd/fd-dev/src/firedrake/firedrake/solving_utils.py", line 539, in _jac
    return self._assembler_jac.allocate()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jhc/codes/fd/fd-dev/src/firedrake/firedrake/assemble.py", line 362, in allocate
    return MatrixFreeAssembler(self._form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jhc/codes/fd/fd-dev/src/firedrake/firedrake/assemble.py", line 954, in __new__
    raise TypeError(f"The first positional argument must be of ufl.Form or slate.TensorBase: got {type(form)} ({form})")
TypeError: The first positional argument must be of ufl.Form or slate.TensorBase: got <class 'ufl.form.FormSum'> (1*d/dfj { w₂ * (conj((v_0))) }, with fh=ExprList(*(w₂,)), dfh/dfj = ExprList(*(v_1,)), and coefficient derivatives ExprMapping(*()) * dx(<Mesh #0>[everywhere], {}, {})
  +  -1*d/dfj { cofunction_3 }, with fh=ExprList(*(w₂,)), dfh/dfj = ExprList(*(v_1,)), and coefficient derivatives ExprMapping(*()))

Additional Info
It appear that derivative(F, u) is returning a FormSum because it isn't removing the derivative(Cofunction, u) term. The matrix free assembly then complains that needs a Form not a FormSum. Using ufl.algorithms.expand_derivatives after calling derivative(F, u) returns a Form without the Cofunction term, so maybe we need a call to that somewhere in the NLVS setup.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions