jax_costs#

Classes#

BaseJaxCost

Base JAX cost function.

JaxGaussianLogLikelihoodKnownSigma

A Jax implementation of the Gaussian Likelihood function.

JaxLogNormalLikelihood

A Log-Normal Likelihood function. This function represents the

JaxSumSquaredError

Jax-based Sum of Squared Error cost function.

Module Contents#

class jax_costs.BaseJaxCost(problem: pybop.BaseProblem)[source]#

Bases: pybop.BaseCost

Base JAX cost function.

This class implements a cost function using JAX for automatic differentiation and efficient gradient computation. It is designed to work with problems defined in the BaseProblem framework and supports transformations, gradient computation, and minimisation for optimisation tasks.

problem[source]#

The problem object containing the model, data, and relevant configurations.

Type:

BaseProblem

model[source]#

The model associated with the problem.

Type:

BaseModel

n_data[source]#

The number of data points in the problem.

Type:

int

has_transform#

Indicates whether input transformations are applied.

Type:

bool

__call__(inputs: pybop.Inputs, calculate_grad: bool = False, apply_transform: bool = False, for_optimiser: bool = False) numpy.array | tuple[float, numpy.ndarray][source]#

Compute the JAX cost function and (optionally) its gradient for given inputs.

Parameters:
  • inputs (Inputs) – Input data for model evaluation.

  • calculate_grad (bool, optional) – Whether to calculate and return the gradient.

  • apply_transform (bool, optional) – Whether to apply transformation to the inputs.

  • for_optimiser (bool, optional) – Whether the function is being called for an optimiser.

Returns:

The computed cost or a tuple of cost and gradient.

Return type:

Union[np.ndarray, tuple[float, np.ndarray]]

_update_solver_sensitivities(calculate_grad: bool) None[source]#

Updates the solver’s sensitivity calculation based on the gradient requirement.

Parameters:

calculate_grad (bool) – Whether gradient calculation is required.

static check_sigma0(sigma0)[source]#

Validates the sigma0 parameter.

observed_fisher(inputs: pybop.Inputs)[source]#

Compute the observed Fisher Information Matrix (FIM) for the given inputs.

The FIM is computed using the square of the gradient, divided by the number of data points. This is an approximation since the Hessian is not available.

Returns:

The observed Fisher Information Matrix.

Return type:

jnp.ndarray

model[source]#
n_data[source]#
class jax_costs.JaxGaussianLogLikelihoodKnownSigma(problem: pybop.BaseProblem, sigma0: list[float] | float)[source]#

Bases: BaseJaxCost, pybop.BaseLikelihood

A Jax implementation of the Gaussian Likelihood function. This function represents the underlining observed data sampled from a Gaussian distribution with known noise, sigma0.

Parameters:
  • problem (BaseProblem) – The problem to fit of type pybop.BaseProblem

  • sigma0 (float, optional) – The variance in the measured data

evaluate(inputs)[source]#

Computes the Gaussian log-likelihood.

_multip[source]#
_offset[source]#
sigma[source]#
sigma2[source]#
class jax_costs.JaxLogNormalLikelihood(problem: pybop.BaseProblem, sigma0: list[float] | float)[source]#

Bases: BaseJaxCost, pybop.BaseLikelihood

A Log-Normal Likelihood function. This function represents the underlining observed data sampled from a Log-Normal distribution.

Parameters:
  • problem (BaseProblem) – The problem to fit, of type pybop.BaseProblem.

  • sigma0 (float, optional (default=0.02)) – The standard deviation of the measured data.

_precompute()[source]#
evaluate(inputs)[source]#

Computes the log-normal likelihood.

_log_target_sum[source]#
_offset[source]#
_target_as_array[source]#
sigma[source]#
sigma2[source]#
class jax_costs.JaxSumSquaredError(problem: pybop.BaseProblem)[source]#

Bases: BaseJaxCost

Jax-based Sum of Squared Error cost function.

Parameters:

problem (BaseProblem) – The problem to fit, of type pybop.BaseProblem.

evaluate(inputs)[source]#

Computes the sum of squared errors between predictions and targets.

evaluate(inputs)[source]#

Evaluates the sum of squared error for the given predictions.