Source code for dolfinx_adjoint.solvers

try:
    import typing_extensions as typing
except ModuleNotFoundError:
    import typing  # type: ignore[no-redef]

import dolfinx.fem.petsc
import numpy as np
import numpy.typing as npt
import pyadjoint
import ufl

from dolfinx_adjoint.types import Function

from .blocks.solvers import LinearProblemBlock


[docs] class LinearProblem(dolfinx.fem.petsc.LinearProblem): """A linear problem that can be used with adjoint methods. This class extends the `dolfinx.fem.petsc.LinearProblem` to support adjoint methods. Args: a: The bilinear form representing the left-hand side of the equation. L: The linear form representing the right-hand side of the equation. bcs: Boundary conditions to apply to the problem. u: Solution vector. P: Preconditioner for the linear problem. kind: Kind of PETSc Matrix to assemble the system into. petsc_options: Options dictionary for the PETSc krylov supspace solver. form_compiler_options: Form compiler options for generating assembly kernels. jit_options: Options for just-in-time compilation of the forms. entity_maps: Mapping from meshes that coefficients and arguments are defined on to the integration domain of the forms. ad_block_tag: Tag for adjoint blocks in the tape. adjoint_petsc_options: PETSc options for adjoint problems. tlm_petsc_options: Optional PETSc options for TLM problems. """ def __init__( self, a: typing.Union[ufl.Form, typing.Iterable[typing.Iterable[ufl.Form]]], L: typing.Union[ufl.Form, typing.Iterable[ufl.Form]], bcs: typing.Optional[typing.Iterable[dolfinx.fem.DirichletBC]] = None, u: typing.Optional[typing.Union[dolfinx.fem.Function, typing.Iterable[dolfinx.fem.Function]]] = None, P: typing.Optional[typing.Union[ufl.Form, typing.Iterable[typing.Iterable[ufl.Form]]]] = None, kind: typing.Optional[typing.Union[str, typing.Iterable[typing.Iterable[str]]]] = None, petsc_options: typing.Optional[dict] = None, form_compiler_options: typing.Optional[dict] = None, jit_options: typing.Optional[dict] = None, entity_maps: typing.Optional[dict[dolfinx.mesh.Mesh, npt.NDArray[np.int32]]] = None, ad_block_tag: typing.Optional[str] = None, adjoint_petsc_options: typing.Optional[dict] = None, tlm_petsc_options: typing.Optional[dict] = None, ) -> None: self.ad_block_tag = ad_block_tag self._adj_options = adjoint_petsc_options self._tlm_options = tlm_petsc_options if u is None: try: # Extract function space for unknown from the right hand # side of the equation. assert isinstance(L, ufl.Form) self._u = Function(L.arguments()[0].ufl_function_space()) except AttributeError: assert isinstance(L, typing.Iterable) self._u = [Function(Li.arguments()[0].ufl_function_space()) for Li in L] # type: ignore[assignment] else: if isinstance(u, dolfinx.fem.Function): self._u = pyadjoint.create_overloaded_object(u) else: self._u = [pyadjoint.create_overloaded_object(ui) for ui in u] # type: ignore[assignment] # Cache some objects self._lhs = a self._rhs = L self._preconditioner = P self._jit_options = jit_options self._form_compiler_options = form_compiler_options self._entity_maps = entity_maps self._petsc_options = petsc_options self._kind = kind # Initialize linear solver dolfinx.fem.petsc.LinearProblem.__init__( self, a, L, bcs, self._u, P, kind, petsc_options, form_compiler_options, jit_options, entity_maps )
[docs] def solve( self, annotate: bool = True ) -> typing.Tuple[typing.Union[dolfinx.fem.Function, typing.Iterable[dolfinx.fem.Function]], int, int]: """ Solve the linear problem and return the solution. """ annotate = pyadjoint.annotate_tape({"annotate": annotate}) if annotate: block = LinearProblemBlock( self._lhs, # type: ignore self._rhs, # type: ignore bcs=self.bcs, u=self.u, P=self._preconditioner, kind=self._kind, petsc_options=self._petsc_options, form_compiler_options=self._form_compiler_options, jit_options=self._jit_options, entity_maps=self._entity_maps, ad_block_tag=self.ad_block_tag, adjoint_petsc_options=self._adj_options, tlm_petsc_options=self._tlm_options, ) tape = pyadjoint.get_working_tape() tape.add_block(block) out, converged_reason, num_its = dolfinx.fem.petsc.LinearProblem.solve(self) if annotate: if isinstance(out, Function): block.add_output(out.create_block_variable()) else: for ui in out: assert isinstance(ui, Function) block.add_output(ui.create_block_variable()) return out, converged_reason, num_its