jax_costs#
Classes#
Base JAX cost function. |
|
A Jax implementation of the Gaussian Likelihood function. |
|
A Log-Normal Likelihood function. This function represents the |
|
Jax-based Sum of Squared Error cost function. |
Module Contents#
- class jax_costs.BaseJaxCost(problem: pybop.BaseProblem)[source]#
Bases:
pybop.BaseCostBase JAX cost function.
This class implements a cost function using JAX for automatic differentiation and efficient gradient computation. It is designed to work with problems defined in the BaseProblem framework and supports gradient computation.
- _update_solver_sensitivities(calculate_grad: bool) None[source]#
Updates the solver’s sensitivity calculation based on the gradient requirement.
- Parameters:
calculate_grad (bool) – Whether gradient calculation is required.
- single_call(inputs: pybop.Inputs, calculate_grad: bool = False) numpy.array | tuple[float, numpy.ndarray][source]#
Compute the JAX cost function and (optionally) its gradient for given inputs.
- Parameters:
inputs (Inputs) – Input data for model evaluation.
calculate_grad (bool, optional) – Whether to calculate and return the gradient.
- Returns:
The computed cost or a tuple of cost and gradient.
- Return type:
Union[np.ndarray, tuple[float, np.ndarray]]
- class jax_costs.JaxGaussianLogLikelihoodKnownSigma(problem: pybop.BaseProblem, sigma0: list[float] | float)[source]#
Bases:
BaseJaxCost,pybop.BaseLikelihoodA Jax implementation of the Gaussian Likelihood function. This function represents the underlining observed data sampled from a Gaussian distribution with known noise, sigma0.
- Parameters:
problem (BaseProblem) – The problem to fit of type pybop.BaseProblem
sigma0 (float, optional) – The variance in the measured data
- class jax_costs.JaxLogNormalLikelihood(problem: pybop.BaseProblem, sigma0: list[float] | float)[source]#
Bases:
BaseJaxCost,pybop.BaseLikelihoodA Log-Normal Likelihood function. This function represents the underlining observed data sampled from a Log-Normal distribution.
- Parameters:
problem (BaseProblem) – The problem to fit, of type pybop.BaseProblem.
sigma0 (float, optional (default=0.02)) – The standard deviation of the measured data.
- class jax_costs.JaxSumSquaredError(problem: pybop.BaseProblem)[source]#
Bases:
BaseJaxCostJax-based Sum of Squared Error cost function.
- Parameters:
problem (BaseProblem) – The problem to fit, of type pybop.BaseProblem.