import logging
import time
from functools import partial
from typing import Optional, Union
import numpy as np
from pints import (
MultiSequentialEvaluator,
ParallelEvaluator,
SequentialEvaluator,
SingleChainMCMC,
)
from pybop import (
BaseCost,
BaseSampler,
LogPosterior,
MultiChainProcessor,
SingleChainProcessor,
)
[docs]
class BasePintsSampler(BaseSampler):
"""
Base class for PINTS samplers.
This class extends the BaseSampler class to provide a common interface for
PINTS samplers. The class provides a sample() method that can be used to
sample from the posterior distribution using a PINTS sampler.
Parameters
----------
log_pdf : pybop.LogPosterior or List[pybop.LogPosterior]
An object to be sampled, currently supports the pybop.LogPosterior class.
sampler : pybop.Sampler
The sampling algorithm to use.
chains : int
Number of chains to run concurrently. Each chain contains separate markov samples from
the log_pdf.
x0 : numpy.ndarray
Initial values of the parameters for the optimisation.
cov0 : list-like
Initial covariance for the chains in the parameters. Either a scalar value
(same for all coordinates) or an array with one entry per dimension.
kwargs
Additional keyword arguments.
"""
def __init__(
self,
log_pdf: Union[LogPosterior, list[LogPosterior]],
sampler,
chains: int = 1,
warm_up=None,
x0=None,
cov0=0.1,
**kwargs,
):
super().__init__(log_pdf, x0, chains, cov0)
# Set kwargs
[docs]
self._max_iterations = kwargs.get("max_iterations", 500)
[docs]
self._log_to_screen = kwargs.get("log_to_screen", True)
[docs]
self._log_filename = kwargs.get("log_filename", None)
[docs]
self._initial_phase_iterations = kwargs.get("initial_phase_iterations", 250)
[docs]
self._chains_in_memory = kwargs.get("chains_in_memory", True)
[docs]
self._chain_files = kwargs.get("chain_files", None)
[docs]
self._evaluation_files = kwargs.get("evaluation_files", None)
[docs]
self._parallel = kwargs.get("parallel", False)
[docs]
self._verbose = kwargs.get("verbose", False)
[docs]
self.iter_time = float(0)
[docs]
self._warm_up = warm_up
# Check log_pdf
if isinstance(self._log_pdf, BaseCost):
self._multi_log_pdf = False
else:
if len(self._log_pdf) != chains:
raise ValueError("Number of log pdf's must match number of chains")
first_pdf_parameters = self._log_pdf[0].n_parameters
for pdf in self._log_pdf:
if not isinstance(pdf, BaseCost):
raise ValueError("All log pdf's must be instances of BaseCost")
if pdf.n_parameters != first_pdf_parameters:
raise ValueError(
"All log pdf's must have the same number of parameters"
)
self._multi_log_pdf = True
# Single chain vs multiple chain samplers
[docs]
self._single_chain = issubclass(self.sampler, SingleChainMCMC)
# Construct the samplers object
if self._single_chain:
self._n_samplers = self._n_chains
self._samplers = [self.sampler(x0, sigma0=self._cov0) for x0 in self._x0]
else:
self._n_samplers = 1
self._samplers = [self.sampler(self._n_chains, self._x0, self._cov0)]
# Check for sensitivities from sampler and set evaluation
[docs]
self._needs_sensitivities = self._samplers[0].needs_sensitivities()
# Check initial phase
[docs]
self._initial_phase = self._samplers[0].needs_initial_phase()
if self._initial_phase:
self.set_initial_phase_iterations()
# Set parallelisation
self.set_parallel(self._parallel)
[docs]
def _initialise_chain_processor(self):
"""
Initialise the appropriate chain processor based on configuration.
"""
if self._single_chain:
self._chain_processor = SingleChainProcessor(self)
else:
self._chain_processor = MultiChainProcessor(self)
[docs]
def run(self) -> Optional[np.ndarray]:
"""
Executes the Monte Carlo sampling process and generates samples
from the posterior distribution.
This method orchestrates the entire sampling process, managing
iterations, evaluations, logging, and stopping criteria. It
initialises the necessary structures, handles both single and
multi-chain scenarios, and manages parallel or sequential
evaluation based on the configuration.
Returns:
np.ndarray: A numpy array containing the samples from the
posterior distribution if chains are stored in memory,
otherwise returns None.
Raises:
ValueError: If no stopping criterion is set (i.e.,
_max_iterations is None).
Details:
- Checks and ensures at least one stopping criterion is set.
- Initialises iterations, evaluations, and other required
structures.
- Sets up the evaluator (parallel or sequential) based on the
configuration.
- Handles the initial phase, if applicable, and manages
intermediate steps in the sampling process.
- Logs progress and relevant information based on the logging
configuration.
- Iterates through the sampling process, evaluating the log
PDF, updating chains, and managing the stopping criteria.
- Finalises and returns the collected samples, or None if
chains are not stored in memory.
"""
self._initialise_logging()
self._check_stopping_criteria()
self._initialise_chain_processor()
# Initialise iterations and evaluations
self._iteration = 0
evaluator = self._create_evaluator()
self._check_initial_phase()
self._initialise_storage()
running = True
while running:
if (
self._initial_phase
and self._iteration == self._initial_phase_iterations
):
self._end_initial_phase()
xs = self._ask_for_samples()
self.fxs = evaluator.evaluate(xs)
self._process_chains()
if self._single_chain:
self._intermediate_step = min(self._n_samples) <= self._iteration
# Skip the remaining loop logic
if self._intermediate_step:
continue
self._iteration += 1
if self._log_to_screen and self._verbose:
if self._iteration <= 10 or self._iteration % 50 == 0:
timing_iterations = self._iteration - self._loop_iters
elapsed_time = time.time() - self.iter_time
iterations_per_second = (
timing_iterations / elapsed_time if elapsed_time > 0 else 0
)
logging.info(
f"| Iteration: {self._iteration} | Iter/s: {iterations_per_second: .2f} |"
)
self.iter_time = time.time()
self._loop_iters = self._iteration
if self._max_iterations and self._iteration >= self._max_iterations:
running = False
self._finalise_logging()
if not self._chains_in_memory:
return None
if self._warm_up:
self._samples = self._samples[:, self._warm_up :, :]
return self._samples
[docs]
def _process_chains(self):
"""
Process chains using the appropriate processor.
"""
self._chain_processor.process_chain()
[docs]
def _ask_for_samples(self):
if self._single_chain:
return [self._samplers[i].ask() for i in self._active]
return self._samplers[0].ask()
[docs]
def _check_initial_phase(self):
"""
Set initial phase if needed
"""
if self._initial_phase:
for sampler in self._samplers:
sampler.set_initial_phase(True)
[docs]
def _end_initial_phase(self):
for sampler in self._samplers:
sampler.set_initial_phase(False)
if self._log_to_screen:
logging.info("Initial phase completed.")
[docs]
def _check_stopping_criteria(self):
"""
Verify that at least one stopping criterion is defined.
"""
if self._max_iterations is None:
raise ValueError("At least one stopping criterion must be set.")
[docs]
def _create_evaluator(self):
"""
Create appropriate evaluator based on configuration settings.
"""
common_args = {"calculate_grad": self._needs_sensitivities}
# Construct function for evaluation
if not self._multi_log_pdf:
f = partial(self.call_cost, cost=self._log_pdf, **common_args)
else:
f = [
partial(self.call_cost, cost=log_pdf, **common_args)
for log_pdf in self._log_pdf
]
# Handle parallel case
if self._parallel:
# Adjust workers for single log pdf case
if not self._multi_log_pdf:
self._n_workers = min(self._n_workers, self._n_chains)
return ParallelEvaluator(f, n_workers=self._n_workers)
# Construct a dict for various return types
evaluator_map = {False: SequentialEvaluator, True: MultiSequentialEvaluator}
return evaluator_map[self._multi_log_pdf](f)
[docs]
def _initialise_storage(self):
if isinstance(self._log_pdf, LogPosterior):
self._prior = self._log_pdf.prior
# Storage of the received samples
self._sampled_logpdf = np.zeros(self._n_chains)
self._sampled_prior = np.zeros(self._n_chains)
# Pre-allocate arrays for chain storage
storage_shape = (
(self._n_chains, self._max_iterations, self.n_parameters)
if self._chains_in_memory
else (self._n_chains, self.n_parameters)
)
self._samples = np.zeros(storage_shape)
# Pre-allocate arrays for evaluation storage
if self._prior:
# Store posterior, likelihood, prior
self._evaluations = np.zeros((self._n_chains, self._max_iterations, 3))
else:
# Store pdf
self._evaluations = np.zeros((self._n_chains, self._max_iterations))
# From PINTS:
# Some samplers need intermediate steps, where `None` is returned instead
# of a sample. But samplers can run asynchronously, so that one can return
# `None` while another returns a sample. To deal with this, we maintain a
# list of 'active' samplers that have not reached `max_iterations`,
# and store the number of samples so far in each chain.
if self._single_chain:
self._active = list(range(self._n_chains))
self._n_samples = [0] * self._n_chains
[docs]
def _initialise_logging(self):
logging.basicConfig(format="%(message)s", level=logging.INFO)
if self._log_to_screen:
logging.info("Using " + str(self._samplers[0].name()))
logging.info("Generating " + str(self._n_chains) + " chains.")
if self._parallel:
logging.info(
f"Running in parallel with {self._n_workers} worker processes."
)
else:
logging.info("Running in sequential mode.")
if self._chain_files:
logging.info("Writing chains to " + self._chain_files[0] + " etc.")
if self._evaluation_files:
logging.info(
"Writing evaluations to " + self._evaluation_files[0] + " etc."
)
[docs]
def _finalise_logging(self):
if self._log_to_screen:
logging.info(
f"Halting: Maximum number of iterations ({self._iteration}) reached."
)
@property
[docs]
def prior(self):
return self._prior
@property
[docs]
def samplers(self):
return self._samplers
@property
[docs]
def active(self):
return self._active
@property
[docs]
def single_chain(self):
return self._single_chain
@property
[docs]
def sampled_logpdf(self):
return self._sampled_logpdf
@property
[docs]
def sampled_prior(self):
return self._sampled_prior
@property
[docs]
def iteration(self):
return self._iteration
@property
[docs]
def needs_sensitivities(self):
return self._needs_sensitivities
@property
[docs]
def chains_in_memory(self):
return self._chains_in_memory
@property
[docs]
def n_samples(self):
return self._n_samples
@property
[docs]
def max_iterations(self):
return self._max_iterations