Source code for jax_costs

from typing import Union

import jax
import jax.numpy as jnp
import numpy as np
from pybamm import IDAKLUSolver

from pybop import BaseCost, BaseLikelihood, BaseProblem, Inputs


[docs] class BaseJaxCost(BaseCost): """ Base JAX cost function. This class implements a cost function using JAX for automatic differentiation and efficient gradient computation. It is designed to work with problems defined in the `BaseProblem` framework and supports transformations, gradient computation, and minimisation for optimisation tasks. Attributes ---------- problem : BaseProblem The problem object containing the model, data, and relevant configurations. model : BaseModel The model associated with the problem. n_data : int The number of data points in the problem. has_transform : bool Indicates whether input transformations are applied. """ def __init__(self, problem: BaseProblem): super().__init__(problem)
[docs] self.model = self.problem.model
[docs] self.n_data = self.problem.n_data
# JAXify solver if the model uses the IDAKLUSolver if isinstance(self.model.solver, IDAKLUSolver): self.model.jaxify_solver(t_eval=self.problem.domain_data)
[docs] def __call__( self, inputs: Inputs, calculate_grad: bool = False, apply_transform: bool = False, for_optimiser: bool = False,
[docs] ) -> Union[np.array, tuple[float, np.ndarray]]:
""" Compute the JAX cost function and (optionally) its gradient for given inputs. Parameters ---------- inputs : Inputs Input data for model evaluation. calculate_grad : bool, optional Whether to calculate and return the gradient. apply_transform : bool, optional Whether to apply transformation to the inputs. for_optimiser : bool, optional Whether the function is being called for an optimiser. Returns ------- Union[np.ndarray, tuple[float, np.ndarray]] The computed cost or a tuple of cost and gradient. """ # Set-up transformation, inputs, minimising factor self.has_transform = bool(self.transformation and apply_transform) model_inputs = self.parameters.verify(self._apply_transformations(inputs)) minimising_factor = 1 if (self.minimising or not for_optimiser) else -1 # Update solver sensitivities if needed if calculate_grad != self.model.calculate_sensitivities: self._update_solver_sensitivities(calculate_grad) if calculate_grad: y, dy = jax.value_and_grad(self.evaluate)(model_inputs) return minimising_factor * y, minimising_factor * np.asarray( list(dy.values()) ) return minimising_factor * self.evaluate(model_inputs)
[docs] def _update_solver_sensitivities(self, calculate_grad: bool) -> None: """ Updates the solver's sensitivity calculation based on the gradient requirement. Parameters ---------- calculate_grad: bool Whether gradient calculation is required. """ self.model.jaxify_solver( t_eval=self.problem.domain_data, calculate_sensitivities=calculate_grad )
@staticmethod
[docs] def check_sigma0(sigma0): """Validates the sigma0 parameter.""" if not isinstance(sigma0, (int, float)) or sigma0 <= 0: raise ValueError("sigma0 must be a positive number") return float(sigma0)
[docs] class JaxSumSquaredError(BaseJaxCost): """ Jax-based Sum of Squared Error cost function. Parameters ---------- problem : BaseProblem The problem to fit, of type `pybop.BaseProblem`. Methods ------- evaluate(inputs) Computes the sum of squared errors between predictions and targets. """
[docs] def evaluate(self, inputs): """ Evaluates the sum of squared error for the given predictions. """ y = self.problem.evaluate(inputs) residuals = jnp.asarray([y[s] - self._target[s] for s in self.signal]) return jnp.sum(jnp.square(residuals))
[docs] class JaxLogNormalLikelihood(BaseJaxCost, BaseLikelihood): """ A Log-Normal Likelihood function. This function represents the underlining observed data sampled from a Log-Normal distribution. Parameters ---------- problem : BaseProblem The problem to fit, of type `pybop.BaseProblem`. sigma0 : float, optional (default=0.02) The standard deviation of the measured data. """ def __init__(self, problem: BaseProblem, sigma0: Union[list[float], float]): super().__init__(problem)
[docs] self.sigma = self.check_sigma0(sigma0)
[docs] self.sigma2 = jnp.square(self.sigma)
[docs] self._offset = 0.5 * self.n_data * jnp.log(2 * jnp.pi)
[docs] self._target_as_array = jnp.asarray([self._target[s] for s in self.signal])
[docs] self._log_target_sum = jnp.sum(jnp.log(self._target_as_array))
self._precompute()
[docs] def _precompute(self): self._constant_term = ( -self._offset - self.n_data * jnp.log(self.sigma) - self._log_target_sum )
[docs] def evaluate(self, inputs): """ Computes the log-normal likelihood. """ y = self.problem.evaluate(inputs) residuals = jnp.asarray( [jnp.log(y[s]) - jnp.log(self._target[s]) for s in self.signal] ) return self._constant_term - jnp.sum(jnp.square(residuals)) / (2 * self.sigma2)
[docs] class JaxGaussianLogLikelihoodKnownSigma(BaseJaxCost, BaseLikelihood): """ A Jax implementation of the Gaussian Likelihood function. This function represents the underlining observed data sampled from a Gaussian distribution with known noise, `sigma0`. Parameters ----------- problem: BaseProblem The problem to fit of type `pybop.BaseProblem` sigma0: float, optional The variance in the measured data """ def __init__(self, problem: BaseProblem, sigma0: Union[list[float], float]): super().__init__(problem)
[docs] self.sigma = self.check_sigma0(sigma0)
[docs] self.sigma2 = jnp.square(self.sigma)
[docs] self._offset = -0.5 * self.n_data * jnp.log(2 * jnp.pi * self.sigma2)
[docs] self._multip = -1 / (2.0 * self.sigma2)
[docs] def evaluate(self, inputs): """ Computes the Gaussian log-likelihood. """ y = self.problem.evaluate(inputs) residuals = jnp.asarray([y[s] - self._target[s] for s in self.signal]) return self._offset + jnp.sum(self._multip * jnp.sum(jnp.square(residuals)))