import multiprocessing as mp
import platform
import re
from dataclasses import dataclass, field
import numpy as np
import pybamm
[docs]
def is_numeric(x):
"""
Check if a variable is numeric.
"""
return isinstance(x, int | float | np.number)
@dataclass(frozen=True)
[docs]
class FailedVariable:
"""
Check if a variable is numeric.
Container for a failed PyBaMM variable that returns np.inf.
Args:
name: Variable name
data: Array data, defaults to [np.inf]
sensitivities: Sensitivity data mapping parameter names to arrays
"""
[docs]
data: np.ndarray = field(default_factory=lambda: np.asarray([np.inf]))
[docs]
sensitivities: {str, np.ndarray} = field(default_factory=dict)
[docs]
def __post_init__(self) -> None:
"""Validate inputs after initialisation."""
if not isinstance(self.name, str) or not self.name.strip():
raise ValueError("Variable name must be a non-empty string")
if not isinstance(self.data, np.ndarray):
object.__setattr__(self, "data", np.asarray(self.data))
[docs]
class FailedSolution:
"""
return isinstance(x, int | float | np.number)
Container for a failed PyBaMM solution that returns [np.inf] for all processed variables.
This class mimics the interface of a successful PyBaMM solution but returns
infinity values to indicate failure while maintaining API compatibility.
Args:
variable_names: List of variable names in the solution
parameter_names: List of parameter names for sensitivity analysis
Example:
>>> solution = FailedSolution(["Voltage [V]"], ["Negative particle radius [m]"])
>>> voltage = solution["Voltage [V]"]
>>> print(voltage.data) # np.ndarray([inf])
"""
def __init__(self, variable_names: list[str], parameter_names: list[str]):
self._validate_inputs(variable_names, parameter_names)
[docs]
self._variable_names = variable_names
[docs]
self._parameter_names = parameter_names
# Solution metadata
[docs]
self.cycles: int | None = None
[docs]
self.termination: str = "failure"
[docs]
self.solve_time: float = 0.0
[docs]
self.integration_time: float = 0.0
[docs]
self._t_eval: np.ndarray = np.asarray([0.0])
# Initialise failed variables
[docs]
self._variables: {str, FailedVariable} = pybamm.FuzzyDict()
self._initialise_variables()
[docs]
def _initialise_variables(self) -> None:
"""Initialise all variables with failed state."""
inf_array = np.asarray([np.inf])
for var_name in self._variable_names:
if self._parameter_names:
sensitivities = {p: inf_array.copy() for p in self._parameter_names}
sensitivities["all"] = [inf_array.copy() for _ in self._parameter_names]
else:
sensitivities = {}
self._variables[var_name] = FailedVariable(
name=var_name, data=inf_array.copy(), sensitivities=sensitivities
)
[docs]
def __getattr__(self, name):
# Return self for any method calls to allow chaining
return self
[docs]
def __getitem__(self, key):
return self._variables[key]
[docs]
def plot(self, *args, **kwargs):
print("Cannot plot a failed solution")
return None
[docs]
def save(self, *args, **kwargs):
print("Cannot save a failed solution")
return None
[docs]
def copy(self):
return FailedSolution(self._variable_names, self._parameter_names)
@property
[docs]
def t_eval(self) -> np.ndarray:
"""Time evaluation points (returns [inf] for failed solutions)."""
return self._t_eval
@property
[docs]
def variable_names(self) -> list[str]:
"""Get list of variable names (read-only)."""
return self._variable_names.copy()
@property
[docs]
def parameter_names(self) -> list[str]:
"""Get list of parameter names (read-only)."""
return self._parameter_names.copy()
[docs]
def keys(self) -> list[str]:
"""Get all variable names."""
return list(self._variables.keys())
[docs]
def values(self) -> list[FailedVariable]:
"""Get all variables."""
return list(self._variables.values())
[docs]
def items(self) -> list[tuple[str, FailedVariable]]:
"""Get all variable name-value pairs."""
return list(self._variables.items())
[docs]
def add_spaces(string):
"""
Return the class name as a string with spaces before each new capitalised word.
"""
re_outer = re.compile(r"([^A-Z ])([A-Z])")
re_inner = re.compile(r"(?<!^)([A-Z])([^A-Z])")
return re_outer.sub(r"\1 \2", re_inner.sub(r" \1\2", string))
[docs]
class SymbolReplacer:
"""
Helper class to replace all instances of one or more symbols in an expression tree
with another symbol, as defined by the dictionary `symbol_replacement_map`
Originally developed by pybamm: https://github.com/pybamm-team/pybamm
Parameters
----------
symbol_replacement_map : dict {:class:`pybamm.Symbol` -> :class:`pybamm.Symbol`}
Map of which symbols should be replaced by which.
processed_symbols: dict {:class:`pybamm.Symbol` -> :class:`pybamm.Symbol`}, optional
cached replaced symbols
process_initial_conditions: bool, optional
Whether to process initial conditions, default is True
"""
def __init__(
self,
symbol_replacement_map: dict[pybamm.Symbol, pybamm.Symbol],
processed_symbols: dict[pybamm.Symbol, pybamm.Symbol] | None = None,
process_initial_conditions: bool = True,
):
[docs]
self._symbol_replacement_map = symbol_replacement_map
[docs]
self._processed_symbols = processed_symbols or {}
[docs]
self._process_initial_conditions = process_initial_conditions
[docs]
def process_model(self, unprocessed_model, inplace=True):
"""
Replace all instances of a symbol in a PyBaMM model class.
Parameters
----------
unprocessed_model : :class:`pybamm.BaseModel`
Model class to assign parameter values to
inplace: bool, optional
If True, replace the parameters in the model in place. Otherwise, return a
new model with parameter values set (default: True).
"""
model = unprocessed_model if inplace else unprocessed_model.new_copy()
for variable, equation in unprocessed_model.rhs.items():
pybamm.logger.verbose(f"Replacing symbols in {variable!r} (rhs)")
model.rhs[self.process_symbol(variable)] = self.process_symbol(equation)
for variable, equation in unprocessed_model.algebraic.items():
pybamm.logger.verbose(f"Replacing symbols in {variable!r} (algebraic)")
model.algebraic[self.process_symbol(variable)] = self.process_symbol(
equation
)
for variable, equation in unprocessed_model.initial_conditions.items():
pybamm.logger.verbose(
f"Replacing symbols in {variable!r} (initial conditions)"
)
if self._process_initial_conditions:
model.initial_conditions[self.process_symbol(variable)] = (
self.process_symbol(equation)
)
else:
model.initial_conditions[self.process_symbol(variable)] = equation
model.boundary_conditions = self.process_boundary_conditions(unprocessed_model)
for variable, equation in unprocessed_model.variables.items():
pybamm.logger.verbose(f"Replacing symbols in {variable!r} (variables)")
model.variables[variable] = self.process_symbol(equation)
model.events = self._process_events(unprocessed_model.events)
pybamm.logger.info(f"Finish replacing symbols in {model.name}")
return model
[docs]
def _process_events(self, events: list) -> list:
new_events = []
for event in events:
pybamm.logger.verbose(f"Replacing symbols in event '{event.name}'")
new_events.append(
pybamm.Event(
event.name, self.process_symbol(event.expression), event.event_type
)
)
return new_events
[docs]
def process_boundary_conditions(self, model):
"""
Process boundary conditions for a PybaMM model class
Boundary conditions are dictionaries {"left": left bc, "right": right bc}
in general, but may be imposed on the tabs (or *not* on the tab) for a
small number of variables, e.g. {"negative tab": neg. tab bc,
"positive tab": pos. tab bc "no tab": no tab bc}.
"""
boundary_conditions = {}
sides = ["left", "right", "negative tab", "positive tab", "no tab"]
for variable, bcs in model.boundary_conditions.items():
processed_variable = self.process_symbol(variable)
boundary_conditions[processed_variable] = {}
for side in sides:
try:
bc, typ = bcs[side]
pybamm.logger.verbose(
f"Replacing symbols in {variable!r} ({side} bc)"
)
processed_bc = (self.process_symbol(bc), typ)
boundary_conditions[processed_variable][side] = processed_bc
except KeyError as err:
# Don't raise if side is not in the boundary conditions
if err.args[0] in side:
pass
# Raise otherwise
else: # pragma: no cover
raise KeyError(err) from err
return boundary_conditions
[docs]
def process_symbol(self, symbol):
"""
This function recurses down the tree, replacing any symbols in
self._symbol_replacement_map.keys() with their corresponding value
Parameters
----------
symbol : :class:`pybamm.Symbol`
The symbol to replace
Returns
-------
:class:`pybamm.Symbol`
Symbol with all replacements performed
"""
if symbol in self._processed_symbols:
return self._processed_symbols[symbol]
processed_symbol = self._process_symbol(symbol)
self._processed_symbols[symbol] = processed_symbol
return processed_symbol
[docs]
def _process_symbol(self, symbol: pybamm.Symbol) -> pybamm.Symbol:
if symbol in self._symbol_replacement_map:
return self._symbol_replacement_map[symbol]
if isinstance(symbol, pybamm.BinaryOperator):
# process children
new_left = self.process_symbol(symbol.left)
new_right = self.process_symbol(symbol.right)
return symbol._binary_new_copy(new_left, new_right) # noqa: SLF001
if isinstance(symbol, pybamm.UnaryOperator):
new_child = self.process_symbol(symbol.child)
return symbol._unary_new_copy(new_child) # noqa: SLF001
if isinstance(symbol, pybamm.Function):
new_children = [self.process_symbol(child) for child in symbol.children]
# Return a new copy with the replaced symbols
return symbol._function_new_copy(new_children) # noqa: SLF001
if isinstance(symbol, pybamm.Concatenation):
new_children = [self.process_symbol(child) for child in symbol.children]
return symbol._concatenation_new_copy(new_children) # noqa: SLF001
# Return leaf
return symbol
[docs]
class RecommendedSolver(pybamm.IDAKLUSolver):
"""A shortcut for creating the PyBaMM solver recommended for optimisation."""
def __init__(self, output_variables: list[str] | None = None):
solver_options = {}
if platform.system() != "Windows":
solver_options["num_threads"] = max(1, mp.cpu_count())
super().__init__(
on_failure="ignore",
atol=1e-6,
rtol=1e-6,
options=solver_options,
output_variables=output_variables,
)