import typing
from mpi4py import MPI
import dolfinx
import numpy as np
import numpy.typing as npt
import ufl
from pyadjoint import Block, create_overloaded_object
from ufl.formatting.ufl2unicode import ufl2unicode
class _SpecialVector(dolfinx.la.Vector):
"""Workaround adding __iadd__ to `dolfinx.la.Vector`."""
def __iadd__(self, other):
self.array[:] += other.array[:]
return self
def _vector(map, bs=1, dtype: npt.DTypeLike = np.float64) -> _SpecialVector:
"""Create a distributed vector.
Args:
map: Index map the describes the size and distribution of the
vector.
bs: Block size.
dtype: The scalar type.
Returns:
A distributed vector.
"""
if np.issubdtype(dtype, np.float32):
vtype = dolfinx.cpp.la.Vector_float32
elif np.issubdtype(dtype, np.float64):
vtype = dolfinx.cpp.la.Vector_float64
elif np.issubdtype(dtype, np.complex64):
vtype = dolfinx.cpp.la.Vector_complex64
elif np.issubdtype(dtype, np.complex128):
vtype = dolfinx.cpp.la.Vector_complex128
elif np.issubdtype(dtype, np.int8):
vtype = dolfinx.cpp.la.Vector_int8
elif np.issubdtype(dtype, np.int32):
vtype = dolfinx.cpp.la.Vector_int32
elif np.issubdtype(dtype, np.int64):
vtype = dolfinx.cpp.la.Vector_int64
else:
raise NotImplementedError(f"Type {dtype} not supported.")
return _SpecialVector(vtype(map, bs))
def _create_vector(L: dolfinx.fem.Form) -> _SpecialVector:
"""Create a Vector that is compatible with a given linear form.
Args:
L: A linear form.
Returns:
A vector that the form can be assembled into.
"""
# Can just take the first dofmap here, since all dof maps have the same
# index map in mixed-topology meshes
dofmap = L.function_spaces[0].dofmaps(0) # type: ignore
return _vector(dofmap.index_map, dofmap.index_map_bs, dtype=L.dtype)
def assemble_compiled_form(
form: dolfinx.fem.Form, tensor: typing.Optional[typing.Union[dolfinx.la.Vector, _SpecialVector | float]] = None
) -> typing.Union[dolfinx.la.Vector, _SpecialVector, float]:
"""Assemble a compiled form and optionally apply Dirichlet boundary condition.
Args:
form: Compiled form to assemble.
tensor: Optional vector to which the assembled form will be added.
Returns:
tensor: The assembled vector, which is either the input tensor or a new vector
created from the form's function space(s).
Raises:
NotImplementedError: If the form's rank is not 0 or 1.
"""
if form.rank == 1:
tensor = dolfinx.fem.create_vector(form) if tensor is None else tensor
assert isinstance(tensor, dolfinx.la.Vector)
dolfinx.fem.assemble._assemble_vector_array(tensor.array, form)
tensor.scatter_reverse(dolfinx.la.InsertMode.add)
tensor.scatter_forward()
elif form.rank == 0:
local_val = dolfinx.fem.assemble_scalar(form)
comm = form.mesh.comm
tensor = comm.allreduce(local_val, op=MPI.SUM)
else:
raise NotImplementedError("Only 1-form assembly is currently supported.")
return tensor
[docs]
class AssembleBlock(Block):
"""Block for assembling a symbolic UFL form into a tensor.
Args:
form: The UFL form to assemble.
ad_block_tag: Tag for the block in the adjoint tape.
jit_options: Dictionary of options for JIT compilation.
form_compiler_options: Dictionary of options for the form compiler.
entity_maps: Dictionary mapping meshes to entity maps for assembly.
"""
def __init__(
self,
form: ufl.Form,
ad_block_tag: typing.Optional[str] = None,
jit_options: typing.Optional[dict] = None,
form_compiler_options: typing.Optional[dict] = None,
entity_maps: typing.Optional[dict[dolfinx.mesh.Mesh, npt.NDArray[np.int32]]] = None,
):
super(AssembleBlock, self).__init__(ad_block_tag=ad_block_tag)
# Store the options for code generation
self._jit_options = jit_options
self._form_compiler_options = form_compiler_options
self._entity_maps = entity_maps
# Store compiled and original form
self.form = form
self.compiled_form = dolfinx.fem.form(
form, jit_options=jit_options, form_compiler_options=form_compiler_options, entity_maps=entity_maps
)
# NOTE: Add when we want to do shape optimization
# mesh = self.form.ufl_domain().ufl_cargo()
# self.add_dependency(mesh)
for coefficient in self.form.coefficients():
self.add_dependency(coefficient, no_duplicates=True)
# Set up cache for vectors that can be reused in adjoint action
# self._cached_vectors: dict[int, _SpecialVector] = {}
def __str__(self):
return f"assemble({ufl2unicode(self.form)})"
[docs]
def compute_action_adjoint(
self,
adj_input: typing.Union[float, dolfinx.la.Vector],
arity_form: int,
form: typing.Optional[ufl.Form] = None,
c_rep: typing.Optional[typing.Union[ufl.Coefficient, ufl.Constant]] = None,
space: typing.Optional[dolfinx.fem.FunctionSpace] = None,
dform: typing.Optional[dolfinx.fem.Form] = None,
):
"""This computes the action of the adjoint of the derivative of `form` wrt `c_rep` on `adj_input`.
In other words, it returns:
.. math::
\\left\\langle\\left(\\frac{\\partial form}{\\partial c_{rep}}\\right)^*, adj_{input} \\right\\rangle
- If `form` has arity 0, then :math:`\\frac{\\partial form}{\\partial c_{rep}}` is a 1-form
and `adj_input` a float, we can simply use the `*` operator.
- If `form` has arity 1 then :math:`\\frac{\\partial form}{\\partial c_{rep}}` is a 2-form
and we can symbolically take its adjoint and then apply the action on `adj_input`, to finally
assemble the result.
Args:
adj_input: The input to the adjoint operation, typically a scalar or vector.
arity_form: The arity of the form, i.e., 0 for scalar, 1 for vector, 2 for matrix etc.
form: The UFL form to differentiate if `dform` is not provided.
c_rep: The coefficient or constant with respect to which the derivative is taken.
space: The function space associated with the `c_rep` to form an `ufl.Argument` in.
dform: Pre-computed derivative form, :math:`\\frac{\\partial form}{\\partial c_{rep}}`.
"""
if arity_form == 0:
assert arity_form == self.compiled_form.rank, "Inconsistent arity of input form and block form."
if dform is None:
assert space is not None
dc = ufl.TestFunction(space)
dform = ufl.derivative(form, c_rep, dc)
assert isinstance(dform, ufl.Form), "dform must be a UFL form."
compiled_adjoint = dolfinx.fem.form(
dform,
jit_options=self._jit_options,
form_compiler_options=self._form_compiler_options,
entity_maps=self._entity_maps,
)
if space is None:
# If space is not supplied infer it from the form
assert len(dform.arguments()) == 1
space = dform.arguments()[0].ufl_function_space()
# self._cached_vectors[id(space)] = _create_vector(compiled_adjoint)
vector = _create_vector(compiled_adjoint)
vector.array[:] = 0.0
# elif self._cached_vectors.get(id(space)) is None:
# Create a new vector for this space
# self._cached_vectors[id(space)] = _create_vector(compiled_adjoint)
# self._cached_vectors[id(space)].array[:] = 0.0
# assemble_compiled_form(compiled_adjoint, self._cached_vectors[id(space)])
assemble_compiled_form(compiled_adjoint, vector)
# return a vector scalbed by the scalar `adj_input`
vector.array[:] *= adj_input
vector.scatter_forward()
return vector, dform
# Return a Vector scaled by the scalar `adj_input`
# self._cached_vectors[id(space)].array[:] *= adj_input
# self._cached_vectors[id(space)].scatter_forward()
# return self._cached_vectors[id(space)], dform
# elif arity_form == 1:
# if dform is None:
# dc = dolfin.TrialFunction(space)
# dform = dolfin.derivative(form, c_rep, dc)
# # Get the Function
# adj_input = adj_input.function
# # Symbolic operators such as action/adjoint require derivatives to have been expanded beforehand.
# # However, UFL doesn't support expanding coordinate derivatives of Coefficients in physical space,
# # implying that we can't symbolically take the action/adjoint of the Jacobian for SpatialCoordinates.
# # -> Workaround: Apply action/adjoint numerically (using PETSc).
# if not isinstance(c_rep, dolfin.SpatialCoordinate):
# # Symbolically compute: (dform/dc_rep)^* * adj_input
# adj_output = dolfin.action(dolfin.adjoint(dform), adj_input)
# adj_output = assemble_adjoint_value(adj_output)
# else:
# # Get PETSc matrix
# dform_mat = assemble_adjoint_value(dform).petscmat
# # Action of the adjoint (Hermitian transpose)
# adj_output = dolfin.Function(space)
# with adj_input.dat.vec_ro as v_vec:
# with adj_output.dat.vec as res_vec:
# dform_mat.multHermitian(v_vec, res_vec)
# return adj_output, dform
else:
raise ValueError("Forms with arity > 1 are not handled yet!")
[docs]
def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies):
replaced_coeffs = {}
for block_variable in self.get_dependencies():
coeff = block_variable.output
c_rep = block_variable.saved_output
if coeff in self.form.coefficients():
replaced_coeffs[coeff] = c_rep
form = ufl.replace(self.form, replaced_coeffs)
return form
[docs]
def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None):
form = prepared
adj_input = adj_inputs[0]
c = block_variable.output
c_rep = block_variable.saved_output
from ufl.algorithms.analysis import extract_arguments
arity_form = len(extract_arguments(form))
# if isinstance(c, dolfin.Constant):
# mesh = extract_mesh_from_form(self.form)
# space = c._ad_function_space(mesh)
if isinstance(c, dolfinx.fem.Function):
space = c.function_space
# elif isinstance(c, dolfin.Mesh):
# c_rep = dolfin.SpatialCoordinate(c_rep)
# space = c._ad_function_space()
return self.compute_action_adjoint(adj_input, arity_form, form, c_rep, space)[0]
[docs]
def prepare_evaluate_tlm(self, inputs, tlm_inputs, relevant_outputs):
return self.prepare_evaluate_adj(inputs, tlm_inputs, self.get_dependencies())
[docs]
def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None):
form = prepared
dform = 0.0
from ufl.algorithms.analysis import extract_arguments
arity_form = len(extract_arguments(form))
for bv in self.get_dependencies():
c_rep = bv.saved_output
tlm_value = bv.tlm_value
if tlm_value is None:
continue
if isinstance(c_rep, dolfinx.mesh.Mesh):
X = ufl.SpatialCoordinate(c_rep)
dform += ufl.derivative(form, X, tlm_value)
else:
dform += ufl.derivative(form, c_rep, tlm_value)
if not isinstance(dform, float):
dform = ufl.algorithms.expand_derivatives(dform)
compiled_form = dolfinx.fem.form(
dform,
jit_options=self._jit_options,
form_compiler_options=self._form_compiler_options,
entity_maps=self._entity_maps,
)
dform = assemble_compiled_form(compiled_form)
if arity_form == 1 and dform != 0:
# Then dform is a Vector
dform = dform.function
return dform
[docs]
def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, relevant_dependencies):
return self.prepare_evaluate_adj(inputs, adj_inputs, relevant_dependencies)
[docs]
def evaluate_hessian_component(
self,
inputs,
hessian_inputs,
adj_inputs,
block_variable,
idx,
relevant_dependencies,
prepared=None,
):
form = prepared
hessian_input = hessian_inputs[0]
adj_input = adj_inputs[0]
from ufl.algorithms.analysis import extract_arguments
arity_form = len(extract_arguments(form))
c1 = block_variable.output
c1_rep = block_variable.saved_output
if isinstance(c1, dolfinx.fem.Constant):
mesh = form.ufl_domain()
space = c1._ad_function_space(mesh)
elif isinstance(c1, dolfinx.fem.Function):
space = c1.function_space
elif isinstance(c1, dolfinx.mesh.Mesh):
c1_rep = ufl.SpatialCoordinate(c1)
space = c1._ad_function_space()
else:
return None
hessian_outputs, dform = self.compute_action_adjoint(hessian_input, arity_form, form, c1_rep, space)
ddform = 0
for other_idx, bv in relevant_dependencies:
c2_rep = bv.saved_output
tlm_input = bv.tlm_value
if tlm_input is None:
continue
if isinstance(c2_rep, dolfinx.mesh.Mesh):
X = ufl.SpatialCoordinate(c2_rep)
ddform += ufl.derivative(dform, X, tlm_input)
else:
ddform += ufl.derivative(dform, c2_rep, tlm_input)
if not isinstance(ddform, float):
ddform = ufl.algorithms.expand_derivatives(ddform)
if not ddform.empty():
# FIXME: COmpare ddform with legacy dolfin_adjoitn here, as this is DG-0, while hessian is in DG-0
adj_action = self.compute_action_adjoint(adj_input, arity_form, dform=ddform)[0]
try:
hessian_outputs += adj_action
except TypeError:
hessian_outputs.array[:] += adj_action.array[:]
return hessian_outputs
[docs]
def prepare_recompute_component(self, inputs, relevant_outputs):
return self.prepare_evaluate_adj(inputs, None, None)
[docs]
def recompute_component(self, inputs, block_variable, idx, prepared):
form = prepared
compiled_form = dolfinx.fem.form(
form,
jit_options=self._jit_options,
form_compiler_options=self._form_compiler_options,
entity_maps=self._entity_maps,
)
local_output = dolfinx.fem.assemble_scalar(compiled_form)
comm = compiled_form.mesh.comm
output = comm.allreduce(local_output, op=MPI.SUM)
output = create_overloaded_object(output)
return output