Source code for pybop.costs.base_cost

from pybop import BaseProblem
from pybop.parameters.parameter import Inputs, Parameters


[docs] class BaseCost: """ Base class for defining cost functions. This class is intended to be subclassed to create specific cost functions for evaluating model predictions against a set of data. The cost function quantifies the goodness-of-fit between the model predictions and the observed data, with a lower cost value indicating a better fit. Parameters ---------- problem : object A problem instance containing the data and functions necessary for evaluating the cost function. _target : array-like An array containing the target data to fit. n_outputs : int The number of outputs in the model. """ def __init__(self, problem=None):
[docs] self.parameters = Parameters()
[docs] self.problem = problem
if isinstance(self.problem, BaseProblem): self._target = self.problem._target self.parameters.join(self.problem.parameters) self.n_outputs = self.problem.n_outputs self.signal = self.problem.signal @property
[docs] def n_parameters(self): return len(self.parameters)
[docs] def __call__(self, x, grad=None): """ Call the evaluate function for a given set of parameters. """ return self.evaluate(x, grad)
[docs] def evaluate(self, x, grad=None): """ Call the evaluate function for a given set of parameters. Parameters ---------- x : array-like The parameters for which to evaluate the cost. grad : array-like, optional An array to store the gradient of the cost function with respect to the parameters. Returns ------- float The calculated cost function value. Raises ------ ValueError If an error occurs during the calculation of the cost. """ inputs = self.parameters.verify(x) try: return self._evaluate(inputs, grad) except NotImplementedError as e: raise e except Exception as e: raise ValueError(f"Error in cost calculation: {e}")
[docs] def _evaluate(self, inputs: Inputs, grad=None): """ Calculate the cost function value for a given set of parameters. This method must be implemented by subclasses. Parameters ---------- inputs : Inputs The parameters for which to evaluate the cost. grad : array-like, optional An array to store the gradient of the cost function with respect to the parameters. Returns ------- float The calculated cost function value. Raises ------ NotImplementedError If the method has not been implemented by the subclass. """ raise NotImplementedError
[docs] def evaluateS1(self, x): """ Call _evaluateS1 for a given set of parameters. Parameters ---------- x : array-like The parameters for which to compute the cost and gradient. Returns ------- tuple A tuple containing the cost and the gradient. The cost is a float, and the gradient is an array-like of the same length as `x`. Raises ------ ValueError If an error occurs during the calculation of the cost or gradient. """ inputs = self.parameters.verify(x) try: return self._evaluateS1(inputs) except NotImplementedError as e: raise e except Exception as e: raise ValueError(f"Error in cost calculation: {e}")
[docs] def _evaluateS1(self, inputs: Inputs): """ Compute the cost and its gradient with respect to the parameters. Parameters ---------- inputs : Inputs The parameters for which to compute the cost and gradient. Returns ------- tuple A tuple containing the cost and the gradient. The cost is a float, and the gradient is an array-like of the same length as `x`. Raises ------ NotImplementedError If the method has not been implemented by the subclass. """ raise NotImplementedError