Source code for pybop.samplers.base_sampler

from dataclasses import dataclass

import numpy as np
import pints
import scipy

from pybop import plot
from pybop._logging import Logger
from pybop._result import Result
from pybop.problems.log_pdf import LogPDF


[docs] @dataclass class SamplerOptions: """ Base options for the sampler. Attributes ---------- n_chains : int The number of chains to concurrently sample from. """ n_chains: int = 1
[docs] def validate(self): """ Validate the options. Raises ------ ValueError If the options are invalid. """ if self.n_chains < 1: raise ValueError("Number of chains must be greater than 0.")
[docs] class BaseSampler: """ Base class for Monte Carlo samplers. Parameters ---------- log_pdf : pybop.LogPDF The negative unnormalised posterior distribution. options : SamplerOptions, optional Options for the sampler. If None, default options are used. """ def __init__( self, log_pdf: LogPDF, options: SamplerOptions | None = None, ): self._log_pdf = log_pdf self._n_parameters = len(self._log_pdf.parameters) self._logger = None self._options = options or self.default_options() self._options.validate() # Get initial conditions self._mean0 = self._log_pdf.parameters.get_mean(transformed=True) * np.ones( [self._options.n_chains, 1] ) self._cov0 = self._log_pdf.parameters.get_covariance(transformed=True) self._validate_covariance_matrix()
[docs] def _validate_covariance_matrix(self) -> None: """Check or create the initial covariance matrix.""" if self._cov0 is None: self._cov0 = 0.05 if np.isscalar(self._cov0): self._cov0 = np.eye(self._n_parameters) * self._cov0 else: self._cov0 = np.atleast_2d(self._cov0) if (np.atleast_1d(self._cov0) < 0).any(): raise ValueError("Covariance values must be nonnegative.")
[docs] @staticmethod def default_options() -> SamplerOptions: """Get the default options for the sampler.""" return SamplerOptions()
@property def mean0(self) -> np.ndarray: return self._mean0 @property def cov0(self) -> np.ndarray: return self._cov0 @property def log_pdf(self) -> LogPDF: return self._log_pdf @property def options(self) -> SamplerOptions: return self._options
[docs] def run(self) -> "SamplingResult": """ Sample from the posterior distribution. Returns: np.ndarray: Samples from the posterior distribution. """ raise NotImplementedError
[docs] def set_initial_phase_iterations(self, iterations: int = 250): """Set the number of iterations for the initial phase of the sampler.""" self._initial_phase_iterations = iterations
[docs] def set_max_iterations(self, iterations: int = 500): """Set the maximum number of iterations for the sampler.""" iterations = int(iterations) if iterations < 1: raise ValueError("Number of iterations must be greater than 0.") self._max_iterations = iterations
[docs] def set_warm_up_iterations(self, iterations: int = 250): """Set the number of warm up iterations for the sampler.""" iterations = int(iterations) if iterations < 1: raise ValueError("Number of iterations must be greater than 0.") self._warm_up = iterations
@property def logger(self) -> Logger | None: return self._logger
[docs] class SamplingResult(Result): """ Stores the result of the sampling. Attributes ---------- sampler : pybop.BaseSampler The sampler used to generate the results. time : float The time taken. chains : np.ndarray, optional An array containing the samples from the posterior distribution, or None. method_name : str The name of the sampler. message : str The reason for stopping given by the sampler. """ def __init__( self, sampler: "BaseSampler", time: float, chains: np.ndarray, method_name: str | None = None, message: str | None = None, ): super().__init__( problem=sampler.log_pdf, logger=sampler.logger, time=time, method_name=method_name, message=message, ) self.chains = chains self.all_samples = np.concatenate(chains, axis=0) self.n_parameters = self.chains.shape[2]
[docs] def signif(self, x, p: int): """ Rounds array `x` to `p` significant digits. """ x = np.asarray(x) x_positive = np.where(np.isfinite(x) & (x != 0), np.abs(x), 10 ** (p - 1)) mags = 10 ** (p - 1 - np.floor(np.log10(x_positive))) return np.round(x * mags) / mags
[docs] def _calculate_statistics(self, fun, attr_name, *args, **kwargs): """ Calculate statistics from callable `fun`. """ stat = fun(self.all_samples, *args, **kwargs) if fun is scipy.stats.mode: setattr(self, attr_name, stat[0]) else: setattr(self, attr_name, stat) return self.signif(stat, self.sig_digits)
[docs] def get_summary_statistics(self, significant_digits: int = 4): """ Calculate summary statistics for the posterior samples. Parameters ---------- significant_digits : int Number of significant digits to display for summary statistics. Returns ------- dict Summary statistics including mean, median, standard deviation, and 95% credible interval. """ self.sig_digits = significant_digits summary_funs = { "mean": np.mean, "median": np.median, "mode": scipy.stats.mode, "max": np.max, "min": np.min, "std": np.std, "ci_lower": lambda x, axis: np.percentile(x, 2.5, axis=axis), "ci_upper": lambda x, axis: np.percentile(x, 97.5, axis=axis), } return { key: self._calculate_statistics(func, key, axis=0) for key, func in summary_funs.items() }
[docs] def plot_trace(self, **kwargs): """ Plot trace plots for the posterior samples. """ return plot.trace(result=self, **kwargs)
[docs] def plot_chains(self, **kwargs): """ Plot posterior distributions for each chain. """ return plot.chains(result=self, **kwargs)
[docs] def plot_posterior(self, **kwargs): """ Plot the summed posterior distribution across chains. """ return plot.posterior(result=self, **kwargs)
[docs] def summary_table(self, **kwargs): """ Display summary statistics in a table. """ return plot.summary_table(result=self, **kwargs)
[docs] def autocorrelation(self, x: np.ndarray) -> np.ndarray: """ Computes the autocorrelation (Pearson correlation coefficient) of a numpy array representing samples. """ x = (x - x.mean()) / (x.std() * np.sqrt(len(x)) + np.finfo(float).eps) cor = np.correlate(x, x, mode="full") return cor[len(x) : -1]
[docs] def _autocorrelate_negative(self, autocorrelation): """ Returns the index of the first negative entry in ``autocorrelation``, or ``len(autocorrelation)`` if a negative entry is not found. """ negative_indices = np.where(autocorrelation < 0)[0] return ( negative_indices[0] if negative_indices.size > 0 else len(autocorrelation) )
[docs] def effective_sample_size(self, mixed_chains=False): """ Computes the effective sample size (ESS) for each parameter in each chain. Parameters ---------- mixed_chains : bool, optional If True, the ESS is computed for all samplers mixed into a single chain. Defaults to False. Returns ------- list A list of effective sample sizes for each parameter in each chain, or for the mixed chain if `mixed_chains` is True. Raises ------ ValueError If there are fewer than two samples in the data. """ if self.all_samples.shape[0] < 2: raise ValueError("At least two samples must be given.") def compute_ess(samples): """Helper function to compute the ESS for a single set of samples.""" ess = [] for j in range(self.n_parameters): rho = self.autocorrelation(samples[:, j]) T = self._autocorrelate_negative(rho) ess.append(len(samples[:, j]) / (1 + 2 * rho[:T].sum())) return ess if mixed_chains: return compute_ess(self.all_samples) ess = [] for chain in self.chains: ess.extend(compute_ess(chain)) return ess
[docs] def rhat(self): """ Computes the Gelman-Rubin statistic which diagnoses MCMC convergence. For well-mixed and stationary chains R-hat will be close to one, otherwise it is higher. """ return pints.rhat(self.chains)