Source code for jax_costs

from typing import Union

import jax
import jax.numpy as jnp
import numpy as np
from pybamm import IDAKLUSolver

from pybop import BaseCost, BaseLikelihood, BaseProblem, Inputs


[docs] class BaseJaxCost(BaseCost): """ Base JAX cost function. This class implements a cost function using JAX for automatic differentiation and efficient gradient computation. It is designed to work with problems defined in the `BaseProblem` framework and supports transformations, gradient computation, and minimisation for optimisation tasks. Attributes ---------- problem : BaseProblem The problem object containing the model, data, and relevant configurations. model : BaseModel The model associated with the problem. n_data : int The number of data points in the problem. has_transform : bool Indicates whether input transformations are applied. """ def __init__(self, problem: BaseProblem): super().__init__(problem)
[docs] self.model = self.problem.model
[docs] self.n_data = self.problem.n_data
# JAXify solver if the model uses the IDAKLUSolver if isinstance(self.model.solver, IDAKLUSolver): self.model.jaxify_solver(t_eval=self.problem.domain_data)
[docs] def __call__( self,
[docs] inputs: Inputs,
calculate_grad: bool = False, apply_transform: bool = False, for_optimiser: bool = False, ) -> Union[np.array, tuple[float, np.ndarray]]: """ Compute the JAX cost function and (optionally) its gradient for given inputs. Parameters ---------- inputs : Inputs Input data for model evaluation. calculate_grad : bool, optional Whether to calculate and return the gradient. apply_transform : bool, optional Whether to apply transformation to the inputs. for_optimiser : bool, optional Whether the function is being called for an optimiser. Returns ------- Union[np.ndarray, tuple[float, np.ndarray]] The computed cost or a tuple of cost and gradient. """ # Set-up transformation, inputs, minimising factor self.has_transform = bool(self.transformation and apply_transform) model_inputs = self.parameters.verify(self._apply_transformations(inputs)) minimising_factor = 1 if (self.minimising or not for_optimiser) else -1 # Update solver sensitivities if needed if calculate_grad != self.model.calculate_sensitivities: self._update_solver_sensitivities(calculate_grad) if calculate_grad: y, dy = jax.value_and_grad(self.evaluate)(model_inputs) return minimising_factor * y, minimising_factor * np.asarray( list(dy.values()) ) return minimising_factor * self.evaluate(model_inputs)
[docs] def _update_solver_sensitivities(self, calculate_grad: bool) -> None: """ Updates the solver's sensitivity calculation based on the gradient requirement. Parameters ---------- calculate_grad: bool Whether gradient calculation is required. """ self.model.jaxify_solver( t_eval=self.problem.domain_data, calculate_sensitivities=calculate_grad )
@staticmethod
[docs] def check_sigma0(sigma0): """Validates the sigma0 parameter.""" if not isinstance(sigma0, (int, float)) or sigma0 <= 0: raise ValueError("sigma0 must be a positive number") return float(sigma0)
[docs] def observed_fisher(self, inputs: Inputs): """ Compute the observed Fisher Information Matrix (FIM) for the given inputs. The FIM is computed using the square of the gradient, divided by the number of data points. This is an approximation since the Hessian is not available. Returns ------- jnp.ndarray The observed Fisher Information Matrix. """ _, grad = self.__call__(inputs, calculate_grad=True) return jnp.square(grad) / self.n_data
[docs] class JaxSumSquaredError(BaseJaxCost): """ Jax-based Sum of Squared Error cost function. Parameters ---------- problem : BaseProblem The problem to fit, of type `pybop.BaseProblem`. Methods ------- evaluate(inputs) Computes the sum of squared errors between predictions and targets. """
[docs] def evaluate(self, inputs): """ Evaluates the sum of squared error for the given predictions. """ y = self.problem.evaluate(inputs) residuals = jnp.asarray([y[s] - self._target[s] for s in self.signal]) return jnp.sum(jnp.square(residuals))
[docs] class JaxLogNormalLikelihood(BaseJaxCost, BaseLikelihood): """ A Log-Normal Likelihood function. This function represents the underlining observed data sampled from a Log-Normal distribution. Parameters ---------- problem : BaseProblem The problem to fit, of type `pybop.BaseProblem`. sigma0 : float, optional (default=0.02) The standard deviation of the measured data. """ def __init__(self, problem: BaseProblem, sigma0: Union[list[float], float]): super().__init__(problem)
[docs] self.sigma = self.check_sigma0(sigma0)
[docs] self.sigma2 = jnp.square(self.sigma)
[docs] self._offset = 0.5 * self.n_data * jnp.log(2 * jnp.pi)
[docs] self._target_as_array = jnp.asarray([self._target[s] for s in self.signal])
[docs] self._log_target_sum = jnp.sum(jnp.log(self._target_as_array))
self._precompute()
[docs] def _precompute(self): self._constant_term = ( -self._offset - self.n_data * jnp.log(self.sigma) - self._log_target_sum )
[docs] def evaluate(self, inputs): """ Computes the log-normal likelihood. """ y = self.problem.evaluate(inputs) residuals = jnp.asarray( [jnp.log(y[s]) - jnp.log(self._target[s]) for s in self.signal] ) return self._constant_term - jnp.sum(jnp.square(residuals)) / (2 * self.sigma2)
[docs] class JaxGaussianLogLikelihoodKnownSigma(BaseJaxCost, BaseLikelihood): """ A Jax implementation of the Gaussian Likelihood function. This function represents the underlining observed data sampled from a Gaussian distribution with known noise, `sigma0`. Parameters ----------- problem: BaseProblem The problem to fit of type `pybop.BaseProblem` sigma0: float, optional The variance in the measured data """ def __init__(self, problem: BaseProblem, sigma0: Union[list[float], float]): super().__init__(problem)
[docs] self.sigma = self.check_sigma0(sigma0)
[docs] self.sigma2 = jnp.square(self.sigma)
[docs] self._offset = -0.5 * self.n_data * jnp.log(2 * jnp.pi * self.sigma2)
[docs] self._multip = -1 / (2.0 * self.sigma2)
[docs] def evaluate(self, inputs): """ Computes the Gaussian log-likelihood. """ y = self.problem.evaluate(inputs) residuals = jnp.asarray([y[s] - self._target[s] for s in self.signal]) return self._offset + jnp.sum(self._multip * jnp.sum(jnp.square(residuals)))