Module streamgen.visualizations
πΌοΈparameter visualization functions.
View Source
"""πΌοΈparameter visualization functions."""
from import Callable
from copy import deepcopy
from itertools import islice
from typing import Any
import IPython
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from IPython.utils import io
from ipywidgets import widgets
from matplotlib import animation
from streamgen.parameter import Parameter
from import ParameterStore
from streamgen.samplers.tree import SamplingTree
def plot(values: list[int | float | list | np.ndarray], ax: plt.Axes | None = None, title: str | None = None) -> plt.Axes:
"""π plots the scheduled values of a parameter.
This function currently supports plotting numeric parameters and probabilities.
values (list[int | float | list | np.ndarray]): list of values to plot.
ax (plt.Axes | None, optional): matplotlib Axes to plot to. Defaults to None.
title (str | None, optional): title of the plot. Defaults to None
NotImplementedError: when the type of the parameter values are not yet supported by this function
plt.Axes: parameter plot
match values[0]:
case int() | float():
ax = sns.lineplot(values, ax=ax, marker="o")
case np.ndarray() | list():
if ax is None:
fig, ax = plt.subplots()
indices = list(range(len(values)))
values = np.array(values).T
assert len(values.shape) == 2, "only arrays with two dimensions can be visualized here" # noqa: S101, PLR2004
ax.stackplot(indices, values)
case _:
raise NotImplementedError
return ax
def plot_parameter(param: Parameter, num_values: int | None = None, ax: plt.Axes | None = None) -> plt.Axes:
"""βοΈ plots the scheduled values of a parameter.
This function currently supports plotting numeric parameters and probabilities.
param (Parameter): parameter to plot
num_values (int | None, optional): number of values to plot.
If None, collect all values from the schedule. Defaults to None.
ax (plt.Axes | None, optional): matplotlib Axes to plot to. Defaults to None.
NotImplementedError: when the type of the parameter values are not yet supported by this function
plt.Axes: parameter plot
match num_values:
case int():
assert num_values > 1, "at least two value are needed for the plot." # noqa: S101
values = [param.value, *list(islice(deepcopy(param.schedule), num_values - 1))]
case None:
values = [param.value, *list(islice(deepcopy(param.schedule), None))]
return plot(values, ax,
def plot_parameter_store(store: ParameterStore, num_values: int | None = None) -> mpl.figure.Figure:
"""ποΈ plots every parameter in a `ParameterStore` in one figure.
store (ParameterStore): parameter store to plot
num_values (int | None, optional): number of values to plot.
If None, collect all values from the schedule. Defaults to None.
NotImplementedError: when the type of the parameter values are not yet supported
mpl.figure.Figure: matplotlib figure object
num_columns: int = len(store.parameter_names)
fig = plt.figure()
for idx, param in enumerate(store.parameter_names):
ax = fig.add_subplot(num_columns, 1, idx + 1)
plot_parameter(store[param], num_values=num_values, ax=ax)
fig.set_figheight(num_columns * 3.0)
return fig
def plot_parameter_store_widget(store: ParameterStore, num_values: int | None = None) -> widgets.Tab:
"""π plots every parameter of each scope in a `ParameterStore` in a separate `ipywidgets.widgets.Tab`.
store (ParameterStore): parameter store to plot
num_values (int | None, optional): number of values to plot.
If None, collect all values from the schedule. Defaults to None.
NotImplementedError: when the type of the parameter values are not yet supported
widgets.Tab: ipywidgets tab widget
scopes = list(store.scopes)
tabs = [widgets.Output() for _ in scopes]
widget = widgets.Tab(children=tabs)
for idx, scope in enumerate(scopes):
widget.set_title(idx, scope)
with tabs[idx]:
params = store.get_scope(scope)
fig = plot_parameter_store(params, num_values=num_values)
return widget
def plot_labeled_samples_grid(
tree: SamplingTree,
plotting_func: Callable[[Any, plt.Axes], plt.Axes],
columns: int = 4,
) -> mpl.figure.Figure:
"""π plots a `columns`x`columns` grid of labeled samples generated from a `SamplingTree` with `ClassLabelNode`s.
tree (SamplingTree): tree to generate samples from
plotting_func (Callable[[Any, plt.Axes], plt.Axes]): function to visualize a single sample.
The function should take a sample and a `plt.Axes` as arguments.
columns (int, optional): number of samples in the columns (and rows). Defaults to 4.
mpl.figure.Figure: matplotlib figure object
num_samples = columns * columns
with io.capture_output() as _captured:
samples = tree.collect(num_samples)
fig = plt.figure()
for idx, (sample, target) in enumerate(samples):
ax = fig.add_subplot(columns, columns, idx + 1)
plotting_func(sample, ax)
fig.set_figheight(columns * 3.0)
fig.set_figwidth(columns * 3.0)
return fig
def plot_labeled_samples_animation(
tree: SamplingTree,
plotting_func: Callable[[Any, plt.Axes], plt.Axes],
num_samples: int = 8,
interval: int = 200,
) -> IPython.display.HTML:
"""ποΈ plots several labeled samples generated from a `SamplingTree` with `ClassLabelNode`s as an animation.
tree (SamplingTree): tree to generate samples from
plotting_func (Callable[[Any, plt.Axes], plt.Axes]): function to visualize a single sample.
The function should take a sample and a `plt.Axes` as arguments.
num_samples (int, optional): number of samples to include in the animation. Defaults to 8.
interval (int, optional): delay between frames in milliseconds. Defaults to 200.
IPython.display.HTML: animation player as an ipython HTML output
with io.capture_output() as _captured:
samples = tree.collect(num_samples)
fig, ax = plt.subplots()
def _plotting_func(idx: int, ax: plt.Axes) -> None:
sample, target = samples[idx]
plotting_func(sample, ax)
anim = animation.FuncAnimation(fig, _plotting_func, frames=num_samples, fargs=(ax,), interval=interval)
return IPython.display.HTML(anim.to_jshtml())
def plot(
values: list[int | float | list | numpy.ndarray],
ax: matplotlib.axes._axes.Axes | None = None,
title: str | None = None
) -> matplotlib.axes._axes.Axes
π plots the scheduled values of a parameter.
This function currently supports plotting numeric parameters and probabilities.
Name | Type | Description | Default |
values | list[int | float | list |
ax | plt.Axes | None | matplotlib Axes to plot to. Defaults to None. |
title | str | None | title of the plot. Defaults to None |
Type | Description |
plt.Axes | parameter plot |
Type | Description |
NotImplementedError | when the type of the parameter values are not yet supported by this function |
View Source
def plot(values: list[int | float | list | np.ndarray], ax: plt.Axes | None = None, title: str | None = None) -> plt.Axes:
"""π plots the scheduled values of a parameter.
This function currently supports plotting numeric parameters and probabilities.
values (list[int | float | list | np.ndarray]): list of values to plot.
ax (plt.Axes | None, optional): matplotlib Axes to plot to. Defaults to None.
title (str | None, optional): title of the plot. Defaults to None
NotImplementedError: when the type of the parameter values are not yet supported by this function
plt.Axes: parameter plot
match values[0]:
case int() | float():
ax = sns.lineplot(values, ax=ax, marker="o")
case np.ndarray() | list():
if ax is None:
fig, ax = plt.subplots()
indices = list(range(len(values)))
values = np.array(values).T
assert len(values.shape) == 2, "only arrays with two dimensions can be visualized here" # noqa: S101, PLR2004
ax.stackplot(indices, values)
case _:
raise NotImplementedError
return ax
def plot_labeled_samples_animation(
tree: streamgen.samplers.tree.SamplingTree,
plotting_func:[[typing.Any, matplotlib.axes._axes.Axes], matplotlib.axes._axes.Axes],
num_samples: int = 8,
interval: int = 200
) -> IPython.core.display.HTML
ποΈ plots several labeled samples generated from a SamplingTree
with ClassLabelNode
s as an animation.
Name | Type | Description | Default |
tree | SamplingTree | tree to generate samples from | None |
plotting_func | Callable[[Any, plt.Axes], plt.Axes] | function to visualize a single sample. The function should take a sample and a plt.Axes as arguments. |
None |
num_samples | int | number of samples to include in the animation. Defaults to 8. | 8 |
interval | int | delay between frames in milliseconds. Defaults to 200. | 200 |
Type | Description |
IPython.display.HTML | animation player as an ipython HTML output |
View Source
def plot_labeled_samples_animation(
tree: SamplingTree,
plotting_func: Callable[[Any, plt.Axes], plt.Axes],
num_samples: int = 8,
interval: int = 200,
) -> IPython.display.HTML:
"""ποΈ plots several labeled samples generated from a `SamplingTree` with `ClassLabelNode`s as an animation.
tree (SamplingTree): tree to generate samples from
plotting_func (Callable[[Any, plt.Axes], plt.Axes]): function to visualize a single sample.
The function should take a sample and a `plt.Axes` as arguments.
num_samples (int, optional): number of samples to include in the animation. Defaults to 8.
interval (int, optional): delay between frames in milliseconds. Defaults to 200.
IPython.display.HTML: animation player as an ipython HTML output
with io.capture_output() as _captured:
samples = tree.collect(num_samples)
fig, ax = plt.subplots()
def _plotting_func(idx: int, ax: plt.Axes) -> None:
sample, target = samples[idx]
plotting_func(sample, ax)
anim = animation.FuncAnimation(fig, _plotting_func, frames=num_samples, fargs=(ax,), interval=interval)
return IPython.display.HTML(anim.to_jshtml())
def plot_labeled_samples_grid(
tree: streamgen.samplers.tree.SamplingTree,
plotting_func:[[typing.Any, matplotlib.axes._axes.Axes], matplotlib.axes._axes.Axes],
columns: int = 4
) -> matplotlib.figure.Figure
π plots a columns
grid of labeled samples generated from a SamplingTree
with ClassLabelNode
Name | Type | Description | Default |
tree | SamplingTree | tree to generate samples from | None |
plotting_func | Callable[[Any, plt.Axes], plt.Axes] | function to visualize a single sample. The function should take a sample and a plt.Axes as arguments. |
None |
columns | int | number of samples in the columns (and rows). Defaults to 4. | 4 |
Type | Description |
mpl.figure.Figure | matplotlib figure object |
View Source
def plot_labeled_samples_grid(
tree: SamplingTree,
plotting_func: Callable[[Any, plt.Axes], plt.Axes],
columns: int = 4,
) -> mpl.figure.Figure:
"""π plots a `columns`x`columns` grid of labeled samples generated from a `SamplingTree` with `ClassLabelNode`s.
tree (SamplingTree): tree to generate samples from
plotting_func (Callable[[Any, plt.Axes], plt.Axes]): function to visualize a single sample.
The function should take a sample and a `plt.Axes` as arguments.
columns (int, optional): number of samples in the columns (and rows). Defaults to 4.
mpl.figure.Figure: matplotlib figure object
num_samples = columns * columns
with io.capture_output() as _captured:
samples = tree.collect(num_samples)
fig = plt.figure()
for idx, (sample, target) in enumerate(samples):
ax = fig.add_subplot(columns, columns, idx + 1)
plotting_func(sample, ax)
fig.set_figheight(columns * 3.0)
fig.set_figwidth(columns * 3.0)
return fig
def plot_parameter(
param: streamgen.parameter.Parameter,
num_values: int | None = None,
ax: matplotlib.axes._axes.Axes | None = None
) -> matplotlib.axes._axes.Axes
βοΈ plots the scheduled values of a parameter.
This function currently supports plotting numeric parameters and probabilities.
Name | Type | Description | Default |
param | Parameter | parameter to plot | None |
num_values | int | None | number of values to plot. If None, collect all values from the schedule. Defaults to None. |
ax | plt.Axes | None | matplotlib Axes to plot to. Defaults to None. |
Type | Description |
plt.Axes | parameter plot |
Type | Description |
NotImplementedError | when the type of the parameter values are not yet supported by this function |
View Source
def plot_parameter(param: Parameter, num_values: int | None = None, ax: plt.Axes | None = None) -> plt.Axes:
"""βοΈ plots the scheduled values of a parameter.
This function currently supports plotting numeric parameters and probabilities.
param (Parameter): parameter to plot
num_values (int | None, optional): number of values to plot.
If None, collect all values from the schedule. Defaults to None.
ax (plt.Axes | None, optional): matplotlib Axes to plot to. Defaults to None.
NotImplementedError: when the type of the parameter values are not yet supported by this function
plt.Axes: parameter plot
match num_values:
case int():
assert num_values > 1, "at least two value are needed for the plot." # noqa: S101
values = [param.value, *list(islice(deepcopy(param.schedule), num_values - 1))]
case None:
values = [param.value, *list(islice(deepcopy(param.schedule), None))]
return plot(values, ax,
def plot_parameter_store(
num_values: int | None = None
) -> matplotlib.figure.Figure
ποΈ plots every parameter in a ParameterStore
in one figure.
Name | Type | Description | Default |
store | ParameterStore | parameter store to plot | None |
num_values | int | None | number of values to plot. If None, collect all values from the schedule. Defaults to None. |
Type | Description |
mpl.figure.Figure | matplotlib figure object |
Type | Description |
NotImplementedError | when the type of the parameter values are not yet supported |
View Source
def plot_parameter_store(store: ParameterStore, num_values: int | None = None) -> mpl.figure.Figure:
"""ποΈ plots every parameter in a `ParameterStore` in one figure.
store (ParameterStore): parameter store to plot
num_values (int | None, optional): number of values to plot.
If None, collect all values from the schedule. Defaults to None.
NotImplementedError: when the type of the parameter values are not yet supported
mpl.figure.Figure: matplotlib figure object
num_columns: int = len(store.parameter_names)
fig = plt.figure()
for idx, param in enumerate(store.parameter_names):
ax = fig.add_subplot(num_columns, 1, idx + 1)
plot_parameter(store[param], num_values=num_values, ax=ax)
fig.set_figheight(num_columns * 3.0)
return fig
def plot_parameter_store_widget(
num_values: int | None = None
) -> ipywidgets.widgets.widget_selectioncontainer.Tab
π plots every parameter of each scope in a ParameterStore
in a separate ipywidgets.widgets.Tab
Name | Type | Description | Default |
store | ParameterStore | parameter store to plot | None |
num_values | int | None | number of values to plot. If None, collect all values from the schedule. Defaults to None. |
Type | Description |
widgets.Tab | ipywidgets tab widget |
Type | Description |
NotImplementedError | when the type of the parameter values are not yet supported |
View Source
def plot_parameter_store_widget(store: ParameterStore, num_values: int | None = None) -> widgets.Tab:
"""π plots every parameter of each scope in a `ParameterStore` in a separate `ipywidgets.widgets.Tab`.
store (ParameterStore): parameter store to plot
num_values (int | None, optional): number of values to plot.
If None, collect all values from the schedule. Defaults to None.
NotImplementedError: when the type of the parameter values are not yet supported
widgets.Tab: ipywidgets tab widget
scopes = list(store.scopes)
tabs = [widgets.Output() for _ in scopes]
widget = widgets.Tab(children=tabs)
for idx, scope in enumerate(scopes):
widget.set_title(idx, scope)
with tabs[idx]:
params = store.get_scope(scope)
fig = plot_parameter_store(params, num_values=num_values)
return widget