Source code for pybop._evaluation

import jax.numpy as jnp
import numpy as np
from pints import Evaluator as PintsEvaluator


[docs] class SequentialJaxEvaluator(PintsEvaluator): """ Sequential evaluates a function (or callable object) for either a single or multiple positions. This class is based off the PintsSequentialEvaluator class, with additions for PyBOP's JAX cost classes. Parameters ---------- function : callable The function to evaluate. This function should accept an input and optionally additional arguments, returning either a single value or a tuple. args : sequence, optional A sequence containing extra arguments to be passed to the function. If specified, the function will be called as `function(x, *args)`. """
[docs] def _evaluate(self, positions): scores = [self._function(x, *self._args) for x in positions] # If gradient provided, convert jnp to np and return if isinstance(scores[0], tuple): return [(score[0].item(), score[1]) for score in scores] return np.asarray(scores)
[docs] class SciPyEvaluator(PintsEvaluator): """ Evaluates a function (or callable object) for the SciPy optimisers for either a single or multiple positions. Parameters ---------- function : callable The function to evaluate. This function should accept an input and optionally additional arguments, returning either a single value or a tuple. args : sequence, optional A sequence containing extra arguments to be passed to the function. If specified, the function will be called as `function(x, *args)`. """
[docs] def _evaluate(self, positions): scores = [self._function(x, *self._args) for x in [positions]] if not isinstance(scores[0], tuple): return np.asarray(scores)[0] # If gradient provided, convert jnp to np and return if isinstance(scores[0][0], jnp.ndarray): return [(score[0].item(), score[1]) for score in scores][0] return [(score[0], score[1]) for score in scores][0]