Source code for pybop.plotting.quick_plot

import math
import sys
import textwrap

import numpy as np

from pybop import PlotlyManager

[docs] DEFAULT_LAYOUT_OPTIONS = dict( title=None, title_x=0.5, xaxis=dict(title=None, titlefont_size=12, tickfont_size=12), yaxis=dict(title=None, titlefont_size=12, tickfont_size=12), legend=dict(x=1, y=1, xanchor="right", yanchor="top", font_size=12), showlegend=True, autosize=False, width=1024, height=576, margin=dict(l=10, r=10, b=10, t=75, pad=4), )
[docs] DEFAULT_SUBPLOT_OPTIONS = dict( start_cell="bottom-left", )
[docs] DEFAULT_TRACE_OPTIONS = dict(line=dict(width=4), mode="lines")
[docs] DEFAULT_SUBPLOT_TRACE_OPTIONS = dict(line=dict(width=2), mode="lines")
[docs] class StandardPlot: """ A class for creating and displaying interactive Plotly figures. Parameters ---------- x : list or np.ndarray X-axis data points. y : list or np.ndarray Primary Y-axis data points for simulated model output. layout : Plotly layout, optional A layout for the figure, overrides the layout options (default: None). layout_options : dict, optional Settings to modify the default layout (default: DEFAULT_LAYOUT_OPTIONS). trace_options : dict, optional Settings to modify the default trace type (default: DEFAULT_TRACE_OPTIONS). trace_names : str, optional Name(s) for the primary trace(s) (default: None). trace_name_width : int, optional Maximum length of the trace names before text wrapping is used (default: 40). Returns ------- plotly.graph_objs.Figure The generated Plotly figure. """ def __init__( self, x, y, layout=None, layout_options=DEFAULT_LAYOUT_OPTIONS.copy(), trace_options=DEFAULT_TRACE_OPTIONS.copy(), trace_names=None, trace_name_width=40, ):
[docs] self.x = x
[docs] self.y = y
[docs] self.layout = layout
[docs] self.layout_options = layout_options
[docs] self.trace_options = DEFAULT_TRACE_OPTIONS.copy()
if trace_options is not None: for arg, value in trace_options.items(): self.trace_options[arg] = value if isinstance(trace_names, str): self.trace_names = [trace_names] else: self.trace_names = trace_names
[docs] self.trace_name_width = trace_name_width
# Check type and dimensions of data # What we want is a list of 'things plotly can take', e.g. numpy arrays or lists of numbers if isinstance(self.x, list): # If it's a list of numpy arrays, it's fine # If it's a list of lists, it's fine # If it's neither, it's a list of numbers that we need to wrap if not isinstance(self.x[0], np.ndarray) and not isinstance( self.x[0], list ): self.x = [self.x] elif isinstance(self.x, np.ndarray): self.x = np.squeeze(self.x) if self.x.ndim == 1: self.x = [self.x] else: self.x = self.x.tolist() if isinstance(self.y, list): if not isinstance(self.y[0], np.ndarray) and not isinstance( self.y[0], list ): self.y = [self.y] if isinstance(self.y, np.ndarray): self.y = np.squeeze(self.y) if self.y.ndim == 1: self.y = [self.y] else: self.y = self.y.tolist() if len(self.x) > 1 and len(self.x) != len(self.y): raise ValueError( "Input x should have either one data series or the same number as y." ) # Attempt to import plotly when an instance is created
[docs] self.go = PlotlyManager().go
# Create layout if self.layout is None: self.layout = self.go.Layout(self.layout_options) if self.layout_options is not None: for arg, value in self.layout_options.items(): self.layout[arg] = value # Wrap trace names if self.trace_names is not None: for i, name in enumerate(self.trace_names): self.trace_names[i] = self.wrap_text(name, width=self.trace_name_width) # Create a trace for each trajectory
[docs] self.traces = []
x = self.x[0] for i in range(0, len(self.y)): if len(self.x) > 1: x = self.x[i] if self.trace_names is not None: self.trace_options["name"] = self.trace_names[i] else: self.trace_options["showlegend"] = False trace = self.create_trace(x, self.y[i], **self.trace_options) self.traces.append(trace)
[docs] def __call__(self, show=True): """ Generate and show the figure. Parameters ---------- show : bool, optional If True, the figure is shown upon creation (default: True). """ fig = self.go.Figure(data=self.traces, layout=self.layout) if "ipykernel" in sys.modules and show: fig.show("svg") elif show: fig.show() return fig
[docs] def create_trace(self, x, y, **trace_options): """ Create a trace for the Plotly figure. Returns ------- plotly.graph_objs.Scatter A trace for a Plotly figure. """ return self.go.Scatter( x=x, y=y, **trace_options, )
@staticmethod
[docs] def wrap_text(text, width): """ Wrap text to a specified width with HTML line breaks. Parameters ---------- text : str The text to wrap. width : int The width to wrap the text to. Returns ------- str The wrapped text. """ wrapped_text = textwrap.fill(text, width=width, break_long_words=False) return wrapped_text.replace("\n", "<br>")
@staticmethod
[docs] def remove_brackets(s): """ Remove square brackets from a string and replace with forward slashes as per section 7.1 of the SI Handbook """ # If s is an iterable (but not a string), apply the function recursively to each element if hasattr(s, "__iter__") and not isinstance(s, str): return type(s)(StandardPlot.remove_brackets(i) for i in s) elif isinstance(s, str): start = s.find("[") end = s.find("]") if start != -1 and end != -1: char_in_brackets = s[start + 1 : end] return s[:start] + " / " + char_in_brackets + s[end + 1 :] return s
[docs] class StandardSubplot(StandardPlot): """ A class for creating and displaying a set of interactive Plotly figures in a grid layout. Parameters ---------- x : list or np.ndarray X-axis data points. y : list or np.ndarray Primary Y-axis data points for simulated model output. num_rows : int, optional Number of rows of subplots, can be set automatically (default: None). num_cols : int, optional Number of columns of subplots, can be set automatically (default: None). layout : Plotly layout, optional A layout for the figure, overrides the layout options (default: None). layout_options : dict, optional Settings to modify the default layout (default: DEFAULT_LAYOUT_OPTIONS). trace_options : dict, optional Settings to modify the default trace type (default: DEFAULT_TRACE_OPTIONS). trace_names : str, optional Name(s) for the primary trace(s) (default: None). trace_name_width : int, optional Maximum length of the trace names before text wrapping is used (default: 40). Returns ------- plotly.graph_objs.Figure The generated Plotly figure. """ def __init__( self, x, y, num_rows=None, num_cols=None, axis_titles=None, layout=None, layout_options=DEFAULT_LAYOUT_OPTIONS.copy(), subplot_options=DEFAULT_SUBPLOT_OPTIONS.copy(), trace_options=DEFAULT_SUBPLOT_TRACE_OPTIONS.copy(), trace_names=None, trace_name_width=40, ): super().__init__( x, y, layout, layout_options, trace_options, trace_names, trace_name_width )
[docs] self.num_traces = len(self.traces)
[docs] self.num_rows = num_rows
[docs] self.num_cols = num_cols
if self.num_rows is None and self.num_cols is None: # Work out the number of subplots self.num_cols = int(math.ceil(math.sqrt(self.num_traces))) self.num_rows = int(math.ceil(self.num_traces / self.num_cols)) elif self.num_rows is None: self.num_rows = int(math.ceil(self.num_traces / self.num_cols)) elif self.num_cols is None: self.num_cols = int(math.ceil(self.num_traces / self.num_rows))
[docs] self.axis_titles = axis_titles
[docs] self.subplot_options = DEFAULT_SUBPLOT_OPTIONS.copy()
if subplot_options is not None: for arg, value in subplot_options.items(): self.subplot_options[arg] = value # Attempt to import plotly when an instance is created
[docs] self.make_subplots = PlotlyManager().make_subplots
[docs] def __call__(self, show): """ Generate and show the set of figures. Parameters ---------- show : bool, optional If True, the figure is shown upon creation (default: True). """ fig = self.make_subplots( rows=self.num_rows, cols=self.num_cols, **self.subplot_options ) fig.update_layout(self.layout_options) for idx, trace in enumerate(self.traces): row = (idx // self.num_cols) + 1 col = (idx % self.num_cols) + 1 fig.add_trace(trace, row=row, col=col) if self.axis_titles and idx < len(self.axis_titles): x_title, y_title = self.axis_titles[idx] fig.update_xaxes(title_text=x_title, row=row, col=col) fig.update_yaxes(title_text=y_title, row=row, col=col) if "ipykernel" in sys.modules and show: fig.show("svg") elif show: fig.show() return fig
[docs] def plot_trajectories(x, y, trace_names=None, show=True, **layout_kwargs): """ Quickly plot one or more trajectories using Plotly. Parameters ---------- x : list or np.ndarray X-axis data points. y : list or np.ndarray Y-axis data points for each trajectory. trace_names : list or str, optional Name(s) for the trace(s) (default: None). **layout_kwargs : optional Valid Plotly layout keys and their values, e.g. `xaxis_title="Time / s"` or `xaxis={"title": "Time / s", "titlefont_size": 18}`. Returns ------- plotly.graph_objs.Figure The Plotly figure object for the scatter plot. """ # Create a plotting dictionary plot_dict = StandardPlot( x=x, y=y, trace_names=trace_names, ) # Generate the figure and update the layout fig = plot_dict(show=False) fig.update_layout(**layout_kwargs) if "ipykernel" in sys.modules and show: fig.show("svg") elif show: fig.show() return fig