Source code for dolfinx_adjoint.blocks.function_assigner

import typing

import dolfinx
import numpy as np
import ufl
from pyadjoint import AdjFloat, Block, OverloadedType
from ufl.formatting.ufl2unicode import ufl2unicode

from dolfinx_adjoint.utils import function_from_vector


[docs] class FunctionAssignBlock(Block): def __init__( self, other: typing.Union[np.inexact, int, float], func: dolfinx.fem.Function, ad_block_tag: typing.Optional[str] = None, ): super().__init__(ad_block_tag=ad_block_tag) self.other = None self.expr = None if isinstance(other, OverloadedType): self.add_dependency(other, no_duplicates=True) elif isinstance(other, float) or isinstance(other, int): other = AdjFloat(other) self.add_dependency(other, no_duplicates=True) elif not (isinstance(other, float) or isinstance(other, int)): raise NotImplementedError("This should eventually be supported") # # Assume that this is a point-wise evaluated UFL expression (firedrake only) # for op in traverse_unique_terminals(other): # if isinstance(op, OverloadedType): # self.add_dependency(op, no_duplicates=True) # self.expr = other else: raise NotImplementedError("We should not get here!") def _replace_with_saved_output(self): if self.expr is None: return None replace_map = {} for dep in self.get_dependencies(): replace_map[dep.output] = dep.saved_output return ufl.replace(self.expr, replace_map)
[docs] def prepare_evaluate_adj(self, inputs, adj_inputs, relevant_dependencies): V = self.get_outputs()[0].output.function_space adj_input_func = function_from_vector(V, adj_inputs[0]) if self.expr is None: return adj_input_func expr = self._replace_with_saved_output() return expr, adj_input_func
[docs] def evaluate_adj_component(self, inputs, adj_inputs, block_variable, idx, prepared=None): if self.expr is None: if isinstance(block_variable.output, AdjFloat): # Adjoint of a broadcast is just a sum if isinstance(adj_inputs[0], dolfinx.la.Vector): vec = adj_inputs[0] one = dolfinx.la.vector( adj_inputs[0].index_map, adj_inputs[0].block_size, adj_inputs[0].array.dtype ) one.array[:] = 1 return dolfinx.cpp.la.inner_product(vec._cpp_object, one._cpp_object) else: try: return adj_inputs[0].sum() except AttributeError: # Catch the case where adj_inputs[0] is just a float return adj_inputs[0] elif isinstance(func := block_variable.output, dolfinx.fem.Function): adj_output = dolfinx.fem.Function(func.function_space) assert func.function_space == prepared.function_space adj_output.x.array[:] = prepared.x.array[:] return adj_output.x else: raise NotImplementedError(f"Adjoint for {block_variable=} not implemented.")
# elif isinstance(block_variable.output, dolfinx.fem.Constant): # R = block_variable.output._ad_function_space(prepared.function_space.mesh) # return self._adj_assign_constant(prepared, R) # else: # adj_output = dolfinx.fem.Function( # block_variable.output.function_space()) # adj_output.assign(prepared) # return adj_output.vector() # else: # # Linear combination # expr, adj_input_func = prepared # adj_output = dolfinx.fem.Function(adj_input_func.function_space) # if not isinstance(block_variable.output, dolfinx.fem.Constant): # diff_expr = ufl.algorithms.expand_derivatives( # ufl.derivative(expr, block_variable.saved_output, adj_input_func) # ) # adj_output.assign(diff_expr) # else: # mesh = adj_output.function_space().mesh() # diff_expr = ufl.algorithms.expand_derivatives( # ufl.derivative( # expr, # block_variable.saved_output, # create_constant(1., domain=mesh) # ) # ) # adj_output.assign(diff_expr) # return adj_output.vector().inner(adj_input_func.vector()) # if isinstance(block_variable.output, dolfin.Constant): # R = block_variable.output._ad_function_space(adj_output.function_space().mesh()) # return self._adj_assign_constant(adj_output, R) # else: # return adj_output.vector() def _adj_assign_constant(self, adj_output, constant_fs): r = dolfinx.fem.Function(constant_fs) shape = r.ufl_shape raise NotImplementedError("Not implemented for constants.") if shape == () or shape[0] == 1: # Scalar Constant raise NotImplementedError("Not implemented for scalar constants yet.") # r.vector()[:] = adj_output.vector().sum() # else: # # We assume the shape of the constant == shape of the output function if not scalar. # # This assumption is due to FEniCS not supporting products with non-scalar constants in assign. # values = [] # for i in range(shape[0]): # values.append(adj_output.sub(i, deepcopy=True).vector().sum()) # r.assign(dolfin.Constant(values)) return r.vector()
[docs] def prepare_evaluate_tlm(self, inputs, tlm_inputs, relevant_outputs): if self.expr is None: return None return self._replace_with_saved_output()
[docs] def evaluate_tlm_component(self, inputs, tlm_inputs, block_variable, idx, prepared=None): if self.expr is None: return tlm_inputs[0] expr = prepared dudm = dolfinx.fem.Function(block_variable.output.function_space) dudmi = dolfinx.fem.Function(block_variable.output.function_space) for dep in self.get_dependencies(): if dep.tlm_value: dudmi.assign(ufl.algorithms.expand_derivatives(ufl.derivative(expr, dep.saved_output, dep.tlm_value))) dudm.vector().axpy(1.0, dudmi.vector()) return dudm
[docs] def prepare_evaluate_hessian(self, inputs, hessian_inputs, adj_inputs, relevant_dependencies): return self.prepare_evaluate_adj(inputs, hessian_inputs, relevant_dependencies)
[docs] def evaluate_hessian_component( self, inputs, hessian_inputs, adj_inputs, block_variable, idx, relevant_dependencies, prepared=None, ): # Current implementation assumes lincom in hessian, # otherwise we need second-order derivatives here. return self.evaluate_adj_component(inputs, hessian_inputs, block_variable, idx, prepared)
[docs] def prepare_recompute_component(self, inputs, relevant_outputs): if self.expr is None: return None return self._replace_with_saved_output()
[docs] def recompute_component(self, inputs, block_variable, idx, prepared): if self.expr is None: prepared = inputs[0] output = dolfinx.fem.Function( block_variable.output.function_space, name="f{block_variable.output.name}_AssignBlockRecompute" ) try: if output.function_space == prepared.function_space: output.x.array[:] = prepared.x.array[:] except AttributeError: # Handling float value output.x.array[:] = prepared return output
def __str__(self): rhs = self.expr or self.other or self.get_dependencies()[0].output if isinstance(rhs, ufl.core.expr.Expr): rhs_str = ufl2unicode(rhs) else: rhs_str = str(rhs) return f"assign({rhs_str})"