Skip to content

Graph

The graph module provides the core data structures for defining probabilistic models.

Overview

A Falcon graph consists of Node objects connected by parent-child relationships. The Graph class manages these relationships and handles topological sorting for correct execution order.

Classes

Node

Node(name, simulator_cls, estimator_cls=None, parents=None, evidence=None, scaffolds=None, observed=False, resample=False, simulator_config=None, estimator_config=None, actor_config=None, num_actors=1, sample_chunk_size=0)

Node definition for a graphical model.

Parameters:

Name Type Description Default
name str

Name of the node.

required
create_distr class

Distribution class to create the node.

required
config dict

Configuration for the distribution.

required
parents list

List of parent node names (forward model).

None
evidence list

List of evidence node names (inference model).

None
observed bool

Whether the node is observed (act as root nodes for inference model).

False
actor_name str

Optional name of the actor to deploy the node.

required
resample bool

Whether to resample the node

False
Source code in falcon/core/graph.py
def __init__(
    self,
    name,
    simulator_cls,
    estimator_cls=None,
    parents=None,
    evidence=None,
    scaffolds=None,
    observed=False,
    resample=False,
    simulator_config=None,
    estimator_config=None,
    actor_config=None,
    num_actors=1,
    sample_chunk_size=0,
):
    """Node definition for a graphical model.

    Args:
        name (str): Name of the node.
        create_distr (class): Distribution class to create the node.
        config (dict): Configuration for the distribution.
        parents (list): List of parent node names (forward model).
        evidence (list): List of evidence node names (inference model).
        observed (bool): Whether the node is observed (act as root nodes for inference model).
        actor_name (str): Optional name of the actor to deploy the node.
        resample (bool): Whether to resample the node
    """
    self.name = name

    self.simulator_cls = simulator_cls
    self.estimator_cls = estimator_cls

    self.parents = parents or []
    self.evidence = evidence or []
    self.scaffolds = scaffolds or []
    self.observed = observed
    self.resample = resample
    self.train = self.estimator_cls is not None

    self.simulator_config = simulator_config or {}
    self.estimator_config = estimator_config or {}
    self.actor_config = actor_config or {}
    self.num_actors = num_actors
    self.sample_chunk_size = sample_chunk_size

Graph

Graph(node_list)
Source code in falcon/core/graph.py
def __init__(self, node_list):
    # Storing the node list
    self.node_list = node_list
    self.node_dict = {node.name: node for node in node_list}
    self.simulator_cls_dict = {node.name: node.simulator_cls for node in node_list}

    # Raw config data
    self.name_list = [node.name for node in node_list]
    self.evidence_dict = {node.name: node.evidence for node in node_list}
    self.scaffolds_dict = {node.name: node.scaffolds for node in node_list}
    self.observed_dict = {node.name: node.observed for node in node_list}

    # Forward graph (simulation): node -> [parent dependencies]
    self.forward_deps = {node.name: node.parents for node in node_list}
    self.forward_order = self._topological_sort(
        self.name_list, self.forward_deps
    )

    # Backward graph (inference): node -> [dependencies for inference]
    # Start with nodes that are observed or have evidence
    backward_set = {
        node.name for node in node_list if node.observed or len(node.evidence) > 0
    }
    # Expand: include deterministic ancestors reachable via evidence references
    queue = []
    for name in list(backward_set):
        for ev in self.evidence_dict[name]:
            if ev not in backward_set:
                queue.append(ev)
    while queue:
        name = queue.pop()
        if name in backward_set:
            continue
        backward_set.add(name)
        for parent in self.forward_deps[name]:
            if parent not in backward_set:
                queue.append(parent)

    backward_names = [n.name for n in node_list if n.name in backward_set]

    # Build merged dependency dict:
    # - Observed nodes: [] (leaf nodes, values provided externally)
    # - Nodes with evidence: evidence_dict (inference direction)
    # - Deterministic intermediates: forward_deps (simulation direction)
    self.backward_deps = {}
    for name in backward_names:
        if self.observed_dict[name]:
            self.backward_deps[name] = []
        elif self.evidence_dict[name]:
            self.backward_deps[name] = self.evidence_dict[name]
        else:
            self.backward_deps[name] = self.forward_deps[name]

    self.backward_order = self._topological_sort(
        backward_names, self.backward_deps
    )

get_parents

get_parents(node_name)
Source code in falcon/core/graph.py
def get_parents(self, node_name):
    return self.forward_deps[node_name]

get_evidence

get_evidence(node_name)
Source code in falcon/core/graph.py
def get_evidence(self, node_name):
    return self.evidence_dict[node_name]

get_simulator_cls

get_simulator_cls(node_name)
Source code in falcon/core/graph.py
def get_simulator_cls(self, node_name):
    return self.simulator_cls_dict[node_name]

Functions

CompositeNode

CompositeNode(names, module, **kwargs)

Auxiliary function to create a composite node with multiple child nodes.

Source code in falcon/core/graph.py
def CompositeNode(names, module, **kwargs):
    """Auxiliary function to create a composite node with multiple child nodes."""

    # Generate name of composite node from names of child nodes
    joined_names = "comp_" + "_".join(names)

    # Instantiate composite node
    node_comp = Node(joined_names, module, **kwargs)

    # Instantiate child nodes, which extract the individual components
    nodes = []
    for i, name in enumerate(names):
        node = Node(
            name, Extractor, parents=[joined_names], simulator_config=dict(index=i)
        )
        nodes.append(node)

    # Return composite node and child nodes, which both must be added to the graph
    return node_comp, *nodes

create_graph_from_config

create_graph_from_config(graph_config, _cfg=None)

Create a computational graph from YAML configuration.

Parameters:

Name Type Description Default
graph_config

Dictionary containing graph node definitions

required
_cfg

Full Hydra configuration object (optional)

None

Returns:

Name Type Description
Graph

The computational graph

Raises:

Type Description
ValueError

If configuration is invalid (missing required fields, unknown references)

Source code in falcon/core/graph.py
def create_graph_from_config(graph_config, _cfg=None):
    """Create a computational graph from YAML configuration.

    Args:
        graph_config: Dictionary containing graph node definitions
        _cfg: Full Hydra configuration object (optional)

    Returns:
        Graph: The computational graph

    Raises:
        ValueError: If configuration is invalid (missing required fields, unknown references)
    """
    nodes = []
    observations = {}

    for node_name, node_config in graph_config.items():
        # Validate configuration
        _validate_node_config(node_name, node_config)

        # Extract node parameters
        parents = node_config.get("parents", [])
        evidence = node_config.get("evidence", [])
        scaffolds = node_config.get("scaffolds", [])
        observed = node_config.get(
            "observed", False
        )  # TODO: Remove from internal logic
        data_path = node_config.get("observed", None)
        resample = node_config.get("resample", False)
        actor_config = node_config.get("ray", {})
        num_actors = node_config.get("num_actors", 1)
        sample_chunk_size = node_config.get("sample_chunk_size", 0)

        if actor_config != {}:
            actor_config = OmegaConf.to_container(actor_config, resolve=True)

        if data_path is not None:
            # Parse path for NPZ key extraction syntax: "file.npz['key']"
            file_path, key = _parse_observation_path(data_path)
            if not os.path.exists(file_path):
                raise FileNotFoundError(f"Observation file not found: {file_path}")
            data = np.load(file_path)
            if key is not None:
                # Extract specific key from NPZ
                data = data[key]
            elif hasattr(data, 'files') and len(data.files) == 1:
                # Auto-extract single-key NPZ files
                data = data[data.files[0]]
            observations[node_name] = data

        # Extract target from simulator
        simulator = node_config.get("simulator")
        if isinstance(simulator, str):
            simulator_cls = simulator
            simulator_config = {}
        else:
            simulator_cls = simulator.get("_target_")
            simulator_config = simulator
            simulator_config = OmegaConf.to_container(simulator_config, resolve=True)
            simulator_config.pop("_target_", None)

        # Extract target from infer
        if "estimator" in node_config:
            estimator = node_config.get("estimator")
            if isinstance(estimator, str):
                estimator_cls = estimator
                estimator_config = {}
            else:
                estimator_cls = estimator.get("_target_")
                estimator_config = estimator
                estimator_config = OmegaConf.to_container(
                    estimator_config, resolve=True
                )
                estimator_config.pop("_target_", None)
        else:
            estimator_cls = None
            estimator_config = {}

        # Create the node
        node = Node(
            name=node_name,
            simulator_cls=simulator_cls,
            estimator_cls=estimator_cls,
            parents=parents,
            evidence=evidence,
            scaffolds=scaffolds,
            observed=observed,
            resample=resample,
            simulator_config=simulator_config,
            estimator_config=estimator_config,
            actor_config=actor_config,
            num_actors=num_actors,
            sample_chunk_size=sample_chunk_size,
        )

        nodes.append(node)

    # Validate node references
    node_names = {node.name for node in nodes}
    _validate_node_references(nodes, node_names)

    # Create and return the graph
    return Graph(nodes), observations