---
jupytext:
  formats: ipynb,md:myst
  text_representation:
    extension: .md
    format_name: myst
    format_version: 0.13
    jupytext_version: 1.19.1
kernelspec:
  display_name: festim-workshop
  language: python
  name: python3
---

# Training a FESTIM surrogate model

+++

```{admonition} Objectives
:class: objectives
* Use AutoEmulate to build an emulator/surrogate model of a FESTIM problem
* Save the emulator to a file for reuse
```

+++

## Build a _simulator_

We start by building a simulator (high fidelity) using FESTIM.
Here for demonstration purposes we make a simple unit square example with two volume subdomains with different diffusivities. The physical process models hydrogen transport (diffusion) across the two subdomains, driven by the respective source terms. Hydrogen diffuses from the areas of higher concentration towards the boundaries. We are modeling this at a constant temperature.

The model is parametric and takes the volumetric source terms in each subdomains as inputs and returns the total species inventories in both subdomains.

The concentration of species on the boundary of the domain is set to zero.


```{seealso}
Check out the [AutoEmulate tutorials](https://alan-turing-institute.github.io/autoemulate/tutorials/emulation/01_quickstart.html) for more information.
```

```{note}
This specific example isn't particularly expensive to run. Building a surrogate of a FESTIM model is especially useful for large and computationally expensive models, such as 3D simulations with complex geometries, coupled multiphysics (e.g., heat transfer and trapping), or parameter sweeps where a single full-fidelity FESTIM run could take minutes to hours.
```

```{code-cell} ipython3
import festim as F

from dolfinx.mesh import create_unit_square
from mpi4py import MPI


def make_model(source_bottom: float, source_top: float) -> F.HydrogenTransportProblem:
    fenics_mesh = create_unit_square(MPI.COMM_WORLD, 20, 20)

    festim_mesh = F.Mesh(fenics_mesh)

    material_top = F.Material(D_0=0.2, E_D=0)
    material_bot = F.Material(D_0=0.1, E_D=0)

    top_volume = F.VolumeSubdomain(
        id=1, material=material_top, locator=lambda x: x[1] >= 0.5
    )
    bottom_volume = F.VolumeSubdomain(
        id=2, material=material_bot, locator=lambda x: x[1] <= 0.5
    )

    boundary = F.SurfaceSubdomain(id=1)

    my_model = F.HydrogenTransportProblem()
    my_model.mesh = festim_mesh
    my_model.subdomains = [boundary, top_volume, bottom_volume]

    H = F.Species("H")
    my_model.species = [H]

    my_model.temperature = 400

    my_model.boundary_conditions = [
        F.FixedConcentrationBC(subdomain=boundary, value=0.0, species=H),
    ]

    my_model.sources = [
        F.ParticleSource(species=H, volume=bottom_volume, value=source_bottom),
        F.ParticleSource(species=H, volume=top_volume, value=source_top),
    ]

    my_model.settings = F.Settings(atol=1e-10, rtol=1e-10, transient=False)

    my_model.exports = [
        F.TotalVolume(field=H, volume=top_volume),
        F.TotalVolume(field=H, volume=bottom_volume),
    ]

    return my_model
```

Let's first visualize the system's spatial profile for varying top and bottom hydrogen source rates. The mesh is colored and warped by the hydrogen concentration.

```{code-cell} ipython3
from dolfinx import plot
import pyvista
pyvista.set_jupyter_backend("html")


def make_ugrid(solution, label="c"):
    topology, cell_types, geometry = plot.vtk_mesh(solution.function_space)
    u_grid = pyvista.UnstructuredGrid(topology, cell_types, geometry)
    u_grid.point_data[label] = solution.x.array.real
    u_grid.set_active_scalars(label)
    return u_grid

u_plotter = pyvista.Plotter(shape=(2,2))

for i, (source_bottom, source_top) in enumerate([(0.0, 1.0), (1.0, 0.0), (2.0, 1.0), (1.0, 2.0)]):
    emulator = make_model(source_bottom, source_top)
    emulator.initialise()
    emulator.run()

    H = emulator.species[0]
    u_grid = make_ugrid(H.post_processing_solution)
    u_plotter.subplot(i // 2, i % 2)
    warped = u_grid.warp_by_scalar(factor=1)
    u_plotter.add_mesh(warped, cmap="viridis", show_edges=True)
    u_plotter.add_text(f"source_bottom={source_bottom}, source_top={source_top}", font_size=10)
    u_plotter.link_views()

if not pyvista.OFF_SCREEN:
    u_plotter.show()
else:
    figure = u_plotter.screenshot("concentration.png")
```

## Wrapping the FESTIM model for AutoEmulate

To train a surrogate model with `AutoEmulate`, we need to expose our FESTIM model through a `Simulator` class. We create a subclass of `autoemulate.simulations.base.Simulator` and implement the `_forward` method. This method takes a 2D [`torch.Tensor`](https://pytorch.org/docs/stable/tensors.html) of inputs `x`, extracts the `source_top` and `source_bottom` values, sets up and solves the FESTIM model, and finally returns the outputs as a `torch.Tensor`.

```{code-cell} ipython3
from autoemulate.simulations.base import Simulator
import torch


class FestimProblem(Simulator):
    def _forward(self, x: torch.Tensor) -> torch.Tensor:
        source_top = x[:, 0]
        source_bottom = x[:, 1]

        # convert to float
        source_top = source_top.item()
        source_bottom = source_bottom.item()
        model = make_model(source_bottom, source_top)

        # Solve the model
        model.initialise()
        model.run()

        # Extract the total amount of H in the top and bottom volumes
        total_top = model.exports[0].data
        total_bot = model.exports[1].data


        y = torch.tensor([total_top, total_bot]).T
        # Ensure the output is a 2D tensor
        if y.ndim == 1:
            y = y.unsqueeze(1)
        
        return y
```

## Generating the Training Data

Now we can create an instance of our wrapper `FestimProblem`. We'll define the ranges for the top and bottom sources, as well as the names of the two quantities of interest (the top and bottom total hydrogen volumes).

```{code-cell} ipython3
simulator = FestimProblem(parameters_range={'source_top': (0.0, 10.0), 'source_bottom': (0.0, 10.0)}, output_names=['total_top', 'total_bot'])
```

We can easily evaluate a single simulation to make sure the model works. Note that the output tensor contains the `total_top` and `total_bot` results for the specific model execution.

```{code-cell} ipython3
simulator.forward(torch.tensor([[0.0, 3.0]]))
```

Let's generate 20 random samples using the sampling strategy provided by AutoEmulate via the `sample_inputs` method. These will be our $X$ inputs to train the surrogate model.

```{code-cell} ipython3
n_samples = 20

X = simulator.sample_inputs(n_samples)

X.shape
```

Next, we run the simulations in a batch over the sampled inputs to retrieve our $Y$ outputs (the FESTIM calculations).

```{code-cell} ipython3
Y, _ = simulator.forward_batch(X, allow_failures=False)
Y.shape
```

```{code-cell} ipython3
Y
```

Let's visualize the training dataset. We scatter the 20 sampled source combinations, coloring them by the target output quantities.

```{code-cell} ipython3
import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 2, figsize=(12, 5))

for i in range(2):
    plt.sca(axs[i])
    plt.scatter(X[:, 0], X[:, 1], c=Y[:, i], cmap='viridis', vmin=Y.min(), vmax=Y.max())

    plt.title(f'{simulator.output_names[i]}')

    plt.xlabel(f"{simulator.param_names[0]}")
    plt.ylabel(f"{simulator.param_names[1]}")

plt.colorbar(cax=fig.add_axes([0.92, 0.15, 0.02, 0.7]))

plt.show()
```

## Training the surrogate models

We will train several models by simply instantiating `AutoEmulate` with our $X$ and $Y$ tensors. The library evaluates several typical regression algorithms (like Random Forests, Gaussian Processes, Multi-Layer Perceptrons, etc.) out-of-the-box using the provided data.

```{code-cell} ipython3
:tags: [hide-output]

from autoemulate import AutoEmulate
# Run AutoEmulate with default settings
ae = AutoEmulate(X, Y, log_level="info")
```

```{code-cell} ipython3
ae.summarise()
```

Here we decide to select the `GaussianProcessRBF` model:

```{code-cell} ipython3
# pick GaussianProcessRBF
emulator = [r for r in ae.results if r.model_name == "GaussianProcessRBF"][0]
print(f"Selected model: {emulator.model_name} with id: {emulator.id}")
```

To analyze the emulator's quality, we plot the predicted versus simulated values on hold-out test data. This is a testing dataset sampled from within the same parameter ranges that was set aside and completely hidden from the surrogate model during its training phase. 

Each plotted point represents a single `(source_top, source_bottom)` input combination. Points closer to the diagonal indicate that the emulator accurately matches the FESTIM high-fidelity predictions.

```{code-cell} ipython3
ae.plot_preds(emulator, output_names=simulator.output_names)
```

Finally, let's explore the continuous parameter space by plotting a 2D slice of the surrogate model's predictions over the `source_top` and `source_bottom` space. We can overlay the training samples (points) to see how well the emulator covers the domain.

```{code-cell} ipython3
from autoemulate.core.plotting import create_and_plot_slice

for i in range(2):

    fig, axs = create_and_plot_slice(
        emulator.model,
        output_idx=i,
        parameters_range=simulator.parameters_range,
        quantile=0.5,
        param_pair=(0, 1),
    )
    plt.scatter(X[:, 0], X[:, 1])
    plt.suptitle(f'{simulator.output_names[i]}')
    plt.show()
```

## Performance Comparison

Finally, we can compare the runtime between executing the full FESTIM physical model and evaluating the trained empirical surrogate using the `%timeit` notebook magic.

```{code-cell} ipython3
# Time the FESTIM simulator
print("Simulator runtime:")
%timeit simulator.forward(torch.tensor([[5.0, 5.0]]))
```

```{code-cell} ipython3
# Time the empirical surrogate model
print("Emulator runtime:")
%timeit emulator.model.predict(torch.tensor([[5.0, 5.0]]))
```

As you can see, substituting FESTIM with the emulator provides a substantial speed-up, highlighting the benefit of training surrogate models in scenarios where a model is evaluated repeatedly (like inference, uncertainty quantification or sensitivity analysis).
