Skip to content

📊 Data Drift Scenarios

This notebook demonstrates how to simulate three different data drift scenarios with streamgen.

✨ the sampling tree abstraction is general enough to simulate all drifts in one stream


📄 Table of Contents

  1. 📈 Covariate shift
  2. 📊 Prior probability shift
  3. 💡 Concept shift
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from IPython.display import HTML
from IPython.utils import io
from matplotlib.animation import FuncAnimation

from streamgen import visualizations
from streamgen.parameter.store import ParameterStore
from streamgen.samplers.tree import SamplingTree

SEED = 42
rng = np.random.default_rng(SEED)


# ➡️ transforms and generators
def background(signal_, signal_length: int, offset: float, strength: float) -> np.ndarray:  # noqa: D103, ANN001, ARG001
    return rng.normal(offset, strength, signal_length)


def ramp(signal: np.ndarray, height: float, length: int) -> np.ndarray:  # noqa: D103
    ramp_signal = np.zeros(len(signal))
    ramp_start = rng.choice(range(len(signal) - length))
    ramp_signal[ramp_start : ramp_start + length] = np.linspace(0.0, height, length)
    return signal + ramp_signal


def step(signal: np.ndarray, length: int, kernel_size: int) -> np.ndarray:  # noqa: D103
    step_signal = np.zeros(len(signal))
    step_start = rng.choice(range(len(signal) - length))
    step_signal[step_start : step_start + length] = 1.0
    kernel = np.ones(kernel_size) / kernel_size
    step_signal = np.convolve(step_signal, kernel, mode="same")
    return signal + step_signal


# 📼 stream visualization
def animate(tree):
    with io.capture_output() as captured:
        experiences = []
        for _ in range(3):
            experiences.append(tree.collect(16))
            tree.update()

        def create_frame(idx, ax) -> None:  # noqa: ANN001
            ax[0].cla()
            signal, target = experiences[0][idx]
            sns.lineplot(signal, ax=ax[0])
            ax[0].set_title("experience 1")
            ax[0].set_ylim(-2.2, 2.2)

            ax[1].cla()
            signal, target = experiences[1][idx]
            sns.lineplot(signal, ax=ax[1])
            ax[1].set_title("experience 2")
            ax[1].set_ylim(-2.2, 2.2)

            ax[2].cla()
            signal, target = experiences[2][idx]
            sns.lineplot(signal, ax=ax[2])
            ax[2].set_title("experience 3")
            ax[2].set_ylim(-2.2, 2.2)

        fig, ax = plt.subplots(1, 3, figsize=(15, 5), sharey=True)
        animation = FuncAnimation(fig, create_frame, frames=16, fargs=(ax,))
    return HTML(animation.to_jshtml())

📈 Covariate shift

This type of data drift happens when the distribution of the independent variables (input features) changes.

When modelling distributions with trees of transformations, this corresponds to changing the transformations (or their parameters) in the trunk of the tree (from the root node until the first branching point).

covariate shift

%matplotlib notebook

df = pd.DataFrame(
    {
        "background.signal_length": 256,
        "background.offset": [0.0, 0.2, 0.5],
        "background.strength": [0.1, 0.2, 0.4],
        "ramp.height": 1.0,
        "ramp.length": 128,
        "step.length": 128,
        "step.kernel_size": 1,
    },
)

# 🎲🌳 tree of transformations
tree = SamplingTree(
    [
        background,
        {
            "background": "background",
            "ramp": [ramp, "ramp"],
            "step": [step, "step"],
        },
    ],
    df,
)

animate(tree)

📊 Prior probability shift

This type of data drift happens when the class(-prior) distribution changes over time. A more extreme type of this shift happens when new classes are introduced over time.

When modelling distributions with trees of transformations, this corresponds to changing the probabilities of the branching points of the tree (or when adding new branches).

prior probability shift

%matplotlib inline

df = pd.DataFrame(
    {
        "background.signal_length": 256,
        "background.offset": 0.0,
        "background.strength": 0.1,
        "branching_node.probs": [
            [0.5, 0.5, 0.0],
            [0.1, 0.9, 0.0],
            [0.0, 0.1, 0.9],
        ],
        "ramp.height": 1.0,
        "ramp.length": 128,
        "step.length": 128,
        "step.kernel_size": 1,
    },
)

visualizations.plot_parameter_store_widget(ParameterStore.from_dataframe(df), num_values=3)
Tab(children=(Output(), Output(), Output(), Output()), selected_index=0, titles=('ramp', 'background', 'branch…
%matplotlib notebook

# 🎲🌳 tree of transformations
tree = SamplingTree(
    [
        background,
        {
            "background": "background",
            "ramp": [ramp, "ramp"],
            "step": [step, "step"],
        },
    ],
    df,
)

animate(tree)

💡 Concept shift

This type of data drift happens when the distribution of the dependent variables (classes) changes over time.

When modelling distributions with trees of transformations, this corresponds to changing the transformations (or their parameters) in the branches and leaves of the tree.

concept shift

df = pd.DataFrame(
    {
        "background.signal_length": 256,
        "background.offset": 0.0,
        "background.strength": 0.1,
        "ramp.height": [1.0, 1.25, 1.5],
        "ramp.length": [128, 64, 32],
        "step.length": 128,
        "step.kernel_size": 1,
    },
)

# 🎲🌳 tree of transformations
tree = SamplingTree(
    [
        background,
        {
            "background": "background",
            "ramp": [ramp, "ramp"],
            "step": [step, "step"],
        },
    ],
    df,
)

animate(tree)