Skip to content

Module streamgen.samplers.tree

🌳 sampling trees are trees of transformations that you can traverse from root to leaf to create samples.

View Source
"""🌳 sampling trees are trees of transformations that you can traverse from root to leaf to create samples."""

import itertools

from collections.abc import Callable

from copy import deepcopy

from itertools import pairwise

from pathlib import Path

from typing import Any, Self

import anytree

import numpy as np

from anytree.exporter import UniqueDotExporter

from beartype import beartype

from graphviz import Source

from IPython.display import SVG

from IPython.utils import io

from matplotlib import animation

from matplotlib import pyplot as plt

from pandas import DataFrame

from rich.progress import track

from streamgen.enums import SamplingStrategy, SamplingStrategyLit

from streamgen.nodes import ClassLabelNode, SampleBufferNode, TransformNode

from streamgen.parameter import Parameter

from streamgen.parameter.store import ParameterStore

from streamgen.samplers import Sampler

from streamgen.transforms import noop

class BranchingNode(TransformNode):

    """πŸͺ΄ node with multiple children/branches.

    When traversed, a random branch is selected based on the probabilities defined by `probs`.

    Args:

        branches (dict): dictionary, where each key:value pair represent label:branch.

        probs (Parameter | list[float] None, optional): parameter containing the probabilities for selecting each branch.

            `probs.value` is passed to `numpy.random.choice` as parameter `p`, which is documented as:

            (1-D array_like, optional) the probabilities associated with each entry in a.

            If not given the sample assumes a uniform distribution over all entries. Defaults to None.

        name (str | None, optional): name of the node. Important for fetching the `probs` if not present. Defaults to "branching_node".

        seed (int, optional): random number generator seed. Defaults to 42.

        string_node (Callable[[str], TransformNode], optional): `TransfromNode` constructor from strings used in `construct_tree`.

            Defaults to `ClassLabelNode`.

    """

    def __init__(  # noqa: D107

        self,

        branches: dict,

        probs: Parameter | list[float] | None = None,

        name: str | None = None,

        seed: int = 42,

        string_node: Callable[[str], TransformNode] = ClassLabelNode,

    ) -> None:

        self.name = name if name else "branching_node"

        if isinstance(probs, list):

            probs = Parameter(name="probs", value=probs)

        self.probs = probs

        self.rng = np.random.default_rng(seed)

        self.branches = {branch_name: construct_tree(nodes, string_node) for branch_name, nodes in branches.items()}

        self.children = [branch[0] for branch in self.branches.values()]

        super().__init__(transform=noop, name=self.name, emoji="πŸͺ΄")

    def traverse(self, input: Any) -> tuple[Any, anytree.NodeMixin]:  # noqa: A002, ANN401

        """πŸƒπŸŽ² `streamgen.transforms.Traverse` protocol `(input: Any) -> (output, anytree.NodeMixin | None)`.

        During traversal, a branching node samples the next node from its children.

        Args:

            input (Any): any input

        Returns:

            tuple[Any, anytree.NodeMixin | None]: output and next node to traverse

        """

        key = self.rng.choice(list(self.branches.keys()), p=self.probs.value if self.probs else None)

        next_node = self.branches[key][0]

        return input, next_node

    def update(self) -> None:

        """πŸ†™ updates every parameter."""

        if self.probs:

            self.probs.update()

        for branch in self.branches.values():

            for node in branch:

                node.update()

    def set_update_step(self, idx: int) -> None:

        """πŸ• updates every parameter to a certain update step.

        Args:

            idx (int): parameter update step

        Returns:

            None: this function mutates `self`

        """

        if self.probs:

            self.probs[idx]

        for branch in self.branches.values():

            for node in branch:

                node.set_update_step(idx)

    def fetch_params(self, params: ParameterStore) -> None:

        """βš™οΈ fetches params from a ParameterStore.

        Args:

            params (ParameterStore): parameter store to fetch the params from

        """

        if self.probs is None and self.name in params.scopes:

            probs = list(params.get_scope(self.name).parameters.values())

            assert (  # noqa: S101

                len(probs) == 1

            ), f'Make sure to only have a single parameter in the scope "{self.name}" when setting the parameters of a `BranchingNode` through `fetch_params`.'  # noqa: E501

            self.probs = probs[0]

        for branch in self.branches.values():

            for node in branch:

                node.fetch_params(params)

    def get_params(self) -> ParameterStore | None:

        """βš™οΈ collects parameters from every node.

        The parameters are scoped based on the node names.

        Returns:

            ParameterStore | None: parameters from every node. None is there are no parameters.

        """

        store = ParameterStore([])

        if self.probs:

            store.scopes.add(self.name)

            store.parameters[self.name] = {}

            store.parameters[self.name][self.probs.name] = self.probs

            store.parameter_names.add(f"{self.name}.{self.probs.name}")

        for branch in self.branches.values():

            for node in branch:

                if params := node.get_params():

                    store |= params

        return store if len(store.parameter_names) > 0 else None

@beartype()

def construct_tree(

    nodes: Callable | TransformNode | dict | str | list[Callable | TransformNode | dict | str],

    string_node: Callable[[str], TransformNode] = ClassLabelNode,

) -> list[TransformNode]:

    """πŸ—οΈ assembles and links nodes into a tree.

    The following rules apply during construction:

    1. Nodes are linked sequentially according to the ordering in the top-level list.

    2. `TransformNode` and sub-classes are not modified.

    3. `Callable`s are cast into `TransformNode`s.

    4. `str` are passed to the `string_node` constructor, which allows to configure which Node type is used for them.

    5. dictionaries are interpreted as `BranchingNode`s, where each value represents a branch.

        The keys `name`, `probs` and `seed` are reserved to describe the node itself.

    6. If there is a node after a `BranchingNode`, then every branch will be connected to a **copy** of this node.

        This ensures that the structure of the tree is preserved (Otherwise we would create a more generic directed acyclic graph),

        which is not supported by `anytree`.

    Args:

        nodes (Callable | TransformNode | dict | str | list[Callable | TransformNode | dict | str]): short-hand description of a tree.

        string_node (Callable[[str], TransformNode], optional): `TransfromNode` constructor from strings used in `construct_tree`.

            Defaults to `ClassLabelNode`.

    Returns:

        list[TransformNode]: list of linked nodes

    """

    # We need the next two lines to handle single element branches gracefully in the recursion.

    if not isinstance(nodes, list):

        nodes = [nodes]

    graph = []

    for node in nodes:

        match node:

            case Callable():

                graph.append(TransformNode(node))

            case TransformNode():

                graph.append(node)

            case dict():

                name = node.pop("name", None)

                probs = node.pop("probs", None)

                seed = node.pop("seed", 42)

                graph.append(BranchingNode(node, name=name, probs=probs, seed=seed, string_node=string_node))

            case str():

                graph.append(string_node(node))

    # connect the nodes to enable traversal and parameter fetching and updating

    for node, next_node in pairwise(graph):

        match (node, next_node):

            case (BranchingNode(), _):

                # * This is a special shorthand conveninence behaviour:

                # * when we sequentially combine a `BranchingNode` with another node,

                # * we add the other node to every leaf of the branches in the `BranchingNode`

                for branch in node.branches.values():

                    # * the copy operation is needed, since `anytree` does not allow merged branches

                    # * (merged branches are different branches with a common child -> creates a DAG instead of a tree).

                    # * we have to add the copy to the leaf's children to enable traversal

                    for leaf in branch[-1].leaves:

                        next_node_copy = deepcopy(next_node)

                        leaf.children = [next_node_copy]

                        # * we have to add the copy to the branch to handle parameter fetching and updating

                        branch.append(next_node_copy)

            case (_, _):

                node.children = [next_node]

    return graph

class SamplingTree(Sampler):

    """🌳 a tree of `TransformNode`s, that can be sampled from.

    The tree will be constructed using `streamgen.nodes.construct_tree(nodes, string_node)`.

    Args:

        nodes (list[Callable | TransformNode | dict| str]): pythonic short-hand description of a graph/tree.

            `streamgen.samplers.tree.construct_tree` will be called to construct the tree.

        params (ParameterStore | DataFrame | None, optional): parameter store containing additional parameters

            that are passed to the nodes based on the scope. Dataframes will be converted to `ParameterStore`. Defaults to None.

        collate_func (Callable[[list[Any]], Any] | None, optional): function to collate samples when using `self.collect(num_samples)`.

            If None, return a list of samples. Defaults to None.

        string_node (Callable[[str], TransformNode], optional): `TransfromNode` constructor from strings used in `construct_tree`.

            Defaults to `ClassLabelNode`.

    """

    def __init__(  # noqa: D107

        self,

        nodes: list[Callable | TransformNode | dict | str],

        params: ParameterStore | DataFrame | None = None,

        collate_func: Callable[[list[Any]], Any] | None = None,

        string_node: Callable[[str], TransformNode] = ClassLabelNode,

    ) -> None:

        self.nodes = construct_tree(nodes, string_node)

        self.root = self.nodes[0]

        match params:

            case None:

                self.params = ParameterStore([])

            case DataFrame():

                self.params = ParameterStore.from_dataframe(params)

            case ParameterStore():

                self.params = params

        # pass parameters to nodes

        for node in self.nodes:

            node.fetch_params(self.params)

        self.collate_func = collate_func

    def sample(self) -> Any:  # noqa: ANN401

        """🎲 generates a sample by traversing the tree from root to one leaf.

        Returns:

            Any: sample

        """

        node = self.root

        out = None

        while node is not None:

            out, node = node.traverse(out)

        return out

    def __next__(self) -> Any:  # noqa: ANN401

        """πŸͺΊ returns the next element during iteration.

        The iterator never runs out of samples, so no `StopIteration` exception is raised.

        Returns:

            Any: a sample

        """

        return self.sample()

    def __iter__(self) -> Self:

        """🏭 turns self into an iterator.

        Required to loop over a `SamplingTree`.

        """

        return self

    def collect(self, num_samples: int, strategy: SamplingStrategy | SamplingStrategyLit = "stochastic") -> Any:  # noqa: ANN401

        """πŸͺΊ collect and concatenate `num_samples` using `sample() and `self.collate_func`.

        Args:

            num_samples (int): number of samples to collect.

                When using the "stochastic" (default) strategy, this refers to the total number of samples.

                When using the "balanced" strategies, this refers to the number of samples per path through the tree.

            strategy (SamplingStrategy | SamplingStrategyLit, optional): sampling strategy. Defaults to "stochastic".

        Returns:

            Any: collection of samples.

                If `self.collate_func` is defined, it will be mapped to the tuple elements in each sample.

                Otherwise this functions just returns a list of samples.

        """

        match strategy:

            case SamplingStrategy.STOCHASTIC:

                samples = [self.sample() for _ in track(range(num_samples), description="🎲 sampling...")]

            case SamplingStrategy.BALANCED:

                paths = self.get_paths()

                samples = list(itertools.chain(*[path.collect(num_samples) for path in paths]))

            case SamplingStrategy.BALANCED_PRUNED:

                paths = self.get_paths(prune=True)

                samples = list(itertools.chain(*[path.collect(num_samples) for path in paths]))

        return tuple(map(self.collate_func, zip(*samples, strict=True))) if self.collate_func else samples

    def update(self) -> None:

        """πŸ†™ updates every parameter."""

        for node in self.nodes:

            node.update()

    def set_update_step(self, idx: int) -> None:

        """πŸ• updates every parameter to a certain update step using `param[idx]`.

        Args:

            idx (int): parameter update step

        Returns:

            None: this function mutates `self`

        """

        for node in self.nodes:

            node.set_update_step(idx)

    def get_params(self) -> ParameterStore | None:

        """βš™οΈ collects parameters from every node.

        The parameters are scoped based on the node names.

        Returns:

            ParameterStore | None: parameters from every node. None is there are no parameters.

        """

        store = ParameterStore([])

        for node in self.nodes:

            if params := node.get_params():

                store |= params

        return store if len(store.parameter_names) > 0 else None

    def to_dotfile(

        self,

        file_path: Path = Path("./tree.dot"),

        plotting_func: Callable[[Any, plt.Axes], plt.Axes] | None = None,

        fps: int = 2,

    ) -> None:

        """πŸ•ΈοΈ exports the tree as a `dot` file using [graphviz](https://www.graphviz.org/).

        Args:

            file_path (Path, optional): path of the resulting file. Defaults to "./tree.dot".

            plotting_func (Callable[[Any, plt.Axes], plt.Axes]): function to visualize a single sample.

                The function should take a sample and a `plt.Axes` as arguments.

                It is used to create sample animations for `SampleBufferNode`s.

            fps (int, optional): frames per second for the sample animations. Defaults to 2.

        """

        output_path = file_path.parent

        def _nodeattrfunc(node) -> str:  # noqa: ANN001

            """Builds the node attribute list for graphviz."""

            a = f'label="{node.emoji} {node.name}"'

            match node:

                case BranchingNode():

                    probs = [round(1.0 / len(node.children), 3)] * len(node.children) if node.probs is None else str(node.probs)

                    return a + f' shape=diamond tooltip="{probs}"'

                case ClassLabelNode():

                    return a + " shape=cds"

                case SampleBufferNode():

                    # create animation

                    if plotting_func is None:

                        return a + " shape=box"

                    anim = node.plot(plotting_func, display=False)

                    if anim is None:

                        return a + " shape=box"

                    # save gif

                    gif_path = output_path / f"{node.name}.gif"

                    anim.save(gif_path, writer=animation.PillowWriter(fps=fps))

                    # add gif as background

                    return f'label="" shape=box image="{gif_path.name}" imagescale=true'

                case _:

                    return a + f' tooltip="{node.get_params()!s}"'

        dot = UniqueDotExporter(

            self.root,

            graph="digraph",

            nodeattrfunc=_nodeattrfunc,

        )

        dot.to_dotfile(file_path)

    def to_svg(

        self,

        file_path: Path = Path("./tree"),

        plotting_func: Callable[[Any, plt.Axes], plt.Axes] | None = None,

        fps: int = 2,

    ) -> SVG:

        """πŸ“Ή visualizes the tree as an svg using [graphviz](https://www.graphviz.org/).

        Args:

            file_path (Path, optional): path of the resulting file. Defaults to "./tree.dot".

            plotting_func (Callable[[Any, plt.Axes], plt.Axes]): function to visualize a single sample.

                The function should take a sample and a `plt.Axes` as arguments.

                It is used to create sample animations for `SampleBufferNode`s.

            fps (int, optional): frames per second for the sample animations. Defaults to 2.

        Returns:

            IPython.display.SVG: svg display of dot visualization

        """

        output_path = file_path.parent

        file_stem = file_path.stem

        dot_path = output_path / (file_stem + ".dot")

        with io.capture_output() as _captured:

            self.to_dotfile(dot_path, plotting_func, fps)

            Source.from_file(dot_path).render(dot_path, format="svg")

        return SVG(filename=(output_path / (file_stem + ".dot.svg")))

    def get_paths(self, prune: bool = False) -> list[Self]:  # noqa: FBT001, FBT002

        """πŸƒ constructs a deterministic path for each leaf in the tree.

        Args:

            prune (bool, optional): If true, only return paths with probabilities greater than zero. Defaults to false.

        Returns:

            list[Self]: list of `SamplingTree`s without branches.

        """

        paths = []

        for leaf in self.root.leaves:

            pruned_path = False

            if prune:

                # check if all probs leading to the leaf are greater than 0

                for node in leaf.ancestors:

                    if isinstance(node, BranchingNode):

                        # get probability for the path leading to the leaf

                        # to do this, we need to find the index of the branch

                        idx = next(idx for idx, branch in enumerate(node.children) if leaf in branch.leaves)

                        prob = node.probs.value[idx]

                        if prob == 0.0:

                            pruned_path = True

            if not pruned_path:

                path = [node for node in deepcopy(leaf.ancestors) if not isinstance(node, BranchingNode)] + [deepcopy(leaf)]

                for node, next_node in pairwise(path):

                    node.children = [next_node]

                paths.append(SamplingTree(path, self.params))

        return paths

    def __str__(self) -> str:

        """🏷️ Returns the string representation `str(self)`.

        Returns:

            str: string representation of self

        """

        s = "🌳\n"

        for pre, _, node in anytree.RenderTree(self.root, style=anytree.ContRoundStyle()):

            s += pre + str(node) + "\n"

        return s

Functions

construct_tree

def construct_tree(
    nodes: collections.abc.Callable | streamgen.nodes.TransformNode | dict | str | list[collections.abc.Callable | streamgen.nodes.TransformNode | dict | str],
    string_node: collections.abc.Callable[[str], streamgen.nodes.TransformNode] = <class 'streamgen.nodes.ClassLabelNode'>
) -> list[streamgen.nodes.TransformNode]

πŸ—οΈ assembles and links nodes into a tree.

The following rules apply during construction:

  1. Nodes are linked sequentially according to the ordering in the top-level list.
  2. TransformNode and sub-classes are not modified.
  3. Callables are cast into TransformNodes.
  4. str are passed to the string_node constructor, which allows to configure which Node type is used for them.
  5. dictionaries are interpreted as BranchingNodes, where each value represents a branch. The keys name, probs and seed are reserved to describe the node itself.
  6. If there is a node after a BranchingNode, then every branch will be connected to a copy of this node. This ensures that the structure of the tree is preserved (Otherwise we would create a more generic directed acyclic graph), which is not supported by anytree.

Parameters:

Name Type Description Default
nodes Callable TransformNode dict
string_node Callable[[str], TransformNode] TransfromNode constructor from strings used in construct_tree.
Defaults to ClassLabelNode.
None

Returns:

Type Description
list[TransformNode] list of linked nodes
View Source
@beartype()

def construct_tree(

    nodes: Callable | TransformNode | dict | str | list[Callable | TransformNode | dict | str],

    string_node: Callable[[str], TransformNode] = ClassLabelNode,

) -> list[TransformNode]:

    """πŸ—οΈ assembles and links nodes into a tree.

    The following rules apply during construction:

    1. Nodes are linked sequentially according to the ordering in the top-level list.

    2. `TransformNode` and sub-classes are not modified.

    3. `Callable`s are cast into `TransformNode`s.

    4. `str` are passed to the `string_node` constructor, which allows to configure which Node type is used for them.

    5. dictionaries are interpreted as `BranchingNode`s, where each value represents a branch.

        The keys `name`, `probs` and `seed` are reserved to describe the node itself.

    6. If there is a node after a `BranchingNode`, then every branch will be connected to a **copy** of this node.

        This ensures that the structure of the tree is preserved (Otherwise we would create a more generic directed acyclic graph),

        which is not supported by `anytree`.

    Args:

        nodes (Callable | TransformNode | dict | str | list[Callable | TransformNode | dict | str]): short-hand description of a tree.

        string_node (Callable[[str], TransformNode], optional): `TransfromNode` constructor from strings used in `construct_tree`.

            Defaults to `ClassLabelNode`.

    Returns:

        list[TransformNode]: list of linked nodes

    """

    # We need the next two lines to handle single element branches gracefully in the recursion.

    if not isinstance(nodes, list):

        nodes = [nodes]

    graph = []

    for node in nodes:

        match node:

            case Callable():

                graph.append(TransformNode(node))

            case TransformNode():

                graph.append(node)

            case dict():

                name = node.pop("name", None)

                probs = node.pop("probs", None)

                seed = node.pop("seed", 42)

                graph.append(BranchingNode(node, name=name, probs=probs, seed=seed, string_node=string_node))

            case str():

                graph.append(string_node(node))

    # connect the nodes to enable traversal and parameter fetching and updating

    for node, next_node in pairwise(graph):

        match (node, next_node):

            case (BranchingNode(), _):

                # * This is a special shorthand conveninence behaviour:

                # * when we sequentially combine a `BranchingNode` with another node,

                # * we add the other node to every leaf of the branches in the `BranchingNode`

                for branch in node.branches.values():

                    # * the copy operation is needed, since `anytree` does not allow merged branches

                    # * (merged branches are different branches with a common child -> creates a DAG instead of a tree).

                    # * we have to add the copy to the leaf's children to enable traversal

                    for leaf in branch[-1].leaves:

                        next_node_copy = deepcopy(next_node)

                        leaf.children = [next_node_copy]

                        # * we have to add the copy to the branch to handle parameter fetching and updating

                        branch.append(next_node_copy)

            case (_, _):

                node.children = [next_node]

    return graph

Classes

BranchingNode

class BranchingNode(
    branches: dict,
    probs: streamgen.parameter.Parameter | list[float] | None = None,
    name: str | None = None,
    seed: int = 42,
    string_node: collections.abc.Callable[[str], streamgen.nodes.TransformNode] = <class 'streamgen.nodes.ClassLabelNode'>
)

πŸͺ΄ node with multiple children/branches.

When traversed, a random branch is selected based on the probabilities defined by probs.

Attributes

Name Type Description Default
branches dict dictionary, where each key:value pair represent label:branch. None
probs Parameter list[float] None parameter containing the probabilities for selecting each branch.
probs.value is passed to numpy.random.choice as parameter p, which is documented as:
(1-D array_like, optional) the probabilities associated with each entry in a.
If not given the sample assumes a uniform distribution over all entries. Defaults to None.
name str None name of the node. Important for fetching the probs if not present. Defaults to "branching_node".
seed int random number generator seed. Defaults to 42. 42
string_node Callable[[str], TransformNode] TransfromNode constructor from strings used in construct_tree.
Defaults to ClassLabelNode.
None
View Source
class BranchingNode(TransformNode):

    """πŸͺ΄ node with multiple children/branches.

    When traversed, a random branch is selected based on the probabilities defined by `probs`.

    Args:

        branches (dict): dictionary, where each key:value pair represent label:branch.

        probs (Parameter | list[float] None, optional): parameter containing the probabilities for selecting each branch.

            `probs.value` is passed to `numpy.random.choice` as parameter `p`, which is documented as:

            (1-D array_like, optional) the probabilities associated with each entry in a.

            If not given the sample assumes a uniform distribution over all entries. Defaults to None.

        name (str | None, optional): name of the node. Important for fetching the `probs` if not present. Defaults to "branching_node".

        seed (int, optional): random number generator seed. Defaults to 42.

        string_node (Callable[[str], TransformNode], optional): `TransfromNode` constructor from strings used in `construct_tree`.

            Defaults to `ClassLabelNode`.

    """

    def __init__(  # noqa: D107

        self,

        branches: dict,

        probs: Parameter | list[float] | None = None,

        name: str | None = None,

        seed: int = 42,

        string_node: Callable[[str], TransformNode] = ClassLabelNode,

    ) -> None:

        self.name = name if name else "branching_node"

        if isinstance(probs, list):

            probs = Parameter(name="probs", value=probs)

        self.probs = probs

        self.rng = np.random.default_rng(seed)

        self.branches = {branch_name: construct_tree(nodes, string_node) for branch_name, nodes in branches.items()}

        self.children = [branch[0] for branch in self.branches.values()]

        super().__init__(transform=noop, name=self.name, emoji="πŸͺ΄")

    def traverse(self, input: Any) -> tuple[Any, anytree.NodeMixin]:  # noqa: A002, ANN401

        """πŸƒπŸŽ² `streamgen.transforms.Traverse` protocol `(input: Any) -> (output, anytree.NodeMixin | None)`.

        During traversal, a branching node samples the next node from its children.

        Args:

            input (Any): any input

        Returns:

            tuple[Any, anytree.NodeMixin | None]: output and next node to traverse

        """

        key = self.rng.choice(list(self.branches.keys()), p=self.probs.value if self.probs else None)

        next_node = self.branches[key][0]

        return input, next_node

    def update(self) -> None:

        """πŸ†™ updates every parameter."""

        if self.probs:

            self.probs.update()

        for branch in self.branches.values():

            for node in branch:

                node.update()

    def set_update_step(self, idx: int) -> None:

        """πŸ• updates every parameter to a certain update step.

        Args:

            idx (int): parameter update step

        Returns:

            None: this function mutates `self`

        """

        if self.probs:

            self.probs[idx]

        for branch in self.branches.values():

            for node in branch:

                node.set_update_step(idx)

    def fetch_params(self, params: ParameterStore) -> None:

        """βš™οΈ fetches params from a ParameterStore.

        Args:

            params (ParameterStore): parameter store to fetch the params from

        """

        if self.probs is None and self.name in params.scopes:

            probs = list(params.get_scope(self.name).parameters.values())

            assert (  # noqa: S101

                len(probs) == 1

            ), f'Make sure to only have a single parameter in the scope "{self.name}" when setting the parameters of a `BranchingNode` through `fetch_params`.'  # noqa: E501

            self.probs = probs[0]

        for branch in self.branches.values():

            for node in branch:

                node.fetch_params(params)

    def get_params(self) -> ParameterStore | None:

        """βš™οΈ collects parameters from every node.

        The parameters are scoped based on the node names.

        Returns:

            ParameterStore | None: parameters from every node. None is there are no parameters.

        """

        store = ParameterStore([])

        if self.probs:

            store.scopes.add(self.name)

            store.parameters[self.name] = {}

            store.parameters[self.name][self.probs.name] = self.probs

            store.parameter_names.add(f"{self.name}.{self.probs.name}")

        for branch in self.branches.values():

            for node in branch:

                if params := node.get_params():

                    store |= params

        return store if len(store.parameter_names) > 0 else None

Ancestors (in MRO)

  • streamgen.nodes.TransformNode
  • anytree.node.nodemixin.NodeMixin

Class variables

separator

Instance variables

ancestors

All parent nodes and their parent nodes.

from anytree import Node udo = Node("Udo") marc = Node("Marc", parent=udo) lian = Node("Lian", parent=marc) udo.ancestors () marc.ancestors (Node('/Udo'),) lian.ancestors (Node('/Udo'), Node('/Udo/Marc'))

anchestors

All parent nodes and their parent nodes - see :any:ancestors.

The attribute anchestors is just a typo of ancestors. Please use ancestors. This attribute will be removed in the 3.0.0 release.

children

All child nodes.

from anytree import Node n = Node("n") a = Node("a", parent=n) b = Node("b", parent=n) c = Node("c", parent=n) n.children (Node('/n/a'), Node('/n/b'), Node('/n/c'))

Modifying the children attribute modifies the tree.

Detach

The children attribute can be updated by setting to an iterable.

n.children = [a, b] n.children (Node('/n/a'), Node('/n/b'))

Node c is removed from the tree. In case of an existing reference, the node c does not vanish and is the root of its own tree.

c Node('/c')

Attach

d = Node("d") d Node('/d') n.children = [a, b, d] n.children (Node('/n/a'), Node('/n/b'), Node('/n/d')) d Node('/n/d')

Duplicate

A node can just be the children once. Duplicates cause a :any:TreeError:

n.children = [a, b, d, a] Traceback (most recent call last): ... anytree.node.exceptions.TreeError: Cannot add node Node('/n/a') multiple times as child.

depth

Number of edges to the root Node.

from anytree import Node udo = Node("Udo") marc = Node("Marc", parent=udo) lian = Node("Lian", parent=marc) udo.depth 0 marc.depth 1 lian.depth 2

descendants

All child nodes and all their child nodes.

from anytree import Node udo = Node("Udo") marc = Node("Marc", parent=udo) lian = Node("Lian", parent=marc) loui = Node("Loui", parent=marc) soe = Node("Soe", parent=lian) udo.descendants (Node('/Udo/Marc'), Node('/Udo/Marc/Lian'), Node('/Udo/Marc/Lian/Soe'), Node('/Udo/Marc/Loui')) marc.descendants (Node('/Udo/Marc/Lian'), Node('/Udo/Marc/Lian/Soe'), Node('/Udo/Marc/Loui')) lian.descendants (Node('/Udo/Marc/Lian/Soe'),)

height

Number of edges on the longest path to a leaf Node.

from anytree import Node udo = Node("Udo") marc = Node("Marc", parent=udo) lian = Node("Lian", parent=marc) udo.height 2 marc.height 1 lian.height 0

is_leaf

Node has no children (External Node).

from anytree import Node udo = Node("Udo") marc = Node("Marc", parent=udo) lian = Node("Lian", parent=marc) udo.is_leaf False marc.is_leaf False lian.is_leaf True

is_root

Node is tree root.

from anytree import Node udo = Node("Udo") marc = Node("Marc", parent=udo) lian = Node("Lian", parent=marc) udo.is_root True marc.is_root False lian.is_root False

leaves

Tuple of all leaf nodes.

from anytree import Node udo = Node("Udo") marc = Node("Marc", parent=udo) lian = Node("Lian", parent=marc) loui = Node("Loui", parent=marc) lazy = Node("Lazy", parent=marc) udo.leaves (Node('/Udo/Marc/Lian'), Node('/Udo/Marc/Loui'), Node('/Udo/Marc/Lazy')) marc.leaves (Node('/Udo/Marc/Lian'), Node('/Udo/Marc/Loui'), Node('/Udo/Marc/Lazy'))

parent

Parent Node.

On set, the node is detached from any previous parent node and attached to the new node.

from anytree import Node, RenderTree udo = Node("Udo") marc = Node("Marc") lian = Node("Lian", parent=marc) print(RenderTree(udo)) Node('/Udo') print(RenderTree(marc)) Node('/Marc') └── Node('/Marc/Lian')

Attach

marc.parent = udo print(RenderTree(udo)) Node('/Udo') └── Node('/Udo/Marc') └── Node('/Udo/Marc/Lian')

Detach

To make a node to a root node, just set this attribute to None.

marc.is_root False marc.parent = None marc.is_root True

path

Path from root node down to this Node.

from anytree import Node udo = Node("Udo") marc = Node("Marc", parent=udo) lian = Node("Lian", parent=marc) udo.path (Node('/Udo'),) marc.path (Node('/Udo'), Node('/Udo/Marc')) lian.path (Node('/Udo'), Node('/Udo/Marc'), Node('/Udo/Marc/Lian'))

root

Tree Root Node.

from anytree import Node udo = Node("Udo") marc = Node("Marc", parent=udo) lian = Node("Lian", parent=marc) udo.root Node('/Udo') marc.root Node('/Udo') lian.root Node('/Udo')

siblings

Tuple of nodes with the same parent.

from anytree import Node udo = Node("Udo") marc = Node("Marc", parent=udo) lian = Node("Lian", parent=marc) loui = Node("Loui", parent=marc) lazy = Node("Lazy", parent=marc) udo.siblings () marc.siblings () lian.siblings (Node('/Udo/Marc/Loui'), Node('/Udo/Marc/Lazy')) loui.siblings (Node('/Udo/Marc/Lian'), Node('/Udo/Marc/Lazy'))

size

Tree size --- the number of nodes in tree starting at this node.

from anytree import Node udo = Node("Udo") marc = Node("Marc", parent=udo) lian = Node("Lian", parent=marc) loui = Node("Loui", parent=marc) soe = Node("Soe", parent=lian) udo.size 5 marc.size 4 lian.size 2 loui.size 1

Methods

fetch_params

def fetch_params(
    self,
    params: streamgen.parameter.store.ParameterStore
) -> None

βš™οΈ fetches params from a ParameterStore.

Parameters:

Name Type Description Default
params ParameterStore parameter store to fetch the params from None
View Source
    def fetch_params(self, params: ParameterStore) -> None:

        """βš™οΈ fetches params from a ParameterStore.

        Args:

            params (ParameterStore): parameter store to fetch the params from

        """

        if self.probs is None and self.name in params.scopes:

            probs = list(params.get_scope(self.name).parameters.values())

            assert (  # noqa: S101

                len(probs) == 1

            ), f'Make sure to only have a single parameter in the scope "{self.name}" when setting the parameters of a `BranchingNode` through `fetch_params`.'  # noqa: E501

            self.probs = probs[0]

        for branch in self.branches.values():

            for node in branch:

                node.fetch_params(params)

get_params

def get_params(
    self
) -> streamgen.parameter.store.ParameterStore | None

βš™οΈ collects parameters from every node.

The parameters are scoped based on the node names.

Returns:

Type Description
None ParameterStore
View Source
    def get_params(self) -> ParameterStore | None:

        """βš™οΈ collects parameters from every node.

        The parameters are scoped based on the node names.

        Returns:

            ParameterStore | None: parameters from every node. None is there are no parameters.

        """

        store = ParameterStore([])

        if self.probs:

            store.scopes.add(self.name)

            store.parameters[self.name] = {}

            store.parameters[self.name][self.probs.name] = self.probs

            store.parameter_names.add(f"{self.name}.{self.probs.name}")

        for branch in self.branches.values():

            for node in branch:

                if params := node.get_params():

                    store |= params

        return store if len(store.parameter_names) > 0 else None

iter_path_reverse

def iter_path_reverse(
    self
)

Iterate up the tree from the current node to the root node.

from anytree import Node udo = Node("Udo") marc = Node("Marc", parent=udo) lian = Node("Lian", parent=marc) for node in udo.iter_path_reverse(): ... print(node) Node('/Udo') for node in marc.iter_path_reverse(): ... print(node) Node('/Udo/Marc') Node('/Udo') for node in lian.iter_path_reverse(): ... print(node) Node('/Udo/Marc/Lian') Node('/Udo/Marc') Node('/Udo')

View Source
    def iter_path_reverse(self):

        """

        Iterate up the tree from the current node to the root node.

        >>> from anytree import Node

        >>> udo = Node("Udo")

        >>> marc = Node("Marc", parent=udo)

        >>> lian = Node("Lian", parent=marc)

        >>> for node in udo.iter_path_reverse():

        ...     print(node)

        Node('/Udo')

        >>> for node in marc.iter_path_reverse():

        ...     print(node)

        Node('/Udo/Marc')

        Node('/Udo')

        >>> for node in lian.iter_path_reverse():

        ...     print(node)

        Node('/Udo/Marc/Lian')

        Node('/Udo/Marc')

        Node('/Udo')

        """

        node = self

        while node is not None:

            yield node

            node = node.parent

set_update_step

def set_update_step(
    self,
    idx: int
) -> None

πŸ• updates every parameter to a certain update step.

Parameters:

Name Type Description Default
idx int parameter update step None

Returns:

Type Description
None this function mutates self
View Source
    def set_update_step(self, idx: int) -> None:

        """πŸ• updates every parameter to a certain update step.

        Args:

            idx (int): parameter update step

        Returns:

            None: this function mutates `self`

        """

        if self.probs:

            self.probs[idx]

        for branch in self.branches.values():

            for node in branch:

                node.set_update_step(idx)

traverse

def traverse(
    self,
    input: Any
) -> tuple[typing.Any, anytree.node.nodemixin.NodeMixin]

πŸƒπŸŽ² streamgen.transforms.Traverse protocol (input: Any) -> (output, anytree.NodeMixin | None).

During traversal, a branching node samples the next node from its children.

Parameters:

Name Type Description Default
input Any any input None

Returns:

Type Description
tuple[Any, anytree.NodeMixin None]
View Source
    def traverse(self, input: Any) -> tuple[Any, anytree.NodeMixin]:  # noqa: A002, ANN401

        """πŸƒπŸŽ² `streamgen.transforms.Traverse` protocol `(input: Any) -> (output, anytree.NodeMixin | None)`.

        During traversal, a branching node samples the next node from its children.

        Args:

            input (Any): any input

        Returns:

            tuple[Any, anytree.NodeMixin | None]: output and next node to traverse

        """

        key = self.rng.choice(list(self.branches.keys()), p=self.probs.value if self.probs else None)

        next_node = self.branches[key][0]

        return input, next_node

update

def update(
    self
) -> None

πŸ†™ updates every parameter.

View Source
    def update(self) -> None:

        """πŸ†™ updates every parameter."""

        if self.probs:

            self.probs.update()

        for branch in self.branches.values():

            for node in branch:

                node.update()

SamplingTree

class SamplingTree(
    nodes: list[collections.abc.Callable | streamgen.nodes.TransformNode | dict | str],
    params: streamgen.parameter.store.ParameterStore | pandas.core.frame.DataFrame | None = None,
    collate_func: collections.abc.Callable[[list[typing.Any]], typing.Any] | None = None,
    string_node: collections.abc.Callable[[str], streamgen.nodes.TransformNode] = <class 'streamgen.nodes.ClassLabelNode'>
)

🌳 a tree of TransformNodes, that can be sampled from.

The tree will be constructed using streamgen.nodes.construct_tree(nodes, string_node).

Attributes

Name Type Description Default
nodes list[Callable TransformNode dict
params ParameterStore DataFrame None
collate_func Callable[[list[Any]], Any] None function to collate samples when using self.collect(num_samples).
If None, return a list of samples. Defaults to None.
string_node Callable[[str], TransformNode] TransfromNode constructor from strings used in construct_tree.
Defaults to ClassLabelNode.
None
View Source
class SamplingTree(Sampler):

    """🌳 a tree of `TransformNode`s, that can be sampled from.

    The tree will be constructed using `streamgen.nodes.construct_tree(nodes, string_node)`.

    Args:

        nodes (list[Callable | TransformNode | dict| str]): pythonic short-hand description of a graph/tree.

            `streamgen.samplers.tree.construct_tree` will be called to construct the tree.

        params (ParameterStore | DataFrame | None, optional): parameter store containing additional parameters

            that are passed to the nodes based on the scope. Dataframes will be converted to `ParameterStore`. Defaults to None.

        collate_func (Callable[[list[Any]], Any] | None, optional): function to collate samples when using `self.collect(num_samples)`.

            If None, return a list of samples. Defaults to None.

        string_node (Callable[[str], TransformNode], optional): `TransfromNode` constructor from strings used in `construct_tree`.

            Defaults to `ClassLabelNode`.

    """

    def __init__(  # noqa: D107

        self,

        nodes: list[Callable | TransformNode | dict | str],

        params: ParameterStore | DataFrame | None = None,

        collate_func: Callable[[list[Any]], Any] | None = None,

        string_node: Callable[[str], TransformNode] = ClassLabelNode,

    ) -> None:

        self.nodes = construct_tree(nodes, string_node)

        self.root = self.nodes[0]

        match params:

            case None:

                self.params = ParameterStore([])

            case DataFrame():

                self.params = ParameterStore.from_dataframe(params)

            case ParameterStore():

                self.params = params

        # pass parameters to nodes

        for node in self.nodes:

            node.fetch_params(self.params)

        self.collate_func = collate_func

    def sample(self) -> Any:  # noqa: ANN401

        """🎲 generates a sample by traversing the tree from root to one leaf.

        Returns:

            Any: sample

        """

        node = self.root

        out = None

        while node is not None:

            out, node = node.traverse(out)

        return out

    def __next__(self) -> Any:  # noqa: ANN401

        """πŸͺΊ returns the next element during iteration.

        The iterator never runs out of samples, so no `StopIteration` exception is raised.

        Returns:

            Any: a sample

        """

        return self.sample()

    def __iter__(self) -> Self:

        """🏭 turns self into an iterator.

        Required to loop over a `SamplingTree`.

        """

        return self

    def collect(self, num_samples: int, strategy: SamplingStrategy | SamplingStrategyLit = "stochastic") -> Any:  # noqa: ANN401

        """πŸͺΊ collect and concatenate `num_samples` using `sample() and `self.collate_func`.

        Args:

            num_samples (int): number of samples to collect.

                When using the "stochastic" (default) strategy, this refers to the total number of samples.

                When using the "balanced" strategies, this refers to the number of samples per path through the tree.

            strategy (SamplingStrategy | SamplingStrategyLit, optional): sampling strategy. Defaults to "stochastic".

        Returns:

            Any: collection of samples.

                If `self.collate_func` is defined, it will be mapped to the tuple elements in each sample.

                Otherwise this functions just returns a list of samples.

        """

        match strategy:

            case SamplingStrategy.STOCHASTIC:

                samples = [self.sample() for _ in track(range(num_samples), description="🎲 sampling...")]

            case SamplingStrategy.BALANCED:

                paths = self.get_paths()

                samples = list(itertools.chain(*[path.collect(num_samples) for path in paths]))

            case SamplingStrategy.BALANCED_PRUNED:

                paths = self.get_paths(prune=True)

                samples = list(itertools.chain(*[path.collect(num_samples) for path in paths]))

        return tuple(map(self.collate_func, zip(*samples, strict=True))) if self.collate_func else samples

    def update(self) -> None:

        """πŸ†™ updates every parameter."""

        for node in self.nodes:

            node.update()

    def set_update_step(self, idx: int) -> None:

        """πŸ• updates every parameter to a certain update step using `param[idx]`.

        Args:

            idx (int): parameter update step

        Returns:

            None: this function mutates `self`

        """

        for node in self.nodes:

            node.set_update_step(idx)

    def get_params(self) -> ParameterStore | None:

        """βš™οΈ collects parameters from every node.

        The parameters are scoped based on the node names.

        Returns:

            ParameterStore | None: parameters from every node. None is there are no parameters.

        """

        store = ParameterStore([])

        for node in self.nodes:

            if params := node.get_params():

                store |= params

        return store if len(store.parameter_names) > 0 else None

    def to_dotfile(

        self,

        file_path: Path = Path("./tree.dot"),

        plotting_func: Callable[[Any, plt.Axes], plt.Axes] | None = None,

        fps: int = 2,

    ) -> None:

        """πŸ•ΈοΈ exports the tree as a `dot` file using [graphviz](https://www.graphviz.org/).

        Args:

            file_path (Path, optional): path of the resulting file. Defaults to "./tree.dot".

            plotting_func (Callable[[Any, plt.Axes], plt.Axes]): function to visualize a single sample.

                The function should take a sample and a `plt.Axes` as arguments.

                It is used to create sample animations for `SampleBufferNode`s.

            fps (int, optional): frames per second for the sample animations. Defaults to 2.

        """

        output_path = file_path.parent

        def _nodeattrfunc(node) -> str:  # noqa: ANN001

            """Builds the node attribute list for graphviz."""

            a = f'label="{node.emoji} {node.name}"'

            match node:

                case BranchingNode():

                    probs = [round(1.0 / len(node.children), 3)] * len(node.children) if node.probs is None else str(node.probs)

                    return a + f' shape=diamond tooltip="{probs}"'

                case ClassLabelNode():

                    return a + " shape=cds"

                case SampleBufferNode():

                    # create animation

                    if plotting_func is None:

                        return a + " shape=box"

                    anim = node.plot(plotting_func, display=False)

                    if anim is None:

                        return a + " shape=box"

                    # save gif

                    gif_path = output_path / f"{node.name}.gif"

                    anim.save(gif_path, writer=animation.PillowWriter(fps=fps))

                    # add gif as background

                    return f'label="" shape=box image="{gif_path.name}" imagescale=true'

                case _:

                    return a + f' tooltip="{node.get_params()!s}"'

        dot = UniqueDotExporter(

            self.root,

            graph="digraph",

            nodeattrfunc=_nodeattrfunc,

        )

        dot.to_dotfile(file_path)

    def to_svg(

        self,

        file_path: Path = Path("./tree"),

        plotting_func: Callable[[Any, plt.Axes], plt.Axes] | None = None,

        fps: int = 2,

    ) -> SVG:

        """πŸ“Ή visualizes the tree as an svg using [graphviz](https://www.graphviz.org/).

        Args:

            file_path (Path, optional): path of the resulting file. Defaults to "./tree.dot".

            plotting_func (Callable[[Any, plt.Axes], plt.Axes]): function to visualize a single sample.

                The function should take a sample and a `plt.Axes` as arguments.

                It is used to create sample animations for `SampleBufferNode`s.

            fps (int, optional): frames per second for the sample animations. Defaults to 2.

        Returns:

            IPython.display.SVG: svg display of dot visualization

        """

        output_path = file_path.parent

        file_stem = file_path.stem

        dot_path = output_path / (file_stem + ".dot")

        with io.capture_output() as _captured:

            self.to_dotfile(dot_path, plotting_func, fps)

            Source.from_file(dot_path).render(dot_path, format="svg")

        return SVG(filename=(output_path / (file_stem + ".dot.svg")))

    def get_paths(self, prune: bool = False) -> list[Self]:  # noqa: FBT001, FBT002

        """πŸƒ constructs a deterministic path for each leaf in the tree.

        Args:

            prune (bool, optional): If true, only return paths with probabilities greater than zero. Defaults to false.

        Returns:

            list[Self]: list of `SamplingTree`s without branches.

        """

        paths = []

        for leaf in self.root.leaves:

            pruned_path = False

            if prune:

                # check if all probs leading to the leaf are greater than 0

                for node in leaf.ancestors:

                    if isinstance(node, BranchingNode):

                        # get probability for the path leading to the leaf

                        # to do this, we need to find the index of the branch

                        idx = next(idx for idx, branch in enumerate(node.children) if leaf in branch.leaves)

                        prob = node.probs.value[idx]

                        if prob == 0.0:

                            pruned_path = True

            if not pruned_path:

                path = [node for node in deepcopy(leaf.ancestors) if not isinstance(node, BranchingNode)] + [deepcopy(leaf)]

                for node, next_node in pairwise(path):

                    node.children = [next_node]

                paths.append(SamplingTree(path, self.params))

        return paths

    def __str__(self) -> str:

        """🏷️ Returns the string representation `str(self)`.

        Returns:

            str: string representation of self

        """

        s = "🌳\n"

        for pre, _, node in anytree.RenderTree(self.root, style=anytree.ContRoundStyle()):

            s += pre + str(node) + "\n"

        return s

Ancestors (in MRO)

  • streamgen.samplers.Sampler
  • collections.abc.Iterator
  • collections.abc.Iterable
  • typing.Protocol
  • typing.Generic

Methods

collect

def collect(
    self,
    num_samples: int,
    strategy: Union[streamgen.enums.SamplingStrategy, Literal['stochastic', 'balanced', 'balanced pruned']] = 'stochastic'
) -> Any

πŸͺΊ collect and concatenate num_samples using sample() andself.collate_func`.

Parameters:

Name Type Description Default
num_samples int number of samples to collect.
When using the "stochastic" (default) strategy, this refers to the total number of samples.
When using the "balanced" strategies, this refers to the number of samples per path through the tree.
None
strategy SamplingStrategy SamplingStrategyLit sampling strategy. Defaults to "stochastic".

Returns:

Type Description
Any collection of samples.
If self.collate_func is defined, it will be mapped to the tuple elements in each sample.
Otherwise this functions just returns a list of samples.
View Source
    def collect(self, num_samples: int, strategy: SamplingStrategy | SamplingStrategyLit = "stochastic") -> Any:  # noqa: ANN401

        """πŸͺΊ collect and concatenate `num_samples` using `sample() and `self.collate_func`.

        Args:

            num_samples (int): number of samples to collect.

                When using the "stochastic" (default) strategy, this refers to the total number of samples.

                When using the "balanced" strategies, this refers to the number of samples per path through the tree.

            strategy (SamplingStrategy | SamplingStrategyLit, optional): sampling strategy. Defaults to "stochastic".

        Returns:

            Any: collection of samples.

                If `self.collate_func` is defined, it will be mapped to the tuple elements in each sample.

                Otherwise this functions just returns a list of samples.

        """

        match strategy:

            case SamplingStrategy.STOCHASTIC:

                samples = [self.sample() for _ in track(range(num_samples), description="🎲 sampling...")]

            case SamplingStrategy.BALANCED:

                paths = self.get_paths()

                samples = list(itertools.chain(*[path.collect(num_samples) for path in paths]))

            case SamplingStrategy.BALANCED_PRUNED:

                paths = self.get_paths(prune=True)

                samples = list(itertools.chain(*[path.collect(num_samples) for path in paths]))

        return tuple(map(self.collate_func, zip(*samples, strict=True))) if self.collate_func else samples

get_params

def get_params(
    self
) -> streamgen.parameter.store.ParameterStore | None

βš™οΈ collects parameters from every node.

The parameters are scoped based on the node names.

Returns:

Type Description
None ParameterStore
View Source
    def get_params(self) -> ParameterStore | None:

        """βš™οΈ collects parameters from every node.

        The parameters are scoped based on the node names.

        Returns:

            ParameterStore | None: parameters from every node. None is there are no parameters.

        """

        store = ParameterStore([])

        for node in self.nodes:

            if params := node.get_params():

                store |= params

        return store if len(store.parameter_names) > 0 else None

get_paths

def get_paths(
    self,
    prune: bool = False
) -> list[typing.Self]

πŸƒ constructs a deterministic path for each leaf in the tree.

Parameters:

Name Type Description Default
prune bool If true, only return paths with probabilities greater than zero. Defaults to false. false

Returns:

Type Description
list[Self] list of SamplingTrees without branches.
View Source
    def get_paths(self, prune: bool = False) -> list[Self]:  # noqa: FBT001, FBT002

        """πŸƒ constructs a deterministic path for each leaf in the tree.

        Args:

            prune (bool, optional): If true, only return paths with probabilities greater than zero. Defaults to false.

        Returns:

            list[Self]: list of `SamplingTree`s without branches.

        """

        paths = []

        for leaf in self.root.leaves:

            pruned_path = False

            if prune:

                # check if all probs leading to the leaf are greater than 0

                for node in leaf.ancestors:

                    if isinstance(node, BranchingNode):

                        # get probability for the path leading to the leaf

                        # to do this, we need to find the index of the branch

                        idx = next(idx for idx, branch in enumerate(node.children) if leaf in branch.leaves)

                        prob = node.probs.value[idx]

                        if prob == 0.0:

                            pruned_path = True

            if not pruned_path:

                path = [node for node in deepcopy(leaf.ancestors) if not isinstance(node, BranchingNode)] + [deepcopy(leaf)]

                for node, next_node in pairwise(path):

                    node.children = [next_node]

                paths.append(SamplingTree(path, self.params))

        return paths

sample

def sample(
    self
) -> Any

🎲 generates a sample by traversing the tree from root to one leaf.

Returns:

Type Description
Any sample
View Source
    def sample(self) -> Any:  # noqa: ANN401

        """🎲 generates a sample by traversing the tree from root to one leaf.

        Returns:

            Any: sample

        """

        node = self.root

        out = None

        while node is not None:

            out, node = node.traverse(out)

        return out

set_update_step

def set_update_step(
    self,
    idx: int
) -> None

πŸ• updates every parameter to a certain update step using param[idx].

Parameters:

Name Type Description Default
idx int parameter update step None

Returns:

Type Description
None this function mutates self
View Source
    def set_update_step(self, idx: int) -> None:

        """πŸ• updates every parameter to a certain update step using `param[idx]`.

        Args:

            idx (int): parameter update step

        Returns:

            None: this function mutates `self`

        """

        for node in self.nodes:

            node.set_update_step(idx)

to_dotfile

def to_dotfile(
    self,
    file_path: pathlib.Path = WindowsPath('tree.dot'),
    plotting_func: collections.abc.Callable[[typing.Any, matplotlib.axes._axes.Axes], matplotlib.axes._axes.Axes] | None = None,
    fps: int = 2
) -> None

πŸ•ΈοΈ exports the tree as a dot file using graphviz.

Parameters:

Name Type Description Default
file_path Path path of the resulting file. Defaults to "./tree.dot". "./tree.dot"
plotting_func Callable[[Any, plt.Axes], plt.Axes] function to visualize a single sample.
The function should take a sample and a plt.Axes as arguments.
It is used to create sample animations for SampleBufferNodes.
None
fps int frames per second for the sample animations. Defaults to 2. 2
View Source
    def to_dotfile(

        self,

        file_path: Path = Path("./tree.dot"),

        plotting_func: Callable[[Any, plt.Axes], plt.Axes] | None = None,

        fps: int = 2,

    ) -> None:

        """πŸ•ΈοΈ exports the tree as a `dot` file using [graphviz](https://www.graphviz.org/).

        Args:

            file_path (Path, optional): path of the resulting file. Defaults to "./tree.dot".

            plotting_func (Callable[[Any, plt.Axes], plt.Axes]): function to visualize a single sample.

                The function should take a sample and a `plt.Axes` as arguments.

                It is used to create sample animations for `SampleBufferNode`s.

            fps (int, optional): frames per second for the sample animations. Defaults to 2.

        """

        output_path = file_path.parent

        def _nodeattrfunc(node) -> str:  # noqa: ANN001

            """Builds the node attribute list for graphviz."""

            a = f'label="{node.emoji} {node.name}"'

            match node:

                case BranchingNode():

                    probs = [round(1.0 / len(node.children), 3)] * len(node.children) if node.probs is None else str(node.probs)

                    return a + f' shape=diamond tooltip="{probs}"'

                case ClassLabelNode():

                    return a + " shape=cds"

                case SampleBufferNode():

                    # create animation

                    if plotting_func is None:

                        return a + " shape=box"

                    anim = node.plot(plotting_func, display=False)

                    if anim is None:

                        return a + " shape=box"

                    # save gif

                    gif_path = output_path / f"{node.name}.gif"

                    anim.save(gif_path, writer=animation.PillowWriter(fps=fps))

                    # add gif as background

                    return f'label="" shape=box image="{gif_path.name}" imagescale=true'

                case _:

                    return a + f' tooltip="{node.get_params()!s}"'

        dot = UniqueDotExporter(

            self.root,

            graph="digraph",

            nodeattrfunc=_nodeattrfunc,

        )

        dot.to_dotfile(file_path)

to_svg

def to_svg(
    self,
    file_path: pathlib.Path = WindowsPath('tree'),
    plotting_func: collections.abc.Callable[[typing.Any, matplotlib.axes._axes.Axes], matplotlib.axes._axes.Axes] | None = None,
    fps: int = 2
) -> IPython.core.display.SVG

πŸ“Ή visualizes the tree as an svg using graphviz.

Parameters:

Name Type Description Default
file_path Path path of the resulting file. Defaults to "./tree.dot". "./tree.dot"
plotting_func Callable[[Any, plt.Axes], plt.Axes] function to visualize a single sample.
The function should take a sample and a plt.Axes as arguments.
It is used to create sample animations for SampleBufferNodes.
None
fps int frames per second for the sample animations. Defaults to 2. 2

Returns:

Type Description
IPython.display.SVG svg display of dot visualization
View Source
    def to_svg(

        self,

        file_path: Path = Path("./tree"),

        plotting_func: Callable[[Any, plt.Axes], plt.Axes] | None = None,

        fps: int = 2,

    ) -> SVG:

        """πŸ“Ή visualizes the tree as an svg using [graphviz](https://www.graphviz.org/).

        Args:

            file_path (Path, optional): path of the resulting file. Defaults to "./tree.dot".

            plotting_func (Callable[[Any, plt.Axes], plt.Axes]): function to visualize a single sample.

                The function should take a sample and a `plt.Axes` as arguments.

                It is used to create sample animations for `SampleBufferNode`s.

            fps (int, optional): frames per second for the sample animations. Defaults to 2.

        Returns:

            IPython.display.SVG: svg display of dot visualization

        """

        output_path = file_path.parent

        file_stem = file_path.stem

        dot_path = output_path / (file_stem + ".dot")

        with io.capture_output() as _captured:

            self.to_dotfile(dot_path, plotting_func, fps)

            Source.from_file(dot_path).render(dot_path, format="svg")

        return SVG(filename=(output_path / (file_stem + ".dot.svg")))

update

def update(
    self
) -> None

πŸ†™ updates every parameter.

View Source
    def update(self) -> None:

        """πŸ†™ updates every parameter."""

        for node in self.nodes:

            node.update()