jax_costs#

Classes#

BaseJaxCost

Base JAX cost function.

JaxGaussianLogLikelihoodKnownSigma

A Jax implementation of the Gaussian Likelihood function.

JaxLogNormalLikelihood

A Log-Normal Likelihood function. This function represents the

JaxSumSquaredError

Jax-based Sum of Squared Error cost function.

Module Contents#

class jax_costs.BaseJaxCost(problem: pybop.BaseProblem)[source]#

Bases: pybop.BaseCost

Base JAX cost function.

This class implements a cost function using JAX for automatic differentiation and efficient gradient computation. It is designed to work with problems defined in the BaseProblem framework and supports gradient computation.

problem[source]#

The problem object containing the model, data, and relevant configurations.

Type:

BaseProblem

model[source]#

The model associated with the problem.

Type:

BaseModel

n_data[source]#

The number of data points in the problem.

Type:

int

_update_solver_sensitivities(calculate_grad: bool) None[source]#

Updates the solver’s sensitivity calculation based on the gradient requirement.

Parameters:

calculate_grad (bool) – Whether gradient calculation is required.

static check_sigma0(sigma0)[source]#

Validates the sigma0 parameter.

single_call(inputs: pybop.Inputs, calculate_grad: bool = False) numpy.array | tuple[float, numpy.ndarray][source]#

Compute the JAX cost function and (optionally) its gradient for given inputs.

Parameters:
  • inputs (Inputs) – Input data for model evaluation.

  • calculate_grad (bool, optional) – Whether to calculate and return the gradient.

Returns:

The computed cost or a tuple of cost and gradient.

Return type:

Union[np.ndarray, tuple[float, np.ndarray]]

model[source]#
n_data[source]#
class jax_costs.JaxGaussianLogLikelihoodKnownSigma(problem: pybop.BaseProblem, sigma0: list[float] | float)[source]#

Bases: BaseJaxCost, pybop.BaseLikelihood

A Jax implementation of the Gaussian Likelihood function. This function represents the underlining observed data sampled from a Gaussian distribution with known noise, sigma0.

Parameters:
  • problem (BaseProblem) – The problem to fit of type pybop.BaseProblem

  • sigma0 (float, optional) – The variance in the measured data

evaluate(inputs)[source]#

Computes the Gaussian log-likelihood.

_multip[source]#
_offset[source]#
sigma[source]#
sigma2[source]#
class jax_costs.JaxLogNormalLikelihood(problem: pybop.BaseProblem, sigma0: list[float] | float)[source]#

Bases: BaseJaxCost, pybop.BaseLikelihood

A Log-Normal Likelihood function. This function represents the underlining observed data sampled from a Log-Normal distribution.

Parameters:
  • problem (BaseProblem) – The problem to fit, of type pybop.BaseProblem.

  • sigma0 (float, optional (default=0.02)) – The standard deviation of the measured data.

_precompute()[source]#
evaluate(inputs)[source]#

Computes the log-normal likelihood.

_log_target_sum[source]#
_offset[source]#
_target_as_array[source]#
sigma[source]#
sigma2[source]#
class jax_costs.JaxSumSquaredError(problem: pybop.BaseProblem)[source]#

Bases: BaseJaxCost

Jax-based Sum of Squared Error cost function.

Parameters:

problem (BaseProblem) – The problem to fit, of type pybop.BaseProblem.

evaluate(inputs)[source]#

Computes the sum of squared errors between predictions and targets.

evaluate(inputs)[source]#

Evaluates the sum of squared error for the given predictions.