Skip to content

Module streamgen.visualizations

πŸ–ΌοΈparameter visualization functions.

View Source
"""πŸ–ΌοΈparameter visualization functions."""

from collections.abc import Callable

from copy import deepcopy

from itertools import islice

from typing import Any

import IPython

import matplotlib as mpl

import matplotlib.pyplot as plt

import numpy as np

import seaborn as sns

from IPython.utils import io

from ipywidgets import widgets

from matplotlib import animation

from streamgen.parameter import Parameter

from streamgen.parameter.store import ParameterStore

from streamgen.samplers.tree import SamplingTree

def plot(values: list[int | float | list | np.ndarray], ax: plt.Axes | None = None, title: str | None = None) -> plt.Axes:

    """πŸ“ˆ plots the scheduled values of a parameter.

    This function currently supports plotting numeric parameters and probabilities.

    Args:

        values (list[int | float | list | np.ndarray]): list of values to plot.

        ax (plt.Axes | None, optional): matplotlib Axes to plot to. Defaults to None.

        title (str | None, optional): title of the plot. Defaults to None

    Raises:

        NotImplementedError: when the type of the parameter values are not yet supported by this function

    Returns:

        plt.Axes: parameter plot

    """

    match values[0]:

        case int() | float():

            ax = sns.lineplot(values, ax=ax, marker="o")

        case np.ndarray() | list():

            if ax is None:

                fig, ax = plt.subplots()

            sns.set_theme()

            indices = list(range(len(values)))

            values = np.array(values).T

            assert len(values.shape) == 2, "only arrays with two dimensions can be visualized here"  # noqa: S101, PLR2004

            ax.stackplot(indices, values)

        case _:

            raise NotImplementedError

    ax.set_title(title)

    return ax

def plot_parameter(param: Parameter, num_values: int | None = None, ax: plt.Axes | None = None) -> plt.Axes:

    """βš™οΈ plots the scheduled values of a parameter.

    This function currently supports plotting numeric parameters and probabilities.

    Args:

        param (Parameter): parameter to plot

        num_values (int | None, optional): number of values to plot.

            If None, collect all values from the schedule. Defaults to None.

        ax (plt.Axes | None, optional): matplotlib Axes to plot to. Defaults to None.

    Raises:

        NotImplementedError: when the type of the parameter values are not yet supported by this function

    Returns:

        plt.Axes: parameter plot

    """

    match num_values:

        case int():

            assert num_values > 1, "at least two value are needed for the plot."  # noqa: S101

            values = [param.value, *list(islice(deepcopy(param.schedule), num_values - 1))]

        case None:

            values = [param.value, *list(islice(deepcopy(param.schedule), None))]

    return plot(values, ax, title=param.name)

def plot_parameter_store(store: ParameterStore, num_values: int | None = None) -> mpl.figure.Figure:

    """πŸ—„οΈ plots every parameter in a `ParameterStore` in one figure.

    Args:

        store (ParameterStore): parameter store to plot

        num_values (int | None, optional): number of values to plot.

            If None, collect all values from the schedule. Defaults to None.

    Raises:

        NotImplementedError: when the type of the parameter values are not yet supported

    Returns:

        mpl.figure.Figure: matplotlib figure object

    """

    num_columns: int = len(store.parameter_names)

    sns.set_theme()

    fig = plt.figure()

    for idx, param in enumerate(store.parameter_names):

        ax = fig.add_subplot(num_columns, 1, idx + 1)

        plot_parameter(store[param], num_values=num_values, ax=ax)

    fig.set_figheight(num_columns * 3.0)

    plt.tight_layout()

    return fig

def plot_parameter_store_widget(store: ParameterStore, num_values: int | None = None) -> widgets.Tab:

    """πŸ“‚ plots every parameter of each scope in a `ParameterStore` in a separate `ipywidgets.widgets.Tab`.

    Args:

        store (ParameterStore): parameter store to plot

        num_values (int | None, optional): number of values to plot.

            If None, collect all values from the schedule. Defaults to None.

    Raises:

        NotImplementedError: when the type of the parameter values are not yet supported

    Returns:

        widgets.Tab: ipywidgets tab widget

    """

    scopes = list(store.scopes)

    tabs = [widgets.Output() for _ in scopes]

    widget = widgets.Tab(children=tabs)

    for idx, scope in enumerate(scopes):

        widget.set_title(idx, scope)

        with tabs[idx]:

            params = store.get_scope(scope)

            fig = plot_parameter_store(params, num_values=num_values)

            plt.show(fig)

    return widget

def plot_labeled_samples_grid(

    tree: SamplingTree,

    plotting_func: Callable[[Any, plt.Axes], plt.Axes],

    columns: int = 4,

) -> mpl.figure.Figure:

    """πŸ“Ÿ plots a `columns`x`columns` grid of labeled samples generated from a `SamplingTree` with `ClassLabelNode`s.

    Args:

        tree (SamplingTree): tree to generate samples from

        plotting_func (Callable[[Any, plt.Axes], plt.Axes]): function to visualize a single sample.

            The function should take a sample and a `plt.Axes` as arguments.

        columns (int, optional): number of samples in the columns (and rows). Defaults to 4.

    Returns:

        mpl.figure.Figure: matplotlib figure object

    """

    num_samples = columns * columns

    with io.capture_output() as _captured:

        samples = tree.collect(num_samples)

    sns.set_theme()

    fig = plt.figure()

    for idx, (sample, target) in enumerate(samples):

        ax = fig.add_subplot(columns, columns, idx + 1)

        plotting_func(sample, ax)

        ax.set_title(target)

    fig.set_figheight(columns * 3.0)

    fig.set_figwidth(columns * 3.0)

    plt.tight_layout()

    return fig

def plot_labeled_samples_animation(

    tree: SamplingTree,

    plotting_func: Callable[[Any, plt.Axes], plt.Axes],

    num_samples: int = 8,

    interval: int = 200,

) -> IPython.display.HTML:

    """🎞️ plots several labeled samples generated from a `SamplingTree` with `ClassLabelNode`s as an animation.

    Args:

        tree (SamplingTree): tree to generate samples from

        plotting_func (Callable[[Any, plt.Axes], plt.Axes]): function to visualize a single sample.

            The function should take a sample and a `plt.Axes` as arguments.

        num_samples (int, optional): number of samples to include in the animation. Defaults to 8.

        interval (int, optional): delay between frames in milliseconds. Defaults to 200.

    Returns:

        IPython.display.HTML: animation player as an ipython HTML output

    """

    with io.capture_output() as _captured:

        samples = tree.collect(num_samples)

        sns.set_theme()

        fig, ax = plt.subplots()

        def _plotting_func(idx: int, ax: plt.Axes) -> None:

            sample, target = samples[idx]

            ax.clear()

            plotting_func(sample, ax)

            ax.set_title(target)

        anim = animation.FuncAnimation(fig, _plotting_func, frames=num_samples, fargs=(ax,), interval=interval)

    return IPython.display.HTML(anim.to_jshtml())

Functions

plot

def plot(
    values: list[int | float | list | numpy.ndarray],
    ax: matplotlib.axes._axes.Axes | None = None,
    title: str | None = None
) -> matplotlib.axes._axes.Axes

πŸ“ˆ plots the scheduled values of a parameter.

This function currently supports plotting numeric parameters and probabilities.

Parameters:

Name Type Description Default
values list[int float list
ax plt.Axes None matplotlib Axes to plot to. Defaults to None.
title str None title of the plot. Defaults to None

Returns:

Type Description
plt.Axes parameter plot

Raises:

Type Description
NotImplementedError when the type of the parameter values are not yet supported by this function
View Source
def plot(values: list[int | float | list | np.ndarray], ax: plt.Axes | None = None, title: str | None = None) -> plt.Axes:

    """πŸ“ˆ plots the scheduled values of a parameter.

    This function currently supports plotting numeric parameters and probabilities.

    Args:

        values (list[int | float | list | np.ndarray]): list of values to plot.

        ax (plt.Axes | None, optional): matplotlib Axes to plot to. Defaults to None.

        title (str | None, optional): title of the plot. Defaults to None

    Raises:

        NotImplementedError: when the type of the parameter values are not yet supported by this function

    Returns:

        plt.Axes: parameter plot

    """

    match values[0]:

        case int() | float():

            ax = sns.lineplot(values, ax=ax, marker="o")

        case np.ndarray() | list():

            if ax is None:

                fig, ax = plt.subplots()

            sns.set_theme()

            indices = list(range(len(values)))

            values = np.array(values).T

            assert len(values.shape) == 2, "only arrays with two dimensions can be visualized here"  # noqa: S101, PLR2004

            ax.stackplot(indices, values)

        case _:

            raise NotImplementedError

    ax.set_title(title)

    return ax

plot_labeled_samples_animation

def plot_labeled_samples_animation(
    tree: streamgen.samplers.tree.SamplingTree,
    plotting_func: collections.abc.Callable[[typing.Any, matplotlib.axes._axes.Axes], matplotlib.axes._axes.Axes],
    num_samples: int = 8,
    interval: int = 200
) -> IPython.core.display.HTML

🎞️ plots several labeled samples generated from a SamplingTree with ClassLabelNodes as an animation.

Parameters:

Name Type Description Default
tree SamplingTree tree to generate samples from None
plotting_func Callable[[Any, plt.Axes], plt.Axes] function to visualize a single sample.
The function should take a sample and a plt.Axes as arguments.
None
num_samples int number of samples to include in the animation. Defaults to 8. 8
interval int delay between frames in milliseconds. Defaults to 200. 200

Returns:

Type Description
IPython.display.HTML animation player as an ipython HTML output
View Source
def plot_labeled_samples_animation(

    tree: SamplingTree,

    plotting_func: Callable[[Any, plt.Axes], plt.Axes],

    num_samples: int = 8,

    interval: int = 200,

) -> IPython.display.HTML:

    """🎞️ plots several labeled samples generated from a `SamplingTree` with `ClassLabelNode`s as an animation.

    Args:

        tree (SamplingTree): tree to generate samples from

        plotting_func (Callable[[Any, plt.Axes], plt.Axes]): function to visualize a single sample.

            The function should take a sample and a `plt.Axes` as arguments.

        num_samples (int, optional): number of samples to include in the animation. Defaults to 8.

        interval (int, optional): delay between frames in milliseconds. Defaults to 200.

    Returns:

        IPython.display.HTML: animation player as an ipython HTML output

    """

    with io.capture_output() as _captured:

        samples = tree.collect(num_samples)

        sns.set_theme()

        fig, ax = plt.subplots()

        def _plotting_func(idx: int, ax: plt.Axes) -> None:

            sample, target = samples[idx]

            ax.clear()

            plotting_func(sample, ax)

            ax.set_title(target)

        anim = animation.FuncAnimation(fig, _plotting_func, frames=num_samples, fargs=(ax,), interval=interval)

    return IPython.display.HTML(anim.to_jshtml())

plot_labeled_samples_grid

def plot_labeled_samples_grid(
    tree: streamgen.samplers.tree.SamplingTree,
    plotting_func: collections.abc.Callable[[typing.Any, matplotlib.axes._axes.Axes], matplotlib.axes._axes.Axes],
    columns: int = 4
) -> matplotlib.figure.Figure

πŸ“Ÿ plots a columnsxcolumns grid of labeled samples generated from a SamplingTree with ClassLabelNodes.

Parameters:

Name Type Description Default
tree SamplingTree tree to generate samples from None
plotting_func Callable[[Any, plt.Axes], plt.Axes] function to visualize a single sample.
The function should take a sample and a plt.Axes as arguments.
None
columns int number of samples in the columns (and rows). Defaults to 4. 4

Returns:

Type Description
mpl.figure.Figure matplotlib figure object
View Source
def plot_labeled_samples_grid(

    tree: SamplingTree,

    plotting_func: Callable[[Any, plt.Axes], plt.Axes],

    columns: int = 4,

) -> mpl.figure.Figure:

    """πŸ“Ÿ plots a `columns`x`columns` grid of labeled samples generated from a `SamplingTree` with `ClassLabelNode`s.

    Args:

        tree (SamplingTree): tree to generate samples from

        plotting_func (Callable[[Any, plt.Axes], plt.Axes]): function to visualize a single sample.

            The function should take a sample and a `plt.Axes` as arguments.

        columns (int, optional): number of samples in the columns (and rows). Defaults to 4.

    Returns:

        mpl.figure.Figure: matplotlib figure object

    """

    num_samples = columns * columns

    with io.capture_output() as _captured:

        samples = tree.collect(num_samples)

    sns.set_theme()

    fig = plt.figure()

    for idx, (sample, target) in enumerate(samples):

        ax = fig.add_subplot(columns, columns, idx + 1)

        plotting_func(sample, ax)

        ax.set_title(target)

    fig.set_figheight(columns * 3.0)

    fig.set_figwidth(columns * 3.0)

    plt.tight_layout()

    return fig

plot_parameter

def plot_parameter(
    param: streamgen.parameter.Parameter,
    num_values: int | None = None,
    ax: matplotlib.axes._axes.Axes | None = None
) -> matplotlib.axes._axes.Axes

βš™οΈ plots the scheduled values of a parameter.

This function currently supports plotting numeric parameters and probabilities.

Parameters:

Name Type Description Default
param Parameter parameter to plot None
num_values int None number of values to plot.
If None, collect all values from the schedule. Defaults to None.
ax plt.Axes None matplotlib Axes to plot to. Defaults to None.

Returns:

Type Description
plt.Axes parameter plot

Raises:

Type Description
NotImplementedError when the type of the parameter values are not yet supported by this function
View Source
def plot_parameter(param: Parameter, num_values: int | None = None, ax: plt.Axes | None = None) -> plt.Axes:

    """βš™οΈ plots the scheduled values of a parameter.

    This function currently supports plotting numeric parameters and probabilities.

    Args:

        param (Parameter): parameter to plot

        num_values (int | None, optional): number of values to plot.

            If None, collect all values from the schedule. Defaults to None.

        ax (plt.Axes | None, optional): matplotlib Axes to plot to. Defaults to None.

    Raises:

        NotImplementedError: when the type of the parameter values are not yet supported by this function

    Returns:

        plt.Axes: parameter plot

    """

    match num_values:

        case int():

            assert num_values > 1, "at least two value are needed for the plot."  # noqa: S101

            values = [param.value, *list(islice(deepcopy(param.schedule), num_values - 1))]

        case None:

            values = [param.value, *list(islice(deepcopy(param.schedule), None))]

    return plot(values, ax, title=param.name)

plot_parameter_store

def plot_parameter_store(
    store: streamgen.parameter.store.ParameterStore,
    num_values: int | None = None
) -> matplotlib.figure.Figure

πŸ—„οΈ plots every parameter in a ParameterStore in one figure.

Parameters:

Name Type Description Default
store ParameterStore parameter store to plot None
num_values int None number of values to plot.
If None, collect all values from the schedule. Defaults to None.

Returns:

Type Description
mpl.figure.Figure matplotlib figure object

Raises:

Type Description
NotImplementedError when the type of the parameter values are not yet supported
View Source
def plot_parameter_store(store: ParameterStore, num_values: int | None = None) -> mpl.figure.Figure:

    """πŸ—„οΈ plots every parameter in a `ParameterStore` in one figure.

    Args:

        store (ParameterStore): parameter store to plot

        num_values (int | None, optional): number of values to plot.

            If None, collect all values from the schedule. Defaults to None.

    Raises:

        NotImplementedError: when the type of the parameter values are not yet supported

    Returns:

        mpl.figure.Figure: matplotlib figure object

    """

    num_columns: int = len(store.parameter_names)

    sns.set_theme()

    fig = plt.figure()

    for idx, param in enumerate(store.parameter_names):

        ax = fig.add_subplot(num_columns, 1, idx + 1)

        plot_parameter(store[param], num_values=num_values, ax=ax)

    fig.set_figheight(num_columns * 3.0)

    plt.tight_layout()

    return fig

plot_parameter_store_widget

def plot_parameter_store_widget(
    store: streamgen.parameter.store.ParameterStore,
    num_values: int | None = None
) -> ipywidgets.widgets.widget_selectioncontainer.Tab

πŸ“‚ plots every parameter of each scope in a ParameterStore in a separate ipywidgets.widgets.Tab.

Parameters:

Name Type Description Default
store ParameterStore parameter store to plot None
num_values int None number of values to plot.
If None, collect all values from the schedule. Defaults to None.

Returns:

Type Description
widgets.Tab ipywidgets tab widget

Raises:

Type Description
NotImplementedError when the type of the parameter values are not yet supported
View Source
def plot_parameter_store_widget(store: ParameterStore, num_values: int | None = None) -> widgets.Tab:

    """πŸ“‚ plots every parameter of each scope in a `ParameterStore` in a separate `ipywidgets.widgets.Tab`.

    Args:

        store (ParameterStore): parameter store to plot

        num_values (int | None, optional): number of values to plot.

            If None, collect all values from the schedule. Defaults to None.

    Raises:

        NotImplementedError: when the type of the parameter values are not yet supported

    Returns:

        widgets.Tab: ipywidgets tab widget

    """

    scopes = list(store.scopes)

    tabs = [widgets.Output() for _ in scopes]

    widget = widgets.Tab(children=tabs)

    for idx, scope in enumerate(scopes):

        widget.set_title(idx, scope)

        with tabs[idx]:

            params = store.get_scope(scope)

            fig = plot_parameter_store(params, num_values=num_values)

            plt.show(fig)

    return widget