Source code for dolfinx_adjoint.types.function

from __future__ import annotations  # for Python<3.11

import dolfinx
import numpy
import ufl
from pyadjoint.overloaded_type import (
    FloatingType,
    create_overloaded_object,
    get_overloaded_class,
    register_overloaded_type,
)
from pyadjoint.tape import annotate_tape, get_working_tape, no_annotations, stop_annotating

from dolfinx_adjoint.blocks.assembly import assemble_compiled_form
from dolfinx_adjoint.utils import function_from_vector, gather

try:
    import typing_extensions as typing
except ModuleNotFoundError:
    import typing  # type: ignore[no-redef]
from dolfinx_adjoint.blocks.function_assigner import FunctionAssignBlock
from dolfinx_adjoint.utils import ad_kwargs


[docs] class Function(dolfinx.fem.Function, FloatingType): """A class overloading `dolfinx.fem.Function` to support it being used as a control variable in the adjoint framework. Args: V: The function space of the function. x: Optional vector to initialize the function with. If not provided, a zero vector is created. name: Optional name for the function. dtype: Data type of the function values, defaults to `dolfinx.default_scalar_type`. **kwargs: Additional keyword arguments to pass to the `pyadjoint.overloaded_type.FloatingType` constructor. """ def __init__( self, V: dolfinx.fem.FunctionSpace, x: typing.Optional[dolfinx.la.Vector] = None, name: typing.Optional[str] = None, dtype: numpy.dtype = dolfinx.default_scalar_type, **kwargs, ): super(Function, self).__init__( V, x, name, dtype, ) FloatingType.__init__( self, V, x, name=name, dtype=dtype, block_class=kwargs.pop("block_class", None), _ad_floating_active=kwargs.pop("_ad_floating_active", False), _ad_args=kwargs.pop("_ad_args", None), output_block_class=kwargs.pop("output_block_class", None), _ad_output_args=kwargs.pop("_ad_output_args", None), _ad_outputs=kwargs.pop("_ad_outputs", None), annotate=kwargs.pop("annotate", True), **kwargs, ) @classmethod def _ad_init_object(cls, obj): return cls(obj.function_space, obj.x, obj.name) @no_annotations def _ad_create_checkpoint(self): checkpoint = create_overloaded_object(self.copy()) checkpoint.name = self.name + "_checkpoint" return checkpoint def _ad_restore_at_checkpoint(self, checkpoint): return checkpoint def _ad_dot(self, other: typing.Self, options: typing.Optional[dict] = None): """Compute the inner product of the current function with ``other`` in the Riesz representation. Args: other: Function to compute the inner product with. """ options = {} if options is None else options riesz_representation = options.get("riesz_representation", "l2") if riesz_representation == "l2": return dolfinx.cpp.la.inner_product(self.x._cpp_object, other.x._cpp_object) elif riesz_representation == "L2": form_compiler_options = options.get("form_compiler_options", None) jit_options = options.get("jit_options", None) mass = ufl.inner(self, other) * ufl.dx compiled_form = dolfinx.fem.form( mass, jit_options=jit_options, form_compiler_options=form_compiler_options, ) return assemble_compiled_form(compiled_form) elif riesz_representation == "H1": form_compiler_options = options.get("form_compiler_options", None) jit_options = options.get("jit_options", None) mass_and_stiffness = ufl.inner(self, other) * ufl.dx + ufl.inner(ufl.grad(self), ufl.grad(other)) * ufl.dx compiled_form = dolfinx.fem.form( mass_and_stiffness, jit_options=jit_options, form_compiler_options=form_compiler_options, ) return assemble_compiled_form(compiled_form) else: raise NotImplementedError("Unknown Riesz representation %s" % riesz_representation) @no_annotations def _ad_mul(self, other: typing.Union[int, float]) -> typing.Self: """Multiplication of self with integer or floating value.""" r = get_overloaded_class(dolfinx.fem.Function)(self.function_space) r.x.array[:] = self.x.array * other return r @no_annotations def _ad_add(self, other: typing.Self) -> typing.Self: r = get_overloaded_class(dolfinx.fem.Function)(self.function_space) r.x.array[:] = self.x.array[:] + other.x.array[:] return r @no_annotations def _ad_convert_type(self, value: dolfinx.la.Vector, options: typing.Optional[dict] = None) -> dolfinx.fem.Function: """Convert a vector to a Riesz representation of the function.""" options = {} if options is None else options riesz_representation = options.get("riesz_representation", "l2") if riesz_representation == "l2": return create_overloaded_object(function_from_vector(self.function_space, value)) elif riesz_representation == "L2": from dolfinx.fem.petsc import assemble_matrix from dolfinx_adjoint.petsc_utils import solve_linear_problem u = ufl.TrialFunction(self.function_space) v = ufl.TestFunction(self.function_space) riesz_form = ufl.inner(u, v) * ufl.dx compiled_riesz = dolfinx.fem.form( riesz_form, jit_options=options.get("jit_options", None), form_compiler_options=options.get("form_compiler_options", None), ) ret = dolfinx.fem.Function(self.function_space) M = assemble_matrix(compiled_riesz) M.assemble() petsc_options = options.get("petsc_options", {}) solve_linear_problem(M, ret.x, value, petsc_options=petsc_options) M.destroy() return ret elif riesz_representation == "H1": from dolfinx.fem.petsc import assemble_matrix from dolfinx_adjoint.petsc_utils import solve_linear_problem u = ufl.TrialFunction(self.function_space) v = ufl.TestFunction(self.function_space) riesz_form = ufl.inner(u, v) * ufl.dx + ufl.inner(ufl.grad(u), ufl.grad(v)) * ufl.dx compiled_riesz = dolfinx.fem.form( riesz_form, jit_options=options.get("jit_options", None), form_compiler_options=options.get("form_compiler_options", None), ) ret = dolfinx.fem.Function(self.function_space) M = assemble_matrix(compiled_riesz) M.assemble() petsc_options = options.get("petsc_options", {}) solve_linear_problem(M, ret.x, value, petsc_options=petsc_options) M.destroy() return ret elif callable(riesz_representation): return riesz_representation(value) else: raise NotImplementedError("Unknown Riesz representation %s" % riesz_representation) @staticmethod def _ad_to_list(m): """Convert a function into a list of (global) values.""" if not hasattr(m, "array"): m_v = m.x else: m_v = m m_a = gather(m_v) return m_a.tolist() def _ad_copy(self): """Create a (deep) copy of the function.""" r = get_overloaded_class(dolfinx.fem.Function)(self.function_space) assign(self, r) return r @staticmethod def _ad_assign_numpy(dst: dolfinx.fem.Function, src: numpy.ndarray, offset: int): range_begin, range_end = dst.x.index_map.local_range range_begin *= dst.x.block_size range_end *= dst.x.block_size m_a_local = src[offset + range_begin : offset + range_end] dst.x.array[: len(m_a_local)] = m_a_local offset += dst.x.index_map.size_local * dst.x.block_size dst.x.scatter_forward() return dst, offset
register_overloaded_type(Function, (dolfinx.fem.Function, Function))
[docs] def assign(value: typing.Union[numpy.inexact, float, int], function: Function, **kwargs: typing.Unpack[ad_kwargs]): """Assign a `value` to a :py:func:`dolfinx_adjoint.Function`. Args: value: The value to assign to the function. function: The function to assign the value to. *args: Additional positional arguments to pass to the assign method. **kwargs: Additional keyword arguments to pass to the assign method. """ # do not annotate in case of self assignment ad_block_tag = kwargs.pop("ad_block_tag", None) annotate = annotate_tape(kwargs) and value != function if annotate: if not isinstance(value, ufl.core.operator.Operator): value = create_overloaded_object(value) block = FunctionAssignBlock(value, function, ad_block_tag=ad_block_tag) tape = get_working_tape() tape.add_block(block) with stop_annotating(): if isinstance(value, (numpy.inexact, float, int)): function.x.array[:] = value elif isinstance(value, dolfinx.fem.Function): assert value.function_space == function.function_space, ( "Function spaces of the value and function must match for assignment." ) function.x.array[:] = value.x.array[:] else: raise ValueError(f"Unsupported value type for assignment: {type(value)})") if annotate: block.add_output(function.create_block_variable())