Source code for pybop.costs.base_cost

import numpy as np

from pybop import BaseProblem


[docs] class BaseCost: """ Base class for defining cost functions. This class is intended to be subclassed to create specific cost functions for evaluating model predictions against a set of data. The cost function quantifies the goodness-of-fit between the model predictions and the observed data, with a lower cost value indicating a better fit. Parameters ---------- problem : object A problem instance containing the data and functions necessary for evaluating the cost function. _target : array-like An array containing the target data to fit. x0 : array-like The initial guess for the model parameters. bounds : tuple The bounds for the model parameters. sigma0 : scalar or array Initial standard deviation around ``x0``. Either a scalar value (one standard deviation for all coordinates) or an array with one entry per dimension. Not all methods will use this information. _n_parameters : int The number of parameters in the model. n_outputs : int The number of outputs in the model. """ def __init__(self, problem=None, sigma=None): self.problem = problem self.x0 = None self.bounds = None self.sigma0 = sigma self._minimising = True if isinstance(self.problem, BaseProblem): self._target = problem._target self.parameters = problem.parameters self.x0 = problem.x0 self.bounds = problem.bounds self.n_outputs = problem.n_outputs self.signal = problem.signal self._n_parameters = problem.n_parameters self.sigma0 = sigma or problem.sigma0 or np.zeros(self._n_parameters) @property
[docs] def n_parameters(self): return self._n_parameters
[docs] def __call__(self, x, grad=None): """ Call the evaluate function for a given set of parameters. """ return self.evaluate(x, grad)
[docs] def evaluate(self, x, grad=None): """ Call the evaluate function for a given set of parameters. Parameters ---------- x : array-like The parameters for which to evaluate the cost. grad : array-like, optional An array to store the gradient of the cost function with respect to the parameters. Returns ------- float The calculated cost function value. Raises ------ ValueError If an error occurs during the calculation of the cost. """ try: if self._minimising: return self._evaluate(x, grad) else: # minimise the negative cost return -self._evaluate(x, grad) except NotImplementedError as e: raise e except Exception as e: raise ValueError(f"Error in cost calculation: {e}")
[docs] def _evaluate(self, x, grad=None): """ Calculate the cost function value for a given set of parameters. This method must be implemented by subclasses. Parameters ---------- x : array-like The parameters for which to evaluate the cost. grad : array-like, optional An array to store the gradient of the cost function with respect to the parameters. Returns ------- float The calculated cost function value. Raises ------ NotImplementedError If the method has not been implemented by the subclass. """ raise NotImplementedError
[docs] def evaluateS1(self, x): """ Call _evaluateS1 for a given set of parameters. Parameters ---------- x : array-like The parameters for which to compute the cost and gradient. Returns ------- tuple A tuple containing the cost and the gradient. The cost is a float, and the gradient is an array-like of the same length as `x`. Raises ------ ValueError If an error occurs during the calculation of the cost or gradient. """ try: if self._minimising: return self._evaluateS1(x) else: # minimise the negative cost L, dl = self._evaluateS1(x) return -L, -dl except NotImplementedError as e: raise e except Exception as e: raise ValueError(f"Error in cost calculation: {e}")
[docs] def _evaluateS1(self, x): """ Compute the cost and its gradient with respect to the parameters. Parameters ---------- x : array-like The parameters for which to compute the cost and gradient. Returns ------- tuple A tuple containing the cost and the gradient. The cost is a float, and the gradient is an array-like of the same length as `x`. Raises ------ NotImplementedError If the method has not been implemented by the subclass. """ raise NotImplementedError