jax_costs#
Classes#
Base JAX cost function. |
|
A Jax implementation of the Gaussian Likelihood function. |
|
A Log-Normal Likelihood function. This function represents the |
|
Jax-based Sum of Squared Error cost function. |
Module Contents#
- class jax_costs.BaseJaxCost(problem: pybop.BaseProblem)[source]#
Bases:
pybop.BaseCostBase JAX cost function.
This class implements a cost function using JAX for automatic differentiation and efficient gradient computation. It is designed to work with problems defined in the BaseProblem framework and supports transformations, gradient computation, and minimisation for optimisation tasks.
- has_transform#
Indicates whether input transformations are applied.
- Type:
bool
- __call__(inputs: pybop.Inputs, calculate_grad: bool = False, apply_transform: bool = False, for_optimiser: bool = False) numpy.array | tuple[float, numpy.ndarray][source]#
Compute the JAX cost function and (optionally) its gradient for given inputs.
- Parameters:
inputs (Inputs) – Input data for model evaluation.
calculate_grad (bool, optional) – Whether to calculate and return the gradient.
apply_transform (bool, optional) – Whether to apply transformation to the inputs.
for_optimiser (bool, optional) – Whether the function is being called for an optimiser.
- Returns:
The computed cost or a tuple of cost and gradient.
- Return type:
Union[np.ndarray, tuple[float, np.ndarray]]
- _update_solver_sensitivities(calculate_grad: bool) None[source]#
Updates the solver’s sensitivity calculation based on the gradient requirement.
- Parameters:
calculate_grad (bool) – Whether gradient calculation is required.
- observed_fisher(inputs: pybop.Inputs)[source]#
Compute the observed Fisher Information Matrix (FIM) for the given inputs.
The FIM is computed using the square of the gradient, divided by the number of data points. This is an approximation since the Hessian is not available.
- Returns:
The observed Fisher Information Matrix.
- Return type:
jnp.ndarray
- class jax_costs.JaxGaussianLogLikelihoodKnownSigma(problem: pybop.BaseProblem, sigma0: list[float] | float)[source]#
Bases:
BaseJaxCost,pybop.BaseLikelihoodA Jax implementation of the Gaussian Likelihood function. This function represents the underlining observed data sampled from a Gaussian distribution with known noise, sigma0.
- Parameters:
problem (BaseProblem) – The problem to fit of type pybop.BaseProblem
sigma0 (float, optional) – The variance in the measured data
- class jax_costs.JaxLogNormalLikelihood(problem: pybop.BaseProblem, sigma0: list[float] | float)[source]#
Bases:
BaseJaxCost,pybop.BaseLikelihoodA Log-Normal Likelihood function. This function represents the underlining observed data sampled from a Log-Normal distribution.
- Parameters:
problem (BaseProblem) – The problem to fit, of type pybop.BaseProblem.
sigma0 (float, optional (default=0.02)) – The standard deviation of the measured data.
- class jax_costs.JaxSumSquaredError(problem: pybop.BaseProblem)[source]#
Bases:
BaseJaxCostJax-based Sum of Squared Error cost function.
- Parameters:
problem (BaseProblem) – The problem to fit, of type pybop.BaseProblem.