Source code for pybop.samplers.base_pints_sampler

import logging
import time
import warnings
from dataclasses import dataclass

import numpy as np
import pints

from pybop import (
    BaseSampler,
    MultiChainProcessor,
    PopulationEvaluator,
    SingleChainProcessor,
)
from pybop._logging import Logger
from pybop.problems.log_pdf import LogPDF
from pybop.samplers.base_sampler import SamplerOptions, SamplingResult


[docs] @dataclass class PintsSamplerOptions(SamplerOptions): """ Pints sampler options. Attributes ---------- n_chains : int The number of chains to concurrently sample from (default: 1). max_iterations : int Maximum number of iterations to run (default: 500). verbose : bool If `True`, additional information will be printed (default: False). warm_up_iterations : int Number of iterations to warm up the sampler (default: 250). chains_in_memory : bool Whether to store the chains in memory (default: True). log_to_screen : bool If `True` (default), the sampler will print information during the sampling. log_filename : str The name of the file to save the sampler log to (default: None). chain_files : list The name of the file to save the chains in (default: None). evaluation_files : list The name of the file to save the evaluations in (default: None). """ max_iterations: int = 500 chains_in_memory: bool = True log_to_screen: bool = True log_filename: str | None = None initial_phase_iterations: int = 250 verbose: bool = False warm_up_iterations: int = 0 chain_files: list[str] | None = None evaluation_files: list[str] | None = None
[docs] def validate(self): """ Validate the options. Raises ------ ValueError If the options are invalid. """ super().validate() if self.warm_up_iterations < 0: raise ValueError("Number of warm-up steps must be non-negative.") if self.max_iterations < 1: raise ValueError("Maximum number of iterations must be greater than 0.") if self.initial_phase_iterations < 1: raise ValueError( "Number of initial phase iterations must be greater than 0." )
[docs] class BasePintsSampler(BaseSampler): """ Base class for PINTS samplers. This class extends the BaseSampler class to provide a common interface for PINTS samplers. The class provides a sample() method that can be used to sample from the posterior distribution using a PINTS sampler. Parameters ---------- log_pdf: pybop.LogPDF The negative unnormalised posterior distribution. sampler: pints.MCMCSampler The PINTS sampler to be used for sampling. options: Optional[PintsSamplerOptions] Options for the sampler, by default None. """ def __init__( self, log_pdf: LogPDF, sampler: type[pints.SingleChainMCMC | pints.MultiChainMCMC], options: PintsSamplerOptions | None = None, ): options = options or self.default_options() super().__init__(log_pdf, options=options) self._sampler = sampler self._max_iterations = options.max_iterations self._chains_in_memory = options.chains_in_memory self._log_to_screen = options.log_to_screen self._log_filename = options.log_filename self._initial_phase_iterations = options.initial_phase_iterations self._verbose = options.verbose self._warm_up = options.warm_up_iterations self._chain_files = options.chain_files self._evaluation_files = options.evaluation_files self._loop_iters = 0 self.iter_time = 0.0 if self.log_pdf.parameters.get_bounds(transformed=True): warnings.warn( "NOTE: Parameter bounds are ignored by PINTS samplers. " "Samples that lie outside the bounds will return an infinite cost.", stacklevel=2, ) # Single chain vs multiple chain samplers self._single_chain = issubclass(self._sampler, pints.SingleChainMCMC) # Construct the samplers object if self._single_chain: self._n_samplers = self.options.n_chains self._samplers = [ self._sampler(x0=mean0, sigma0=self.cov0) for mean0 in self.mean0 ] else: self._n_samplers = 1 self._samplers = [ self._sampler( chains=self.options.n_chains, x0=self.mean0, sigma0=self.cov0 ) ] # Check for sensitivities from sampler and set evaluation self._needs_sensitivities = self._samplers[0].needs_sensitivities() # Check initial phase self._initial_phase = self._samplers[0].needs_initial_phase() if self._initial_phase: self.set_initial_phase_iterations()
[docs] @staticmethod def default_options() -> PintsSamplerOptions: """Get the default options for the sampler.""" return PintsSamplerOptions()
[docs] def _initialise_chain_processor(self): """ Initialise the appropriate chain processor based on configuration. """ if self._single_chain: self._chain_processor = SingleChainProcessor(self) else: self._chain_processor = MultiChainProcessor(self)
[docs] def run(self) -> np.ndarray: """ Executes the Monte Carlo sampling process and generates samples from the posterior distribution. This method orchestrates the entire sampling process, managing iterations, evaluations, logging, and stopping criteria. It initialises the necessary structures, handles both single and multi-chain scenarios, and manages parallel evaluation based on the configuration. Returns ------- SamplingResult The results including a numpy array containing the samples from the posterior distribution if chains are stored in memory, otherwise None. Raises ------ ValueError If no stopping criterion is set (i.e., _max_iterations is None). Details: - Checks and ensures at least one stopping criterion is set. - Initialises iterations, evaluations, and other required structures. - Sets up the parallel evaluator based on the configuration. - Handles the initial phase, if applicable, and manages intermediate steps in the sampling process. - Logs progress and relevant information based on the logging configuration. - Iterates through the sampling process, evaluating the log PDF, updating chains, and managing the stopping criteria. - Finalises and returns the collected samples, or None if chains are not stored in memory. """ self._start_time = time.time() self._initialise_logging() self._check_stopping_criteria() self._initialise_chain_processor() self._logger = Logger( minimising=self.log_pdf.minimising, verbose=False ) # print sampler logging instead evaluator = PopulationEvaluator( problem=self._log_pdf, minimise=False, with_sensitivities=self._needs_sensitivities, logger=self._logger, ) self.iteration = 0 self._check_initial_phase() self._initialise_storage() running = True while running: if self._initial_phase and self.iteration == self._initial_phase_iterations: self._end_initial_phase() xs = self._ask_for_samples() self.fxs = evaluator.evaluate(xs) self._process_chains() if self._single_chain: self._intermediate_step = min(self._n_samples) <= self.iteration # Skip the remaining loop logic if self._intermediate_step: continue self.iteration += 1 if self._log_to_screen and self._verbose: if self.iteration <= 10 or self.iteration % 50 == 0: timing_iterations = self.iteration - self._loop_iters elapsed_time = time.time() - self.iter_time iterations_per_second = ( timing_iterations / elapsed_time if elapsed_time > 0 else 0 ) logging.info( f"| Iteration: {self.iteration} | Iter/s: {iterations_per_second: .2f} |" ) self.iter_time = time.time() self._loop_iters = self.iteration if self._max_iterations and self.iteration >= self._max_iterations: running = False halt_message = ( "Maximum number of iterations (" + str(self._max_iterations) + ") reached." ) self._finalise_logging() if not self._chains_in_memory: return np.array([]).reshape((0, self._max_iterations, self._n_parameters)) if self._warm_up > 0: self._samples = self._samples[:, self._warm_up :, :] return SamplingResult( sampler=self, time=self._total_time, chains=self._samples, method_name=self._samplers[0].name(), message=halt_message, )
[docs] def _process_chains(self): """ Process chains using the appropriate processor. """ self._chain_processor.process_chain()
[docs] def _ask_for_samples(self): if self._single_chain: return [self._samplers[i].ask() for i in self._active] return self._samplers[0].ask()
[docs] def _check_initial_phase(self): """ Set initial phase if needed """ if self._initial_phase: for sampler in self._samplers: sampler.set_initial_phase(True)
[docs] def _end_initial_phase(self): for sampler in self._samplers: sampler.set_initial_phase(False) if self._log_to_screen: logging.info("Initial phase completed.")
[docs] def _check_stopping_criteria(self): """ Verify that at least one stopping criterion is defined. """ if self._max_iterations is None: raise ValueError("At least one stopping criterion must be set.")
[docs] def _initialise_storage(self): # Storage of the received samples n_chains = self.options.n_chains self._sampled_logpdf = np.zeros(n_chains) self._sampled_prior = np.zeros(n_chains) # Pre-allocate arrays for chain storage storage_shape = ( (n_chains, self._max_iterations, self._n_parameters) if self._chains_in_memory else (n_chains, self._n_parameters) ) self._samples = np.zeros(storage_shape) self._evaluations = np.zeros((n_chains, self._max_iterations)) # From PINTS: # Some samplers need intermediate steps, where `None` is returned instead # of a sample. But samplers can run asynchronously, so that one can return # `None` while another returns a sample. To deal with this, we maintain a # list of 'active' samplers that have not reached `max_iterations`, # and store the number of samples so far in each chain. if self._single_chain: self._active = list(range(n_chains)) self._n_samples = [0] * n_chains
[docs] def _initialise_logging(self): logging.basicConfig(format="%(message)s", level=logging.INFO) if self._log_to_screen: logging.info("Using " + str(self._samplers[0].name())) logging.info("Generating " + str(self.options.n_chains) + " chains.") if self._chain_files: logging.info("Writing chains to " + self._chain_files[0] + " etc.") if self._evaluation_files: logging.info( "Writing evaluations to " + self._evaluation_files[0] + " etc." )
[docs] def _finalise_logging(self): self._total_time = time.time() - self._start_time if self._log_to_screen: logging.info( f"Halting: Maximum number of iterations ({self.iteration}) reached." ) logging.info(f"Total time: {self._total_time} seconds.") logging.info(f"Total number of evaluations: ({self._logger.evaluations}).")
@property def samplers(self): return self._samplers @property def active(self): return self._active @property def single_chain(self): return self._single_chain @property def sampled_logpdf(self): return self._sampled_logpdf @property def sampled_prior(self): return self._sampled_prior @property def iteration(self): return self._logger.iteration @iteration.setter def iteration(self, value): self._logger.iteration = value @property def needs_sensitivities(self): return self._needs_sensitivities @property def chains_in_memory(self): return self._chains_in_memory @property def n_samples(self): return self._n_samples @property def max_iterations(self): return self._max_iterations