Source code for pybop.samplers.base_pints_sampler

import logging
import time
from dataclasses import dataclass

import numpy as np
import pints

from pybop import (
    BaseSampler,
    MultiChainProcessor,
    PopulationEvaluator,
    SingleChainProcessor,
)
from pybop._logging import Logger
from pybop._result import SamplingResult
from pybop.problems.problem import Problem
from pybop.samplers.base_sampler import SamplerOptions


@dataclass
[docs] class PintsSamplerOptions(SamplerOptions):
[docs] """
[docs] Pints sampler options.
Attributes ---------- n_chains : int The number of chains to concurrently sample from (default: 1). cov : float | np.ndarray Covariance matrix (default: 0.05). max_iterations : int Maximum number of iterations to run (default: 500). verbose : bool If `True`, additional information will be printed (default: False). warm_up_iterations : int Number of iterations to warm up the sampler (default: 250). chains_in_memory : bool Whether to store the chains in memory (default: True). log_to_screen : bool If `True` (default), the sampler will print information during the sampling. log_filename : str The name of the file to save the sampler log to (default: None). chain_files : list The name of the file to save the chains in (default: None). evaluation_files : list The name of the file to save the evaluations in (default: None). """
[docs] max_iterations: int = 500
[docs] chains_in_memory: bool = True
[docs] log_to_screen: bool = True
[docs] log_filename: str | None = None
[docs] initial_phase_iterations: int = 250
[docs] verbose: bool = False
[docs] warm_up_iterations: int = 0
[docs] chain_files: list[str] | None = None
[docs] evaluation_files: list[str] | None = None
[docs] def validate(self): """ Validate the options. Raises ------ ValueError If the options are invalid. """ super().validate() if self.cov is not None and any(np.atleast_1d(self.cov) <= 0): raise ValueError("Covariance values must be positive.") if self.warm_up_iterations < 0: raise ValueError("Number of warm-up steps must be non-negative.") if self.max_iterations < 1: raise ValueError("Maximum number of iterations must be greater than 0.") if self.initial_phase_iterations < 1: raise ValueError( "Number of initial phase iterations must be greater than 0." )
[docs] class BasePintsSampler(BaseSampler): """ Base class for PINTS samplers. This class extends the BaseSampler class to provide a common interface for PINTS samplers. The class provides a sample() method that can be used to sample from the posterior distribution using a PINTS sampler. Parameters ---------- log_pdf: pybop.Problem The negative unnormalised posterior distribution. sampler: pints.MCMCSampler The PINTS sampler to be used for sampling. options: Optional[PintsSamplerOptions] Options for the sampler, by default None. """ def __init__( self, log_pdf: Problem, sampler: type[pints.SingleChainMCMC | pints.MultiChainMCMC], options: PintsSamplerOptions | None = None, ): options = options or self.default_options() super().__init__(log_pdf, options=options)
[docs] self._sampler = sampler
[docs] self._max_iterations = options.max_iterations
[docs] self._chains_in_memory = options.chains_in_memory
[docs] self._log_to_screen = options.log_to_screen
[docs] self._log_filename = options.log_filename
[docs] self._initial_phase_iterations = options.initial_phase_iterations
[docs] self._verbose = options.verbose
[docs] self._warm_up = options.warm_up_iterations
[docs] self._n_parameters = len(self._log_pdf.parameters)
[docs] self._chain_files = options.chain_files
[docs] self._evaluation_files = options.evaluation_files
[docs] self._loop_iters = 0
[docs] self.iter_time = 0.0
# Single chain vs multiple chain samplers
[docs] self._single_chain = issubclass(self._sampler, pints.SingleChainMCMC)
# Construct the samplers object if self._single_chain: self._n_samplers = self.options.n_chains self._samplers = [self._sampler(x0, sigma0=self.cov0) for x0 in self.x0] else: self._n_samplers = 1 self._samplers = [self._sampler(self.options.n_chains, self.x0, self.cov0)] # Check for sensitivities from sampler and set evaluation
[docs] self._needs_sensitivities = self._samplers[0].needs_sensitivities()
# Check initial phase
[docs] self._initial_phase = self._samplers[0].needs_initial_phase()
if self._initial_phase: self.set_initial_phase_iterations() @staticmethod
[docs] def default_options() -> PintsSamplerOptions: """Get the default options for the sampler.""" return PintsSamplerOptions()
[docs] def _initialise_chain_processor(self): """ Initialise the appropriate chain processor based on configuration. """ if self._single_chain: self._chain_processor = SingleChainProcessor(self) else: self._chain_processor = MultiChainProcessor(self)
[docs] def run(self) -> np.ndarray: """ Executes the Monte Carlo sampling process and generates samples from the posterior distribution. This method orchestrates the entire sampling process, managing iterations, evaluations, logging, and stopping criteria. It initialises the necessary structures, handles both single and multi-chain scenarios, and manages parallel evaluation based on the configuration. Returns ------- SamplingResult The results including a numpy array containing the samples from the posterior distribution if chains are stored in memory, otherwise None. Raises ------ ValueError If no stopping criterion is set (i.e., _max_iterations is None). Details: - Checks and ensures at least one stopping criterion is set. - Initialises iterations, evaluations, and other required structures. - Sets up the parallel evaluator based on the configuration. - Handles the initial phase, if applicable, and manages intermediate steps in the sampling process. - Logs progress and relevant information based on the logging configuration. - Iterates through the sampling process, evaluating the log PDF, updating chains, and managing the stopping criteria. - Finalises and returns the collected samples, or None if chains are not stored in memory. """ self._start_time = time.time() self._initialise_logging() self._check_stopping_criteria() self._initialise_chain_processor() self._logger = Logger( minimising=self.log_pdf.minimising, verbose=False ) # print sampler logging instead evaluator = PopulationEvaluator( problem=self._log_pdf, minimise=False, with_sensitivities=self._needs_sensitivities, logger=self._logger, ) self.iteration = 0 self._check_initial_phase() self._initialise_storage() running = True while running: if self._initial_phase and self.iteration == self._initial_phase_iterations: self._end_initial_phase() xs = self._ask_for_samples() self.fxs = evaluator.evaluate(xs) self._process_chains() if self._single_chain: self._intermediate_step = min(self._n_samples) <= self.iteration # Skip the remaining loop logic if self._intermediate_step: continue self.iteration += 1 if self._log_to_screen and self._verbose: if self.iteration <= 10 or self.iteration % 50 == 0: timing_iterations = self.iteration - self._loop_iters elapsed_time = time.time() - self.iter_time iterations_per_second = ( timing_iterations / elapsed_time if elapsed_time > 0 else 0 ) logging.info( f"| Iteration: {self.iteration} | Iter/s: {iterations_per_second: .2f} |" ) self.iter_time = time.time() self._loop_iters = self.iteration if self._max_iterations and self.iteration >= self._max_iterations: running = False halt_message = ( "Maximum number of iterations (" + str(self._max_iterations) + ") reached." ) self._finalise_logging() if not self._chains_in_memory: return np.array([]).reshape((0, self._max_iterations, self._n_parameters)) if self._warm_up > 0: self._samples = self._samples[:, self._warm_up :, :] return SamplingResult( sampler=self, logger=self._logger, time=self._total_time, chains=self._samples, sampler_name=self._samplers[0].name(), message=halt_message, )
[docs] def _process_chains(self): """ Process chains using the appropriate processor. """ self._chain_processor.process_chain()
[docs] def _ask_for_samples(self): if self._single_chain: return [self._samplers[i].ask() for i in self._active] return self._samplers[0].ask()
[docs] def _check_initial_phase(self): """ Set initial phase if needed """ if self._initial_phase: for sampler in self._samplers: sampler.set_initial_phase(True)
[docs] def _end_initial_phase(self): for sampler in self._samplers: sampler.set_initial_phase(False) if self._log_to_screen: logging.info("Initial phase completed.")
[docs] def _check_stopping_criteria(self): """ Verify that at least one stopping criterion is defined. """ if self._max_iterations is None: raise ValueError("At least one stopping criterion must be set.")
[docs] def _initialise_storage(self): # Storage of the received samples n_chains = self.options.n_chains self._sampled_logpdf = np.zeros(n_chains) self._sampled_prior = np.zeros(n_chains) # Pre-allocate arrays for chain storage storage_shape = ( (n_chains, self._max_iterations, self._n_parameters) if self._chains_in_memory else (n_chains, self._n_parameters) ) self._samples = np.zeros(storage_shape) self._evaluations = np.zeros((n_chains, self._max_iterations)) # From PINTS: # Some samplers need intermediate steps, where `None` is returned instead # of a sample. But samplers can run asynchronously, so that one can return # `None` while another returns a sample. To deal with this, we maintain a # list of 'active' samplers that have not reached `max_iterations`, # and store the number of samples so far in each chain. if self._single_chain: self._active = list(range(n_chains)) self._n_samples = [0] * n_chains
[docs] def _initialise_logging(self): logging.basicConfig(format="%(message)s", level=logging.INFO) if self._log_to_screen: logging.info("Using " + str(self._samplers[0].name())) logging.info("Generating " + str(self.options.n_chains) + " chains.") if self._chain_files: logging.info("Writing chains to " + self._chain_files[0] + " etc.") if self._evaluation_files: logging.info( "Writing evaluations to " + self._evaluation_files[0] + " etc." )
[docs] def _finalise_logging(self): self._total_time = time.time() - self._start_time if self._log_to_screen: logging.info( f"Halting: Maximum number of iterations ({self.iteration}) reached." ) logging.info(f"Total time: {self._total_time} seconds.") logging.info(f"Total number of evaluations: ({self._logger.evaluations}).")
@property
[docs] def samplers(self): return self._samplers
@property
[docs] def active(self): return self._active
@property
[docs] def single_chain(self): return self._single_chain
@property
[docs] def sampled_logpdf(self): return self._sampled_logpdf
@property
[docs] def sampled_prior(self): return self._sampled_prior
@property
[docs] def iteration(self): return self._logger.iteration
@iteration.setter def iteration(self, value): self._logger.iteration = value @property
[docs] def needs_sensitivities(self): return self._needs_sensitivities
@property
[docs] def chains_in_memory(self): return self._chains_in_memory
@property
[docs] def n_samples(self): return self._n_samples
@property
[docs] def max_iterations(self): return self._max_iterations