Skip to content

ðŸŽē Sampling from Static Distributions

This notebook demonstrates how to create distributions from a series of functions using streamgen.

First, we will create one background and two class signals using numpy. Then, we will demonstrate how to use the SamplingTree class to sample from these distributions. Finally, we will add ClassLabelNodes, which will allow us to sample labeled signals.


📄 Table of Contents

  1. ðŸ§Ū Function based data generation
  2. ðŸŒģ Sampling trees
  3. 🏷ïļ Generating labels
from pathlib import Path

import numpy as np
import seaborn as sns
from rich import print

from streamgen import visualizations
from streamgen.nodes import SampleBufferNode
from streamgen.parameter.store import ParameterStore
from streamgen.samplers.tree import SamplingTree
from streamgen.transforms import noop

SEED = 42
rng = np.random.default_rng(SEED)

sns.set_theme()

output_path = Path("./")
output_path.mkdir(parents=True, exist_ok=True)

ðŸ§Ū Function based data

We want to generate a dataset of time series. This dataset consists of three classes with the following parameters:

  1. ðŸŽĩ "background" - background noise from our measurement equipment
    • ⚙ïļ signal_length - how many measurement points are in our signal
    • ⚙ïļ offset - signal offset of the background noise
    • ⚙ïļ strength - signal strength of the background noise
  2. 🊜 "ramp" - a slowly climbing ramp signal
    • ⚙ïļ height - height of the ramp
    • ⚙ïļ length - length of the ramp
  3. 🏃 "step" - a step signal
    • ⚙ïļ length - length of the step
    • ⚙ïļ kernel_size - kernel size for a moving average filtering
def background(signal_, signal_length: int, offset: float, strength: float) -> np.ndarray:  # noqa: D103, ANN001, ARG001
    return rng.normal(offset, strength, signal_length)


def ramp(signal: np.ndarray, height: float, length: int) -> np.ndarray:  # noqa: D103
    ramp_signal = np.zeros(len(signal))
    ramp_start = rng.choice(range(len(signal) - length))
    ramp_signal[ramp_start : ramp_start + length] = np.linspace(0.0, height, length)
    return signal + ramp_signal


def step(signal: np.ndarray, length: int, kernel_size: int) -> np.ndarray:  # noqa: D103
    step_signal = np.zeros(len(signal))
    step_start = rng.choice(range(len(signal) - length))
    step_signal[step_start : step_start + length] = 1.0
    kernel = np.ones(kernel_size) / kernel_size
    step_signal = np.convolve(step_signal, kernel, mode="same")
    return signal + step_signal
%matplotlib inline

# ⚙ïļ parameters
signal_length = 256
offset = 0.0
noise_strength = 0.1
height = 1.0
length = 128
kernel_size = 10

# ðŸŽē sample signals
background_sample = background(None, signal_length, offset, noise_strength)
ramp_sample = ramp(background_sample, height, length)
step_sample = step(background_sample, length, kernel_size)

# 📈 plot signals
sns.lineplot(ramp_sample, label="ramp")
sns.lineplot(step_sample, label="step")
sns.lineplot(background_sample, label="background");

png

ðŸŒģ Sampling trees

In the last section, we manually sampled from our three signal generation functions.

This approach has a few shortcomings:

  • for each sample, one has to call the corresponding function by hand 🖐ïļ
  • the parameter handling is a bit awkward and clumsy (parameter to function association is not clear)
  • dependencies (like step and ramp requiring background as input) are not clear

To address all of these issues, streamgen provides the SamplingTree class, which organizes and calls functions very similar to torchvision.transforms.Compose or torch.nn.Sequential:

from torchvision.transforms import v2

transforms = v2.Compose([
    v2.RandomHorizontalFlip(),
    v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    # ...
])

while generating_data:
    sample = np.random.randn(...)
    augmented_sample = transforms(sample)
    # -> sample is passed to the first function in transforms.
    # the output if this transform si then passed as input to the next transform ...

Instead of using a sequential organization of functions, streamgen uses trees of transformations.

sampling tree

We use the package anytree and a custom Node class (streamgen.nodes.TransformNode) to construct these trees. For convenient construction, we also provide a shorthand description following the rules:

  1. Nodes are linked sequentially according to the ordering in the top-level list.
  2. TransformNode and sub-classes are not modified.
  3. Callables are cast into TransformNodes.
  4. str are cast into ClassLabelNode
  5. dictionaries are interpreted as BranchingNodes, where each value represents a branch. The keys name, probs and seed are reserved to describe the node itself.
  6. If there is a node after a BranchingNode, then every branch will be connected to a copy of this node. This ensures that the structure of the tree is preserved (Otherwise we would create a more generic directed acyclic graph), which is not supported by anytree.

The parameters of the functions inside a tree are stored in a ParameterStore object. A sampling tree knows which parameters need to be passed to the functions based on scopes.

👉 more details about parameter handling is explained in the next notebook on 🌌 data streams.

# ⚙ïļ scoped parameter
params = ParameterStore(
    {
        "background": {
            "signal_length": {"value": 256},
            "offset": {"value": 0.0},
            "strength": {"value": 0.1},
        },
        "ramp": {
            "height": {"value": 1.0},
            "length": {"value": 128},
        },
        "step": {
            "length": {"value": 128},
            "kernel_size": {"value": 10},
        },
    },
)

# ðŸŽēðŸŒģ tree of transformations
tree = SamplingTree(
    [
        background,
        {
            "background": noop,
            "ramp": ramp,
            "step": step,
        },
    ],
    params,
)

print(tree)
ðŸŒģ
➡ïļ `background(offset=0.0, signal_length=256, strength=0.1)`
╰── ðŸŠī `branching_node()`
    ├── ➡ïļ `noop()`
    ├── ➡ïļ `ramp(height=1.0, length=128)`
    ╰── ➡ïļ `step(kernel_size=10, length=128)`

# ðŸŽē sample from tree
signal = tree.sample()
sns.lineplot(signal);

png

🏷ïļ Generating labels

Every path of transformations to a leaf in our sampling tree already corresponds to a "class".

To include the information about the branch into the sample, streamgen provides a ClassLabelNode, which is automatically created from strings in the tree definition.

# ðŸŽēðŸŒģ tree of transformations
tree = SamplingTree(
    [
        background,
        SampleBufferNode("background-samples"),
        {
            "background": "background",
            "ramp": [ramp, "ramp", SampleBufferNode("ramp-samples")],
            "step": [step, "step", SampleBufferNode("step-samples")],
        },
    ],
    params,
)

print(tree)
ðŸŒģ
➡ïļ `background(offset=0.0, signal_length=256, strength=0.1)`
╰── 🗃ïļ `background-samples()`
    ╰── ðŸŠī `branching_node()`
        ├── 🏷ïļ `background`
        ├── ➡ïļ `ramp(height=1.0, length=128)`
        │   ╰── 🏷ïļ `ramp`
        │       ╰── 🗃ïļ `ramp-samples()`
        ╰── ➡ïļ `step(kernel_size=10, length=128)`
            ╰── 🏷ïļ `step`
                ╰── 🗃ïļ `step-samples()`

%matplotlib notebook


def plotting_func(sample, ax):
    ax = sns.lineplot(sample, ax=ax)
    return ax


tree.to_svg(output_path / "tree", plotting_func)

svg

The SampleBufferNodes don't contain any buffered samples yet to show in the rendered graph. After collecting some samples, the visualization will include sample plots:

%matplotlib notebook
samples = tree.collect(20)

tree.to_svg(output_path / "tree_with_samples", plotting_func)
Output()




svg

ðŸĨē Unfortunately, the example gifs in the svg graph are not rendered in the hosted documentation, so we include a screen capture from the IPython interface for reference:

output of last cell

%matplotlib inline
# ðŸŽē sample from tree
signal, target = tree.sample()
sns.lineplot(signal).set_title(target);

png

🖞ïļ there are many useful visualizations in streamgen.visualizations.

Because any data sctructure can by generated by a user, he has to provide his own sample plotting_func to most of these functions.

%matplotlib notebook
visualizations.plot_labeled_samples_animation(tree, lambda sample, ax: sns.lineplot(sample, ax=ax))

%matplotlib inline
visualizations.plot_labeled_samples_grid(tree, lambda sample, ax: sns.lineplot(sample, ax=ax));

png