Source code for pybop.optimisers.base_optimiser

from dataclasses import dataclass

from pybop._logging import Logger
from pybop._result import OptimisationResult
from pybop.problems.problem import Problem


@dataclass
[docs] class OptimiserOptions: """ A base class for optimiser options. Attributes ---------- multistart : int Number of times to multistart the optimiser. verbose : bool The verbosity level. verbose_print_rate : int The distance between iterations to print verbose output. """
[docs] multistart: int = 1
[docs] verbose: bool = False
[docs] verbose_print_rate: int = 50
[docs] def validate(self): """ Validate the options. Raises ------ ValueError If the options are invalid. """ if self.multistart < 1: raise ValueError("Multistart must be greater than or equal to 1.") if self.verbose_print_rate < 1: raise ValueError("Verbose print rate must be greater than or equal to 1.")
[docs] class BaseOptimiser: """ A base class for defining optimisation methods. This class serves as a base class for creating optimisers. It provides a basic structure for an optimisation algorithm, including the initial setup and a method stub for performing the optimisation process. Child classes should override _set_up_optimiser and the _run method with a specific algorithm. Parameters ---------- problem : pybop.Problem The problem to optimise. options: pybop.OptimiserOptions , optional Options for the optimiser, such as multistart. """
[docs] default_max_iterations = 1000
def __init__( self, problem: Problem, options: OptimiserOptions | None = None, ): if not isinstance(problem, Problem): raise TypeError(f"Expected a pybop.Problem instance, got {type(problem)}")
[docs] self._problem = problem
[docs] self._logger = None
options = options or self.default_options() options.validate()
[docs] self._options = options
[docs] self.verbose = options.verbose
[docs] self.verbose_print_rate = options.verbose_print_rate
[docs] self._multistart = options.multistart
[docs] self._needs_sensitivities = None # to be overridden during set_up_optimiser
self._set_up_optimiser() if self._needs_sensitivities and not self._problem.has_sensitivities: raise ValueError( "This optimiser needs sensitivities, but they are not available from this problem." ) @staticmethod
[docs] def default_options() -> OptimiserOptions: """Returns the default options for the optimiser.""" return OptimiserOptions()
@property
[docs] def problem(self) -> Problem: """Returns the optimisation problem object.""" return self._problem
@property
[docs] def options(self) -> OptimiserOptions: """Returns the options for the optimiser.""" return self._options
[docs] def _set_up_optimiser(self): """ Parse optimiser options and prepare the optimiser. This method should be implemented by child classes. Raises ------ NotImplementedError If the method has not been implemented by the subclass. """ raise NotImplementedError
[docs] def _run(self) -> OptimisationResult: """ Contains the logic for the optimisation algorithm. This method should be implemented by child classes to perform the actual optimisation. Raises ------ NotImplementedError If the method has not been implemented by the subclass. """ raise NotImplementedError
[docs] def name(self) -> str: """ Returns the name of the optimiser, to be overwritten by child classes. Returns ------- str The name of the optimiser """ raise NotImplementedError # pragma: no cover
[docs] def run(self) -> OptimisationResult: """ Run the optimisation and return the optimised parameters and final cost. Returns ------- results: OptimisationResult The pybop optimisation result class. """ results = [] for i in range(self._multistart): if i >= 1: if not self.problem.parameters.priors(): raise RuntimeError("Priors must be provided for multi-start") initial_values = self.problem.parameters.sample_from_priors(1)[0] self.problem.parameters.update(initial_values=initial_values) self._set_up_optimiser() results.append(self._run()) result = OptimisationResult.combine(results) if self.options.verbose: print(result) return result
@property
[docs] def logger(self) -> Logger | None: return self._logger