Smoothed TV image inpainting#

Download sources

This demo solves a variational image inpainting problem on the unit square. A synthetic image is masked on an irregular interior region, and the missing values are reconstructed using smoothed total variation (TV) regularization.

Problem Definition#

Let \(\Omega = [0,1]^2\) be the image domain. We define:

  • \(u_{\mathrm{true}}\): synthetic ground-truth image

  • \(m\): mask, equal to 1 on known data and 0 on the missing region

  • \(f = m u_{\mathrm{true}}\): observed incomplete image

  • \(u\): reconstructed image

We compute \(u\) by minimizing

\[ J(u)= {1 \over 2}\beta \int_\Omega m(u-f)^2\,\mathrm{d}x + \alpha \int_\Omega \sqrt{||\nabla u||^2 + \varepsilon^2}~\mathrm{d}x. \]

The first term enforces agreement with the known image data, while the second term is a smoothed total variation regularization term. It promotes piecewise smooth solution and preserves edges \(\alpha\) and \(\beta\) control the balance between the data fidelity (fit to f) and smoothness. The parameter \(\varepsilon>0\) smooths the TV function so that it is differentiable and can be solved with Newton type methods

Discretization#

We discretize the problem using

  • a first order Lagrange finite element space

  • a triangular mesh of the unit square

Implementation#

We use a first-order Lagrange space on a triangular mesh of the unit square. The nonlinear problem is solved with PETSc SNES through NonlinearProblem.

from mpi4py import MPI

import matplotlib.pyplot as plt
import matplotlib.tri as mtri
import numpy as np

import ufl
from dolfinx import fem, mesh
from dolfinx.fem.petsc import NonlinearProblem

We discretize the domain \(\Omega =[0,1]^2\) using a triangular mesh, where nx and ny control the resolution of the mesh.

nx = 128
ny = 128
msh = mesh.create_unit_square(MPI.COMM_WORLD, nx, ny)

We use first order Lagrange elements for discretizing the image. In this space, the DOFs are the values of u at mesh vertices the solution is continuous but has piecewise constant gradient

V = fem.functionspace(msh, ("Lagrange", 1))

Ground Truth image \(u_{true}\)#

We define a synthetic binary image

\[\begin{split} u_{true}=\begin{cases} 1 & \text{ if } (x,y) \text{ is inside a square}\\ 0 &\text{ otherwise} \end{cases} \end{split}\]

The square is defined as \(0.2<x<0.8, ~0.2<y<0.8\), which gives a piecewise-constant image with sharp edges

def true_image(x):
    """Define a binary image with a square in the center."""
    X = x[0]
    Y = x[1]

    # main square
    return ((X > 0.2) & (X < 0.8) & (Y > 0.2) & (Y < 0.8)).astype(np.float64)

Mask \(m(x,y)\)#

The mask defines which pixel are known and which are missing

\[\begin{split} m(x,y)=\begin{cases} 1& \text{ known data}\\ 0 & \text{ missing region} \end{cases} \end{split}\]

We construct a mask with random “holes” inside the square

  • small circular regions are removed and set to 0

  • everywhere else remains known (1)

This creates a challenging inpainting problem as:

  • many small missing regions

  • irregular geometry

The solver must reconstruct these missing values using smoothness (TV regularization)

def mask_function(x):
    """Create a mask with random circular holes inside the square."""
    X = x[0]
    Y = x[1]
    # all pixels known
    mask = np.ones_like(X, dtype=np.float64)
    # number of speckles
    num_speckles = 25
    # random centers
    generator = np.random.Generator(
        np.random.MT19937(0)
    )  # random seed for reproducibility

    cx = generator.uniform(0.25, 0.75, num_speckles)
    cy = generator.uniform(0.25, 0.75, num_speckles)
    # random radii (small + varied)
    radii = generator.uniform(0.012, 0.035, num_speckles)
    # create holes. mask =0 inside circles
    for i in range(num_speckles):
        r2 = (X - cx[i]) ** 2 + (Y - cy[i]) ** 2
        mask[r2 < radii[i] ** 2] = 0.0
    return mask

We interpolate the exact image and the mask into the finite element space, and construct the observed damaged image, where \(u_{true}\) is our true image, \(m: \Omega \to \mathbb{R}\) is the mask, \(f: \Omega \to \mathbb{R}\) is the observed damaged image, and \(u:\Omega \to \mathbb{R}\) is the reconstructed image.

u_true = fem.Function(V, name="true_image")
u_true.interpolate(true_image)
m = fem.Function(V, name="mask")
m.interpolate(mask_function)
f = fem.Function(V, name="observed_image")
f.x.array[:] = m.x.array * u_true.x.array
u = fem.Function(V, name="reconstructed_image")
u.x.array[:] = f.x.array.copy()

We now define the nonlinear variational problem corresponding to the smoothed total variation regularised inpainting model.

The Euler-Lagrange equation for \(J(u)\) leads to the weak form Find \(u\in V\) such that

\[ \int_\Omega m(u-f)v\,\mathrm{d}x + \alpha \int_\Omega {\nabla u\cdot\nabla v \over \sqrt{||\nabla u||^2+\varepsilon^2}} \,\mathrm{d}x = 0 \]

for all test functions \(v\). This is a nonlinear problem due to the TV term Total variation is usually defined as \(\vert\vert\nabla u\vert\vert\), but in practice one uses a smoothed version to allow for differentiation and Newton type solvers:

\[ TV = \sqrt{\vert\vert\nabla u\vert\vert^2 +\varepsilon^2} \]

where # \(\varepsilon\) is the smoothing of the TV:

  • large \(\varepsilon\) smoother more like quadratic diffusion

  • small \(\varepsilon\) closer to true TV edge preserving

alpha = fem.Constant(msh, 0.003)
beta = fem.Constant(msh, 1.0)
eps = fem.Constant(msh, 1.0e-4)

Smoothed TV inpainting energy functional. We define the energy J(u) and use ufl.derivative() to obtain the residual form \(F(u; v)=0 \quad\forall v\in V\):

\[ J(u) = {1 \over 2}\beta\int_\Omega m(u-f)^2\,dx + \alpha\int_\Omega \sqrt{||\nabla u||^2+\varepsilon^2}~\mathrm{d}x \]

Taking the first variation gives the weak form F(u; v).

\[ F(u; v) = \beta\int_\Omega m(u-f)v\,dx + \alpha\int_\Omega {\nabla u\cdot\nabla v \over \sqrt{||\nabla u||^2+\varepsilon^2}}~\mathrm{d}x = 0 \quad \forall v\in V. \]
v = ufl.TestFunction(V)

J_energy = (
    0.5 * beta * m * (u - f) ** 2 * ufl.dx
    + alpha * ufl.sqrt(ufl.inner(ufl.grad(u), ufl.grad(u)) + eps**2) * ufl.dx
)

F = ufl.derivative(J_energy, u, v)

This formulation is based on total variation (TV) regularization for image denoising and inpainting Rudin et al. [ROF92], Shen and Chan [SC02].

A nonlinear PETSc problem is created and solved with a Newton line-search method, with an LU factorization for the linearized system \(F'(u_k) s= -F(u_k)\).

petsc_options = {
    "snes_type": "newtonls",
    "snes_linesearch_type": "bt",
    "snes_rtol": 1.0e-8,
    "snes_atol": 1.0e-8,
    "snes_max_it": 1000,
    "ksp_type": "preonly",
    "pc_type": "lu",
}

problem = NonlinearProblem(
    F,
    u,
    bcs=[],
    petsc_options_prefix="tv_inpainting_",
    petsc_options=petsc_options,
)

problem.solve()
Coefficient(FunctionSpace(Mesh(blocked element (Basix element (P, triangle, 1, gll_warped, unset, False, float64, []), (2,)), 0), Basix element (P, triangle, 1, gll_warped, unset, False, float64, [])), 3)

Model Validation and Results#

These diagnostics asses

  1. whether the nonlinear Newton/SNES solve converged

  2. whether the variational objective decreased

  3. how accurate the reconstruction is globally and in the hole region

FEM Metrics Global number of degrees of freedom reports the size of the finite element discretization H1 seminorm error measures the gradient error

\[ \vert\vert\nabla(u-u_{true})\vert\vert_{L_2 (\Omega)} \]

This is useful as TV regularization is gradient based. Smaller values mean the reconstruction recovers edge structure better

num_dofs = V.dofmap.index_map.size_global
h1_semi_error = fem.assemble_scalar(
    fem.form(ufl.inner(ufl.grad(u - u_true), ufl.grad(u - u_true)) * ufl.dx)
)
h1_semi_error = np.sqrt(msh.comm.allreduce(h1_semi_error, op=MPI.SUM))

Reconstruction Errors Data fidelity (known region only):

\[ \sqrt{\vert\vert m(u-f) \vert\vert_{L_2 \Omega}} \]

measures the agreement with the known image data. Smaller values mean the reconstruction matches the observe pixels better.

data_error = fem.assemble_scalar(fem.form(m * (u - f) ** 2 * ufl.dx))
data_error = np.sqrt(msh.comm.allreduce(data_error, op=MPI.SUM))

TV seminorm

\[ \int_{\Omega}\sqrt{\vert\vert\nabla u \vert\vert^2 +\varepsilon^2}~\mathrm{d}x \]

This is the regularization term in the objective Smaller values mean a smoother reconstruction

tv_energy = fem.assemble_scalar(
    fem.form(ufl.sqrt(ufl.inner(ufl.grad(u), ufl.grad(u)) + eps**2) * ufl.dx)
)
tv_energy = msh.comm.allreduce(tv_energy, op=MPI.SUM)

True error

\[ \sqrt{\vert\vert u-u_{true} \vert\vert_{L_2 \Omega}} \]

Measures overall reconstruction accuracy

true_error = fem.assemble_scalar(fem.form((u - u_true) ** 2 * ufl.dx))
true_error = np.sqrt(msh.comm.allreduce(true_error, op=MPI.SUM))

Hole error

\[ \sqrt{\vert\vert (1-m)(u-u_{true}) \vert\vert_{L_2 \Omega}} \]
hole_error = fem.assemble_scalar(fem.form((1 - m) * (u - u_true) ** 2 * ufl.dx))
hole_error = np.sqrt(msh.comm.allreduce(hole_error, op=MPI.SUM))

Image quality metric PSNR (peak signal to noise ratio), standard imaging metric since the image range is [0,1], we use

\[ PSNR=10\log_{10}(1/MSE) \]

Larger PSNR means better reconstruction quality

mse = np.mean((u.x.array - u_true.x.array) ** 2)
if mse == 0:
    psnr = np.inf
else:
    psnr = 10.0 * np.log10(1.0 / mse)

Newton Linesearch metrics Measure whether the nonlinear solve succeeded

  • we want a positive converged reason

  • a small final residual norm

  • a reasonable number of iterations

snes = problem.solver
reason = snes.getConvergedReason()
iters = snes.getIterationNumber()
final_residual = snes.getFunctionNorm()

Objective values Comparing the initial objective J(f) with the final objective J(u)

\[ J(v)={1\over 2}\beta \int m(v-f)^2 dx +\alpha \int \sqrt{||\nabla v||^2+\varepsilon^2} \]

A decrease in the objective show that the nonlinear optimization improved the damaged image under the smoothed TV model

objective_value = 0.5 * float(beta) * data_error**2 + float(alpha) * tv_energy
if reason > 0:
    status = "converged"
else:
    status = "not converged"

u0 = fem.Function(V)
u0.x.array[:] = f.x.array.copy()

J0_data = fem.assemble_scalar(fem.form(m * (u0 - f) ** 2 * ufl.dx))
J0_data = msh.comm.allreduce(J0_data, op=MPI.SUM)

J0_tv = fem.assemble_scalar(
    fem.form(ufl.sqrt(ufl.inner(ufl.grad(u0), ufl.grad(u0)) + eps**2) * ufl.dx)
)
J0_tv = msh.comm.allreduce(J0_tv, op=MPI.SUM)

J0 = 0.5 * float(beta) * J0_data + float(alpha) * J0_tv

Printing statements for validation and metrics If on main process

if msh.comm.rank == 0:
    print("---Smoothed TV inpainting results---")

    print("--FEM Metrics--")
    print(f"Global DOFs: {num_dofs}")
    print(f"H1 seminorm error: {h1_semi_error}")

    print("--Newton Linesearch:--")
    print("-Optimization:-")
    print(f"Initial objective J(f): {J0:.4e}")
    print(f"Final objective J(u): {objective_value:.4e}")
    print(f"Relative decrease: {(J0 - objective_value) / J0:.2%}")

    print("-Solver convergence:-")
    print(f"SNES iteration: {iters}")
    print(f"SNES final residual norm: {final_residual:.4e}")
    print(f"SNES status: {status}")
    print(f"SNES converged reason: {reason}")

    print("---Reconstruction Quality:---")
    print(f"Data error (known region): {data_error:.4e}")
    print(f"TV seminorm: {tv_energy:.4e}")
    print(f"True L2 error: {true_error:.4e}")
    print(f"Hole error: {hole_error:.4e}")
    print(f"PSNR: {psnr:.2f} dB")
---Smoothed TV inpainting results---
--FEM Metrics--
Global DOFs: 16641
H1 seminorm error: 1.9251510135159864
--Newton Linesearch:--
-Optimization:-
Initial objective J(f): 1.7398e-02
Final objective J(u): 7.9366e-03
Relative decrease: 54.38%
-Solver convergence:-
SNES iteration: 430
SNES final residual norm: 2.6906e-09
SNES status: converged
SNES converged reason: 2
---Reconstruction Quality:---
Data error (known region): 4.7422e-02
TV seminorm: 2.2707e+00
True L2 error: 2.4220e-02
Hole error: 7.4629e-03
PSNR: 32.13 dB

Visualization#

We construct fields that allow us to visually asses the quality of the reconstruction \(u-u_{true}\) is the global reconstruction error \((1-m)(u-u_{true})\) is the hole error, restricted to the missing regions

u_minus_u_true = fem.Function(V)
u_minus_u_true.x.array[:] = u.x.array - u_true.x.array

hole_error_field = fem.Function(V)
hole_error_field.x.array[:] = (1.0 - m.x.array) * (u.x.array - u_true.x.array)

FEM to matplotlib#

The solution u in FEM is represented by values at degrees of freedom (DOFs), not on a regular grid To plot in matplotlib

  1. extract the coordinates of the DOFs

  2. extract the function-space dofmap connectivity (triangles)

  3. build a Triangulation object

This allows matplotlib to render the piecewise linear FEM solution

coords = V.tabulate_dof_coordinates()
x, y = coords[:, 0], coords[:, 1]

triangles = V.dofmap.list
triang = mtri.Triangulation(x, y, triangles)

Plotting We use matplotlib.pyplot.tripcolor() to plot scalar fields defined on a triangulated mesh shading= “flat” shows piecewise constant coloring per triangle which better reflects the discrete FEM representations

def plot_field(ax, data, title, fig, cmap="viridis", vmin=0.0, vmax=1.0):
    """Plot a scalar field on a triangulated mesh."""
    im = ax.tripcolor(triang, data, shading="flat", cmap=cmap, vmin=vmin, vmax=vmax)
    ax.set_title(title)
    ax.set_aspect("equal")
    fig.colorbar(im, ax=ax)


fig, axes = plt.subplots(2, 3, figsize=(12, 8))
# $u_{true } $ground truth image
plot_field(axes[0, 0], u_true.x.array, "u_true", fig)
# m, mask with known (1) and missing (0) regions
plot_field(axes[0, 1], m.x.array, "mask", fig, cmap="gray")
# f is the damaged image
plot_field(axes[0, 2], f.x.array, "f", fig)
# u is the reconstructed image
plot_field(axes[1, 0], u.x.array, "u", fig)
# Global error
lim = np.max(np.abs(u_minus_u_true.x.array))
# $u-u_{true}$ is the global reconstruction error
plot_field(
    axes[1, 1],
    u_minus_u_true.x.array,
    "u - u_true",
    fig,
    cmap="coolwarm",
    vmin=-lim,
    vmax=lim,
)
# Hole only errors
lim = np.max(np.abs(hole_error_field.x.array))
# Hole only error restricted to the missing regions
plot_field(
    axes[1, 2],
    hole_error_field.x.array,
    "hole-only error",
    fig,
    cmap="coolwarm",
    vmin=-lim,
    vmax=lim,
)
plt.tight_layout()
plt.show()
../_images/f853d26cf12c65f9a2e9e9dcb2a35b8462db3de6a5bc51234b32dfbcf68be168.png

References#

[ROF92]

Leonid I. Rudin, Stanley Osher, and Emad Fatemi. Nonlinear total variation based noise removal algorithms. Physica D: Nonlinear Phenomena, 60(1):259–268, 1992. doi:10.1016/0167-2789(92)90242-F.

[SC02]

Jianhong Shen and Tony F. Chan. Mathematical Models for Local Nontexture Inpaintings. SIAM Journal on Applied Mathematics, 62(3):1019–1043, 2002. doi:10.1137/S0036139900368844.