Source code for pybop.plotting.quick_plot

import numpy as np
import textwrap
import pybop


[docs] class StandardPlot: """ A class for creating and displaying Plotly figures for model output comparison. Generates interactive plots comparing simulated model output with an optional target dataset and visualizes uncertainty. Parameters ---------- x : list or np.ndarray X-axis data points. y : list or np.ndarray Primary Y-axis data points for simulated model output. cost : float Cost associated with the model output. y2 : list or np.ndarray, optional Secondary Y-axis data points for the target dataset (default: None). title : str, optional Title of the plot (default: None). xaxis_title : str, optional Title for the x-axis (default: None). yaxis_title : str, optional Title for the y-axis (default: None). trace_name : str, optional Name for the primary trace (default: "Simulated"). width : int, optional Width of the figure in pixels (default: 1024). height : int, optional Height of the figure in pixels (default: 576). """ def __init__( self, x, y, cost, y2=None, title=None, xaxis_title=None, yaxis_title=None, trace_name=None, width=1024, height=576, ): """ Initialize the StandardPlot object with simulation and optional target data. Parameters ---------- x : list or np.ndarray X-axis data points. y : list or np.ndarray Primary Y-axis data points for simulated model output. cost : float Cost associated with the model output. y2 : list or np.ndarray, optional Secondary Y-axis data points for target dataset (default: None). title : str, optional Plot title (default: None). xaxis_title : str, optional X-axis title (default: None). yaxis_title : str, optional Y-axis title (default: None). trace_name : str, optional Name for the primary trace (default: "Simulated"). width : int, optional Figure width in pixels (default: 1024). height : int, optional Figure height in pixels (default: 576). """ self.x = x if isinstance(x, list) else x.tolist() self.y = y self.y2 = y2 self.cost = cost self.width = width self.height = height self.title = title self.xaxis_title = xaxis_title self.yaxis_title = yaxis_title self.trace_name = trace_name or "Simulated" if self.y2 is not None: self.sigma = np.std(self.y - self.y2) self.y_upper = (self.y + self.sigma).tolist() self.y_lower = (self.y - self.sigma).tolist() # Attempt to import plotly when an instance is created self.go = pybop.PlotlyManager().go @staticmethod
[docs] def wrap_text(text, width): """ Wrap text to a specified width with HTML line breaks. Parameters ---------- text : str The text to wrap. width : int The width to wrap the text to. Returns ------- str The wrapped text. """ wrapped_text = textwrap.fill(text, width=width, break_long_words=False) return wrapped_text.replace("\n", "<br>")
[docs] def create_layout(self): """ Create the layout for the Plotly figure. Returns ------- plotly.graph_objs.Layout The layout for the Plotly figure. """ return self.go.Layout( title=self.title, title_x=0.5, xaxis=dict(title=self.xaxis_title, titlefont_size=12, tickfont_size=12), yaxis=dict(title=self.yaxis_title, titlefont_size=12, tickfont_size=12), legend=dict(x=1, y=1, xanchor="right", yanchor="top", font_size=12), showlegend=True, autosize=False, width=self.width, height=self.height, margin=dict(l=10, r=10, b=10, t=75, pad=4), )
[docs] def create_traces(self): """ Create traces for the Plotly figure. Returns ------- list A list of plotly.graph_objs.Scatter objects to be used as traces. """ traces = [] wrapped_trace_name = self.wrap_text(self.trace_name, width=40) simulated_trace = self.go.Scatter( x=self.x, y=self.y, line=dict(width=4), mode="lines", name=wrapped_trace_name, ) if self.y2 is not None: target_trace = self.go.Scatter( x=self.x, y=self.y2, mode="markers", name="Target" ) fill_trace = self.go.Scatter( x=self.x + self.x[::-1], y=self.y_upper + self.y_lower[::-1], fill="toself", fillcolor="rgba(255,229,204,0.8)", line=dict(color="rgba(255,255,255,0)"), hoverinfo="skip", showlegend=False, ) traces.extend([fill_trace, target_trace]) traces.append(simulated_trace) return traces
[docs] def __call__(self): """ Generate the Plotly figure. Returns ------- plotly.graph_objs.Figure The generated Plotly figure. """ layout = self.create_layout() traces = self.create_traces() fig = self.go.Figure(data=traces, layout=layout) return fig
[docs] def quick_plot(params, cost, title="Scatter Plot", width=1024, height=576): """ Quickly plot the target dataset against minimized model output. Parameters ---------- params : array-like Optimized parameters. cost : object Cost object with problem, dataset, and signal attributes. title : str, optional Title of the plot (default: "Scatter Plot"). width : int, optional Width of the figure in pixels (default: 1024). height : int, optional Height of the figure in pixels (default: 576). Returns ------- plotly.graph_objs.Figure The Plotly figure object for the scatter plot. """ # Extract the time data and evaluate the model's output and target values time_data = cost.problem._dataset["Time [s]"].data model_output = cost.problem.evaluate(params) target_output = cost.problem.target() for i in range(0, cost.problem.n_outputs): # Create the figure using the StandardPlot class fig = pybop.StandardPlot( x=time_data, y=model_output[:, i], cost=cost, y2=target_output[:, i], xaxis_title="Time [s]", yaxis_title=cost.problem.signal[i], title=title, trace_name="Model", width=width, height=height, )() # Display the figure fig.show() return fig