ðē Sampling from Static Distributions
This notebook demonstrates how to create distributions from a series of functions using streamgen
.
First, we will create one background and two class signals using numpy
.
Then, we will demonstrate how to use the SamplingTree
class to sample from these distributions.
Finally, we will add ClassLabelNode
s, which will allow us to sample labeled signals.
ð Table of Contents
from pathlib import Path
import numpy as np
import seaborn as sns
from rich import print
from streamgen import visualizations
from streamgen.nodes import SampleBufferNode
from streamgen.parameter.store import ParameterStore
from streamgen.samplers.tree import SamplingTree
from streamgen.transforms import noop
SEED = 42
rng = np.random.default_rng(SEED)
sns.set_theme()
output_path = Path("./")
output_path.mkdir(parents=True, exist_ok=True)
ð§Ū Function based data
We want to generate a dataset of time series. This dataset consists of three classes with the following parameters:
- ðĩ "background" - background noise from our measurement equipment
- âïļ
signal_length
- how many measurement points are in our signal - âïļ
offset
- signal offset of the background noise - âïļ
strength
- signal strength of the background noise
- âïļ
- ðŠ "ramp" - a slowly climbing ramp signal
- âïļ
height
- height of the ramp - âïļ
length
- length of the ramp
- âïļ
- ð "step" - a step signal
- âïļ
length
- length of the step - âïļ
kernel_size
- kernel size for a moving average filtering
- âïļ
def background(signal_, signal_length: int, offset: float, strength: float) -> np.ndarray: # noqa: D103, ANN001, ARG001
return rng.normal(offset, strength, signal_length)
def ramp(signal: np.ndarray, height: float, length: int) -> np.ndarray: # noqa: D103
ramp_signal = np.zeros(len(signal))
ramp_start = rng.choice(range(len(signal) - length))
ramp_signal[ramp_start : ramp_start + length] = np.linspace(0.0, height, length)
return signal + ramp_signal
def step(signal: np.ndarray, length: int, kernel_size: int) -> np.ndarray: # noqa: D103
step_signal = np.zeros(len(signal))
step_start = rng.choice(range(len(signal) - length))
step_signal[step_start : step_start + length] = 1.0
kernel = np.ones(kernel_size) / kernel_size
step_signal = np.convolve(step_signal, kernel, mode="same")
return signal + step_signal
%matplotlib inline
# âïļ parameters
signal_length = 256
offset = 0.0
noise_strength = 0.1
height = 1.0
length = 128
kernel_size = 10
# ðē sample signals
background_sample = background(None, signal_length, offset, noise_strength)
ramp_sample = ramp(background_sample, height, length)
step_sample = step(background_sample, length, kernel_size)
# ð plot signals
sns.lineplot(ramp_sample, label="ramp")
sns.lineplot(step_sample, label="step")
sns.lineplot(background_sample, label="background");
ðģ Sampling trees
In the last section, we manually sampled from our three signal generation functions.
This approach has a few shortcomings:
- for each sample, one has to call the corresponding function by hand ðïļ
- the parameter handling is a bit awkward and clumsy (parameter to function association is not clear)
- dependencies (like
step
andramp
requiringbackground
as input) are not clear
To address all of these issues, streamgen
provides the SamplingTree
class, which organizes and calls functions very similar to torchvision.transforms.Compose
or torch.nn.Sequential
:
from torchvision.transforms import v2
transforms = v2.Compose([
v2.RandomHorizontalFlip(),
v2.RandomResizedCrop(size=(224, 224), antialias=True),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ...
])
while generating_data:
sample = np.random.randn(...)
augmented_sample = transforms(sample)
# -> sample is passed to the first function in transforms.
# the output if this transform si then passed as input to the next transform ...
Instead of using a sequential organization of functions, streamgen
uses trees of transformations.
We use the package anytree and a custom Node
class (streamgen.nodes.TransformNode
) to construct these trees.
For convenient construction, we also provide a shorthand description following the rules:
- Nodes are linked sequentially according to the ordering in the top-level list.
TransformNode
and sub-classes are not modified.Callable
s are cast intoTransformNode
s.str
are cast intoClassLabelNode
- dictionaries are interpreted as
BranchingNode
s, where each value represents a branch. The keysname
,probs
andseed
are reserved to describe the node itself. - If there is a node after a
BranchingNode
, then every branch will be connected to a copy of this node. This ensures that the structure of the tree is preserved (Otherwise we would create a more generic directed acyclic graph), which is not supported byanytree
.
The parameters of the functions inside a tree are stored in a ParameterStore
object. A sampling tree knows which parameters need to be passed to the functions based on scopes.
ð more details about parameter handling is explained in the next notebook on ð data streams.
# âïļ scoped parameter
params = ParameterStore(
{
"background": {
"signal_length": {"value": 256},
"offset": {"value": 0.0},
"strength": {"value": 0.1},
},
"ramp": {
"height": {"value": 1.0},
"length": {"value": 128},
},
"step": {
"length": {"value": 128},
"kernel_size": {"value": 10},
},
},
)
# ðēðģ tree of transformations
tree = SamplingTree(
[
background,
{
"background": noop,
"ramp": ramp,
"step": step,
},
],
params,
)
print(tree)
ðģ âĄïļ `background(offset=0.0, signal_length=256, strength=0.1)` â°ââ ðŠī `branching_node()` âââ âĄïļ `noop()` âââ âĄïļ `ramp(height=1.0, length=128)` â°ââ âĄïļ `step(kernel_size=10, length=128)`
# ðē sample from tree
signal = tree.sample()
sns.lineplot(signal);
ð·ïļ Generating labels
Every path of transformations to a leaf in our sampling tree already corresponds to a "class".
To include the information about the branch into the sample, streamgen
provides a ClassLabelNode
, which is automatically created from strings in the tree definition.
# ðēðģ tree of transformations
tree = SamplingTree(
[
background,
SampleBufferNode("background-samples"),
{
"background": "background",
"ramp": [ramp, "ramp", SampleBufferNode("ramp-samples")],
"step": [step, "step", SampleBufferNode("step-samples")],
},
],
params,
)
print(tree)
ðģ âĄïļ `background(offset=0.0, signal_length=256, strength=0.1)` â°ââ ðïļ `background-samples()` â°ââ ðŠī `branching_node()` âââ ð·ïļ `background` âââ âĄïļ `ramp(height=1.0, length=128)` â â°ââ ð·ïļ `ramp` â â°ââ ðïļ `ramp-samples()` â°ââ âĄïļ `step(kernel_size=10, length=128)` â°ââ ð·ïļ `step` â°ââ ðïļ `step-samples()`
%matplotlib notebook
def plotting_func(sample, ax):
ax = sns.lineplot(sample, ax=ax)
return ax
tree.to_svg(output_path / "tree", plotting_func)
The SampleBufferNode
s don't contain any buffered samples yet to show in the rendered graph.
After collecting some samples, the visualization will include sample plots:
%matplotlib notebook
samples = tree.collect(20)
tree.to_svg(output_path / "tree_with_samples", plotting_func)
Output()
ðĨē Unfortunately, the example gifs in the svg graph are not rendered in the hosted documentation, so we include a screen capture from the IPython interface for reference:
%matplotlib inline
# ðē sample from tree
signal, target = tree.sample()
sns.lineplot(signal).set_title(target);
ðžïļ there are many useful visualizations in streamgen.visualizations
.
Because any data sctructure can by generated by a user, he has to provide his own sample plotting_func
to most of these functions.
%matplotlib notebook
visualizations.plot_labeled_samples_animation(tree, lambda sample, ax: sns.lineplot(sample, ax=ax))
%matplotlib inline
visualizations.plot_labeled_samples_grid(tree, lambda sample, ax: sns.lineplot(sample, ax=ax));