import logging
import time
from dataclasses import dataclass
import numpy as np
import pints
from pybop import (
BaseSampler,
MultiChainProcessor,
PopulationEvaluator,
SingleChainProcessor,
)
from pybop._logging import Logger
from pybop._result import SamplingResult
from pybop.problems.problem import Problem
from pybop.samplers.base_sampler import SamplerOptions
@dataclass
[docs]
class PintsSamplerOptions(SamplerOptions):
Attributes
----------
n_chains : int
The number of chains to concurrently sample from (default: 1).
cov : float | np.ndarray
Covariance matrix (default: 0.05).
max_iterations : int
Maximum number of iterations to run (default: 500).
verbose : bool
If `True`, additional information will be printed (default: False).
warm_up_iterations : int
Number of iterations to warm up the sampler (default: 250).
chains_in_memory : bool
Whether to store the chains in memory (default: True).
log_to_screen : bool
If `True` (default), the sampler will print information during the sampling.
log_filename : str
The name of the file to save the sampler log to (default: None).
chain_files : list
The name of the file to save the chains in (default: None).
evaluation_files : list
The name of the file to save the evaluations in (default: None).
"""
[docs]
max_iterations: int = 500
[docs]
chains_in_memory: bool = True
[docs]
log_to_screen: bool = True
[docs]
log_filename: str | None = None
[docs]
initial_phase_iterations: int = 250
[docs]
warm_up_iterations: int = 0
[docs]
chain_files: list[str] | None = None
[docs]
evaluation_files: list[str] | None = None
[docs]
def validate(self):
"""
Validate the options.
Raises
------
ValueError
If the options are invalid.
"""
super().validate()
if self.cov is not None and any(np.atleast_1d(self.cov) <= 0):
raise ValueError("Covariance values must be positive.")
if self.warm_up_iterations < 0:
raise ValueError("Number of warm-up steps must be non-negative.")
if self.max_iterations < 1:
raise ValueError("Maximum number of iterations must be greater than 0.")
if self.initial_phase_iterations < 1:
raise ValueError(
"Number of initial phase iterations must be greater than 0."
)
[docs]
class BasePintsSampler(BaseSampler):
"""
Base class for PINTS samplers.
This class extends the BaseSampler class to provide a common interface for
PINTS samplers. The class provides a sample() method that can be used to
sample from the posterior distribution using a PINTS sampler.
Parameters
----------
log_pdf: pybop.Problem
The negative unnormalised posterior distribution.
sampler: pints.MCMCSampler
The PINTS sampler to be used for sampling.
options: Optional[PintsSamplerOptions]
Options for the sampler, by default None.
"""
def __init__(
self,
log_pdf: Problem,
sampler: type[pints.SingleChainMCMC | pints.MultiChainMCMC],
options: PintsSamplerOptions | None = None,
):
options = options or self.default_options()
super().__init__(log_pdf, options=options)
[docs]
self._sampler = sampler
[docs]
self._max_iterations = options.max_iterations
[docs]
self._chains_in_memory = options.chains_in_memory
[docs]
self._log_to_screen = options.log_to_screen
[docs]
self._log_filename = options.log_filename
[docs]
self._initial_phase_iterations = options.initial_phase_iterations
[docs]
self._verbose = options.verbose
[docs]
self._warm_up = options.warm_up_iterations
[docs]
self._n_parameters = len(self._log_pdf.parameters)
[docs]
self._chain_files = options.chain_files
[docs]
self._evaluation_files = options.evaluation_files
# Single chain vs multiple chain samplers
[docs]
self._single_chain = issubclass(self._sampler, pints.SingleChainMCMC)
# Construct the samplers object
if self._single_chain:
self._n_samplers = self.options.n_chains
self._samplers = [self._sampler(x0, sigma0=self.cov0) for x0 in self.x0]
else:
self._n_samplers = 1
self._samplers = [self._sampler(self.options.n_chains, self.x0, self.cov0)]
# Check for sensitivities from sampler and set evaluation
[docs]
self._needs_sensitivities = self._samplers[0].needs_sensitivities()
# Check initial phase
[docs]
self._initial_phase = self._samplers[0].needs_initial_phase()
if self._initial_phase:
self.set_initial_phase_iterations()
@staticmethod
[docs]
def default_options() -> PintsSamplerOptions:
"""Get the default options for the sampler."""
return PintsSamplerOptions()
[docs]
def _initialise_chain_processor(self):
"""
Initialise the appropriate chain processor based on configuration.
"""
if self._single_chain:
self._chain_processor = SingleChainProcessor(self)
else:
self._chain_processor = MultiChainProcessor(self)
[docs]
def run(self) -> np.ndarray:
"""
Executes the Monte Carlo sampling process and generates samples
from the posterior distribution.
This method orchestrates the entire sampling process, managing
iterations, evaluations, logging, and stopping criteria. It
initialises the necessary structures, handles both single and
multi-chain scenarios, and manages parallel evaluation based on
the configuration.
Returns
-------
SamplingResult
The results including a numpy array containing the samples from the posterior
distribution if chains are stored in memory, otherwise None.
Raises
------
ValueError
If no stopping criterion is set (i.e., _max_iterations is None).
Details:
- Checks and ensures at least one stopping criterion is set.
- Initialises iterations, evaluations, and other required
structures.
- Sets up the parallel evaluator based on the configuration.
- Handles the initial phase, if applicable, and manages
intermediate steps in the sampling process.
- Logs progress and relevant information based on the logging
configuration.
- Iterates through the sampling process, evaluating the log
PDF, updating chains, and managing the stopping criteria.
- Finalises and returns the collected samples, or None if
chains are not stored in memory.
"""
self._start_time = time.time()
self._initialise_logging()
self._check_stopping_criteria()
self._initialise_chain_processor()
self._logger = Logger(
minimising=self.log_pdf.minimising, verbose=False
) # print sampler logging instead
evaluator = PopulationEvaluator(
problem=self._log_pdf,
minimise=False,
with_sensitivities=self._needs_sensitivities,
logger=self._logger,
)
self.iteration = 0
self._check_initial_phase()
self._initialise_storage()
running = True
while running:
if self._initial_phase and self.iteration == self._initial_phase_iterations:
self._end_initial_phase()
xs = self._ask_for_samples()
self.fxs = evaluator.evaluate(xs)
self._process_chains()
if self._single_chain:
self._intermediate_step = min(self._n_samples) <= self.iteration
# Skip the remaining loop logic
if self._intermediate_step:
continue
self.iteration += 1
if self._log_to_screen and self._verbose:
if self.iteration <= 10 or self.iteration % 50 == 0:
timing_iterations = self.iteration - self._loop_iters
elapsed_time = time.time() - self.iter_time
iterations_per_second = (
timing_iterations / elapsed_time if elapsed_time > 0 else 0
)
logging.info(
f"| Iteration: {self.iteration} | Iter/s: {iterations_per_second: .2f} |"
)
self.iter_time = time.time()
self._loop_iters = self.iteration
if self._max_iterations and self.iteration >= self._max_iterations:
running = False
halt_message = (
"Maximum number of iterations (" + str(self._max_iterations) + ") reached."
)
self._finalise_logging()
if not self._chains_in_memory:
return np.array([]).reshape((0, self._max_iterations, self._n_parameters))
if self._warm_up > 0:
self._samples = self._samples[:, self._warm_up :, :]
return SamplingResult(
sampler=self,
logger=self._logger,
time=self._total_time,
chains=self._samples,
sampler_name=self._samplers[0].name(),
message=halt_message,
)
[docs]
def _process_chains(self):
"""
Process chains using the appropriate processor.
"""
self._chain_processor.process_chain()
[docs]
def _ask_for_samples(self):
if self._single_chain:
return [self._samplers[i].ask() for i in self._active]
return self._samplers[0].ask()
[docs]
def _check_initial_phase(self):
"""
Set initial phase if needed
"""
if self._initial_phase:
for sampler in self._samplers:
sampler.set_initial_phase(True)
[docs]
def _end_initial_phase(self):
for sampler in self._samplers:
sampler.set_initial_phase(False)
if self._log_to_screen:
logging.info("Initial phase completed.")
[docs]
def _check_stopping_criteria(self):
"""
Verify that at least one stopping criterion is defined.
"""
if self._max_iterations is None:
raise ValueError("At least one stopping criterion must be set.")
[docs]
def _initialise_storage(self):
# Storage of the received samples
n_chains = self.options.n_chains
self._sampled_logpdf = np.zeros(n_chains)
self._sampled_prior = np.zeros(n_chains)
# Pre-allocate arrays for chain storage
storage_shape = (
(n_chains, self._max_iterations, self._n_parameters)
if self._chains_in_memory
else (n_chains, self._n_parameters)
)
self._samples = np.zeros(storage_shape)
self._evaluations = np.zeros((n_chains, self._max_iterations))
# From PINTS:
# Some samplers need intermediate steps, where `None` is returned instead
# of a sample. But samplers can run asynchronously, so that one can return
# `None` while another returns a sample. To deal with this, we maintain a
# list of 'active' samplers that have not reached `max_iterations`,
# and store the number of samples so far in each chain.
if self._single_chain:
self._active = list(range(n_chains))
self._n_samples = [0] * n_chains
[docs]
def _initialise_logging(self):
logging.basicConfig(format="%(message)s", level=logging.INFO)
if self._log_to_screen:
logging.info("Using " + str(self._samplers[0].name()))
logging.info("Generating " + str(self.options.n_chains) + " chains.")
if self._chain_files:
logging.info("Writing chains to " + self._chain_files[0] + " etc.")
if self._evaluation_files:
logging.info(
"Writing evaluations to " + self._evaluation_files[0] + " etc."
)
[docs]
def _finalise_logging(self):
self._total_time = time.time() - self._start_time
if self._log_to_screen:
logging.info(
f"Halting: Maximum number of iterations ({self.iteration}) reached."
)
logging.info(f"Total time: {self._total_time} seconds.")
logging.info(f"Total number of evaluations: ({self._logger.evaluations}).")
@property
[docs]
def samplers(self):
return self._samplers
@property
[docs]
def active(self):
return self._active
@property
[docs]
def single_chain(self):
return self._single_chain
@property
[docs]
def sampled_logpdf(self):
return self._sampled_logpdf
@property
[docs]
def sampled_prior(self):
return self._sampled_prior
@property
[docs]
def iteration(self):
return self._logger.iteration
@iteration.setter
def iteration(self, value):
self._logger.iteration = value
@property
[docs]
def needs_sensitivities(self):
return self._needs_sensitivities
@property
[docs]
def chains_in_memory(self):
return self._chains_in_memory
@property
[docs]
def n_samples(self):
return self._n_samples
@property
[docs]
def max_iterations(self):
return self._max_iterations