Source code for dolfinx_adjoint.blocks.assembly

import typing

from mpi4py import MPI

import dolfinx
import numpy as np
import numpy.typing as npt
import ufl
from pyadjoint import Block, create_overloaded_object
from ufl.formatting.ufl2unicode import ufl2unicode


class _SpecialVector(dolfinx.la.Vector):
    """Workaround adding __iadd__ to `dolfinx.la.Vector`."""

    def __iadd__(self, other):
        self.array[:] += other.array[:]
        return self


def _vector(map, bs=1, dtype: npt.DTypeLike = np.float64) -> _SpecialVector:
    """Create a distributed vector.

    Args:
        map: Index map the describes the size and distribution of the
            vector.
        bs: Block size.
        dtype: The scalar type.

    Returns:
        A distributed vector.
    """
    if np.issubdtype(dtype, np.float32):
        vtype = dolfinx.cpp.la.Vector_float32
    elif np.issubdtype(dtype, np.float64):
        vtype = dolfinx.cpp.la.Vector_float64
    elif np.issubdtype(dtype, np.complex64):
        vtype = dolfinx.cpp.la.Vector_complex64
    elif np.issubdtype(dtype, np.complex128):
        vtype = dolfinx.cpp.la.Vector_complex128
    elif np.issubdtype(dtype, np.int8):
        vtype = dolfinx.cpp.la.Vector_int8
    elif np.issubdtype(dtype, np.int32):
        vtype = dolfinx.cpp.la.Vector_int32
    elif np.issubdtype(dtype, np.int64):
        vtype = dolfinx.cpp.la.Vector_int64
    else:
        raise NotImplementedError(f"Type {dtype} not supported.")

    return _SpecialVector(vtype(map, bs))


def _create_vector(L: dolfinx.fem.Form) -> _SpecialVector:
    """Create a Vector that is compatible with a given linear form.

    Args:
        L: A linear form.

    Returns:
        A vector that the form can be assembled into.
    """
    # Can just take the first dofmap here, since all dof maps have the same
    # index map in mixed-topology meshes
    dofmap = L.function_spaces[0].dofmaps(0)  # type: ignore
    return _vector(dofmap.index_map, dofmap.index_map_bs, dtype=L.dtype)


def assemble_compiled_form(
    form: dolfinx.fem.Form, tensor: typing.Optional[typing.Union[dolfinx.la.Vector, _SpecialVector | float]] = None
) -> typing.Union[dolfinx.la.Vector, _SpecialVector, float]:
    """Assemble a compiled form and optionally apply Dirichlet boundary condition.

    Args:
        form: Compiled form to assemble.
        tensor: Optional vector to which the assembled form will be added.
    Returns:
        tensor: The assembled vector, which is either the input tensor or a new vector
            created from the form's function space(s).
    Raises:
        NotImplementedError: If the form's rank is not 0 or 1.
    """

    if form.rank == 1:
        tensor = dolfinx.fem.create_vector(form) if tensor is None else tensor
        assert isinstance(tensor, dolfinx.la.Vector)
        dolfinx.fem.assemble._assemble_vector_array(tensor.array, form)
        tensor.scatter_reverse(dolfinx.la.InsertMode.add)
        tensor.scatter_forward()
    elif form.rank == 0:
        local_val = dolfinx.fem.assemble_scalar(form)
        comm = form.mesh.comm
        tensor = comm.allreduce(local_val, op=MPI.SUM)

    else:
        raise NotImplementedError("Only 1-form assembly is currently supported.")
    return tensor


[docs] class AssembleBlock(Block): """Block for assembling a symbolic UFL form into a tensor. Args: form: The UFL form to assemble. ad_block_tag: Tag for the block in the adjoint tape. jit_options: Dictionary of options for JIT compilation. form_compiler_options: Dictionary of options for the form compiler. entity_maps: Dictionary mapping meshes to entity maps for assembly. """ def __init__( self, form: ufl.Form, ad_block_tag: typing.Optional[str] = None, jit_options: typing.Optional[dict] = None, form_compiler_options: typing.Optional[dict] = None, entity_maps: typing.Optional[dict[dolfinx.mesh.Mesh, npt.NDArray[np.int32]]] = None, ): super(AssembleBlock, self).__init__(ad_block_tag=ad_block_tag) # Store the options for code generation self._jit_options = jit_options self._form_compiler_options = form_compiler_options self._entity_maps = entity_maps # Store compiled and original form self.form = form self.compiled_form = dolfinx.fem.form( form, jit_options=jit_options, form_compiler_options=form_compiler_options, entity_maps=entity_maps ) # NOTE: Add when we want to do shape optimization # mesh = self.form.ufl_domain().ufl_cargo() # self.add_dependency(mesh) for coefficient in self.form.coefficients(): self.add_dependency(coefficient, no_duplicates=True) # Set up cache for vectors that can be reused in adjoint action # self._cached_vectors: dict[int, _SpecialVector] = {} def __str__(self): return f"assemble({ufl2unicode(self.form)})"
[docs] def compute_action_adjoint( self, adj_input: typing.Union[float, dolfinx.la.Vector], arity_form: int, form: typing.Optional[ufl.Form] = None, c_rep: typing.Optional[typing.Union[ufl.Coefficient, ufl.Constant]] = None, space: typing.Optional[dolfinx.fem.FunctionSpace] = None, dform: typing.Optional[dolfinx.fem.Form] = None, ): """This computes the action of the adjoint of the derivative of `form` wrt `c_rep` on `adj_input`. In other words, it returns: .. math:: \\left\\langle\\left(\\frac{\\partial form}{\\partial c_{rep}}\\right)^*, adj_{input} \\right\\rangle - If `form` has arity 0, then :math:`\\frac{\\partial form}{\\partial c_{rep}}` is a 1-form and `adj_input` a float, we can simply use the `*` operator. - If `form` has arity 1 then :math:`\\frac{\\partial form}{\\partial c_{rep}}` is a 2-form and we can symbolically take its adjoint and then apply the action on `adj_input`, to finally assemble the result. Args: adj_input: The input to the adjoint operation, typically a scalar or vector. arity_form: The arity of the form, i.e., 0 for scalar, 1 for vector, 2 for matrix etc. form: The UFL form to differentiate if `dform` is not provided. c_rep: The coefficient or constant with respect to which the derivative is taken. space: The function space associated with the `c_rep` to form an `ufl.Argument` in. dform: Pre-computed derivative form, :math:`\\frac{\\partial form}{\\partial c_{rep}}`. """ if arity_form == 0: assert arity_form == self.compiled_form.rank, "Inconsistent arity of input form and block form." if dform is None: assert space is not None dc = ufl.TestFunction(space) dform = ufl.derivative(form, c_rep, dc) assert isinstance(dform, ufl.Form), "dform must be a UFL form." compiled_adjoint = dolfinx.fem.form( dform, jit_options=self._jit_options, form_compiler_options=self._form_compiler_options, entity_maps=self._entity_maps, ) if space is None: # If space is not supplied infer it from the form assert len(dform.arguments()) == 1 space = dform.arguments()[0].ufl_function_space() # self._cached_vectors[id(space)] = _create_vector(compiled_adjoint) vector = _create_vector(compiled_adjoint) vector.array[:] = 0.0 # elif self._cached_vectors.get(id(space)) is None: # Create a new vector for this space # self._cached_vectors[id(space)] = _create_vector(compiled_adjoint) # self._cached_vectors[id(space)].array[:] = 0.0 # assemble_compiled_form(compiled_adjoint, self._cached_vectors[id(space)]) assemble_compiled_form(compiled_adjoint, vector) # return a vector scalbed by the scalar `adj_input` vector.array[:] *= adj_input vector.scatter_forward() return vector, dform # Return a Vector scaled by the scalar `adj_input` # self._cached_vectors[id(space)].array[:] *= adj_input # self._cached_vectors[id(space)].scatter_forward() # return self._cached_vectors[id(space)], dform # elif arity_form == 1: # if dform is None: # dc = dolfin.TrialFunction(space) # dform = dolfin.derivative(form, c_rep, dc) # # Get the Function # adj_input = adj_input.function # # Symbolic operators such as action/adjoint require derivatives to have been expanded beforehand. # # However, UFL doesn't support expanding coordinate derivatives of Coefficients in physical space, # # implying that we can't symbolically take the action/adjoint of the Jacobian for SpatialCoordinates. # # -> Workaround: Apply action/adjoint numerically (using PETSc). # if not isinstance(c_rep, dolfin.SpatialCoordinate): # # Symbolically compute: (dform/dc_rep)^* * adj_input # adj_output = dolfin.action(dolfin.adjoint(dform), adj_input) # adj_output = assemble_adjoint_value(adj_output) # else: # # Get PETSc matrix # dform_mat = assemble_adjoint_value(dform).petscmat # # Action of the adjoint (Hermitian transpose) # adj_output = dolfin.Function(space) # with adj_input.dat.vec_ro as v_vec: # with adj_output.dat.vec as res_vec: # dform_mat.multHermitian(v_vec, res_vec) # return adj_output, dform else: raise ValueError("Forms with arity > 1 are not handled yet!")
[docs] def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): replaced_coeffs = {} for block_variable in self.get_dependencies(): coeff = block_variable.output c_rep = block_variable.saved_output if coeff in self.form.coefficients(): replaced_coeffs[coeff] = c_rep form = ufl.replace(self.form, replaced_coeffs) return form
[docs] def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None): form = prepared adj_input = adj_inputs[0] c = block_variable.output c_rep = block_variable.saved_output from ufl.algorithms.analysis import extract_arguments arity_form = len(extract_arguments(form)) # if isinstance(c, dolfin.Constant): # mesh = extract_mesh_from_form(self.form) # space = c._ad_function_space(mesh) if isinstance(c, dolfinx.fem.Function): space = c.function_space # elif isinstance(c, dolfin.Mesh): # c_rep = dolfin.SpatialCoordinate(c_rep) # space = c._ad_function_space() return self.compute_action_adjoint(adj_input, arity_form, form, c_rep, space)[0]
[docs] def prepare_evaluate_tlm(self, inputs, tlm_inputs, relevant_outputs): return self.prepare_evaluate_adj(inputs, tlm_inputs, self.get_dependencies())
[docs] def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None): form = prepared dform = 0.0 from ufl.algorithms.analysis import extract_arguments arity_form = len(extract_arguments(form)) for bv in self.get_dependencies(): c_rep = bv.saved_output tlm_value = bv.tlm_value if tlm_value is None: continue if isinstance(c_rep, dolfinx.mesh.Mesh): X = ufl.SpatialCoordinate(c_rep) dform += ufl.derivative(form, X, tlm_value) else: dform += ufl.derivative(form, c_rep, tlm_value) if not isinstance(dform, float): dform = ufl.algorithms.expand_derivatives(dform) compiled_form = dolfinx.fem.form( dform, jit_options=self._jit_options, form_compiler_options=self._form_compiler_options, entity_maps=self._entity_maps, ) dform = assemble_compiled_form(compiled_form) if arity_form == 1 and dform != 0: # Then dform is a Vector dform = dform.function return dform
[docs] def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, relevant_dependencies): return self.prepare_evaluate_adj(inputs, adj_inputs, relevant_dependencies)
[docs] def evaluate_hessian_component( self, inputs, hessian_inputs, adj_inputs, block_variable, idx, relevant_dependencies, prepared=None, ): form = prepared hessian_input = hessian_inputs[0] adj_input = adj_inputs[0] from ufl.algorithms.analysis import extract_arguments arity_form = len(extract_arguments(form)) c1 = block_variable.output c1_rep = block_variable.saved_output if isinstance(c1, dolfinx.fem.Constant): mesh = form.ufl_domain() space = c1._ad_function_space(mesh) elif isinstance(c1, dolfinx.fem.Function): space = c1.function_space elif isinstance(c1, dolfinx.mesh.Mesh): c1_rep = ufl.SpatialCoordinate(c1) space = c1._ad_function_space() else: return None hessian_outputs, dform = self.compute_action_adjoint(hessian_input, arity_form, form, c1_rep, space) ddform = 0 for other_idx, bv in relevant_dependencies: c2_rep = bv.saved_output tlm_input = bv.tlm_value if tlm_input is None: continue if isinstance(c2_rep, dolfinx.mesh.Mesh): X = ufl.SpatialCoordinate(c2_rep) ddform += ufl.derivative(dform, X, tlm_input) else: ddform += ufl.derivative(dform, c2_rep, tlm_input) if not isinstance(ddform, float): ddform = ufl.algorithms.expand_derivatives(ddform) if not ddform.empty(): # FIXME: COmpare ddform with legacy dolfin_adjoitn here, as this is DG-0, while hessian is in DG-0 adj_action = self.compute_action_adjoint(adj_input, arity_form, dform=ddform)[0] try: hessian_outputs += adj_action except TypeError: hessian_outputs.array[:] += adj_action.array[:] return hessian_outputs
[docs] def prepare_recompute_component(self, inputs, relevant_outputs): return self.prepare_evaluate_adj(inputs, None, None)
[docs] def recompute_component(self, inputs, block_variable, idx, prepared): form = prepared compiled_form = dolfinx.fem.form( form, jit_options=self._jit_options, form_compiler_options=self._form_compiler_options, entity_maps=self._entity_maps, ) local_output = dolfinx.fem.assemble_scalar(compiled_form) comm = compiled_form.mesh.comm output = comm.allreduce(local_output, op=MPI.SUM) output = create_overloaded_object(output) return output