Skip to content

Graph

The graph module provides the core data structures for defining probabilistic models.

Overview

A Falcon graph consists of Node objects connected by parent-child relationships. The Graph class manages these relationships and handles topological sorting for correct execution order.

Classes

Node

Node(name, simulator_cls, estimator_cls=None, parents=None, evidence=None, scaffolds=None, observed=False, resample=False, simulator_config=None, estimator_config=None, actor_config=None, num_actors=1)

Node definition for a graphical model.

Parameters:

Name Type Description Default
name str

Name of the node.

required
create_distr class

Distribution class to create the node.

required
config dict

Configuration for the distribution.

required
parents list

List of parent node names (forward model).

None
evidence list

List of evidence node names (inference model).

None
observed bool

Whether the node is observed (act as root nodes for inference model).

False
actor_name str

Optional name of the actor to deploy the node.

required
resample bool

Whether to resample the node

False
Source code in falcon/core/graph.py
def __init__(
    self,
    name,
    simulator_cls,
    estimator_cls=None,
    parents=None,
    evidence=None,
    scaffolds=None,
    observed=False,
    resample=False,
    simulator_config=None,
    estimator_config=None,
    actor_config=None,
    num_actors=1,
):
    """Node definition for a graphical model.

    Args:
        name (str): Name of the node.
        create_distr (class): Distribution class to create the node.
        config (dict): Configuration for the distribution.
        parents (list): List of parent node names (forward model).
        evidence (list): List of evidence node names (inference model).
        observed (bool): Whether the node is observed (act as root nodes for inference model).
        actor_name (str): Optional name of the actor to deploy the node.
        resample (bool): Whether to resample the node
    """
    self.name = name

    self.simulator_cls = simulator_cls
    self.estimator_cls = estimator_cls

    self.parents = parents or []
    self.evidence = evidence or []
    self.scaffolds = scaffolds or []
    self.observed = observed
    self.resample = resample
    self.train = self.estimator_cls is not None

    self.simulator_config = simulator_config or {}
    self.estimator_config = estimator_config or {}
    self.actor_config = actor_config or {}
    self.num_actors = num_actors

Graph

Graph(node_list)
Source code in falcon/core/graph.py
def __init__(self, node_list):
    # Storing the node list
    self.node_list = node_list
    self.node_dict = {node.name: node for node in node_list}
    self.simulator_cls_dict = {node.name: node.simulator_cls for node in node_list}

    # Storing the model graph structure
    self.name_list = [node.name for node in node_list]
    self.parents_dict = {node.name: node.parents for node in node_list}
    self.sorted_node_names = self._topological_sort(
        self.name_list, self.parents_dict
    )

    # Storing the inference graph structure.
    # Only observed nodes or nodes with evidence are included in the inference graph.
    self.evidence_dict = {node.name: node.evidence for node in node_list}
    self.scaffolds_dict = {node.name: node.scaffolds for node in node_list}
    self.observed_dict = {node.name: node.observed for node in node_list}
    self.inference_name_list = [
        node.name for node in node_list if node.observed or len(node.evidence) > 0
    ]
    self.sorted_inference_node_names = self._topological_sort(
        self.inference_name_list, self.evidence_dict
    )

get_parents

get_parents(node_name)
Source code in falcon/core/graph.py
def get_parents(self, node_name):
    return self.parents_dict[node_name]

get_evidence

get_evidence(node_name)
Source code in falcon/core/graph.py
def get_evidence(self, node_name):
    return self.evidence_dict[node_name]

get_simulator_cls

get_simulator_cls(node_name)
Source code in falcon/core/graph.py
def get_simulator_cls(self, node_name):
    return self.simulator_cls_dict[node_name]

Functions

CompositeNode

CompositeNode(names, module, **kwargs)

Auxiliary function to create a composite node with multiple child nodes.

Source code in falcon/core/graph.py
def CompositeNode(names, module, **kwargs):
    """Auxiliary function to create a composite node with multiple child nodes."""

    # Generate name of composite node from names of child nodes
    joined_names = "comp_" + "_".join(names)

    # Instantiate composite node
    node_comp = Node(joined_names, module, **kwargs)

    # Instantiate child nodes, which extract the individual components
    nodes = []
    for i, name in enumerate(names):
        node = Node(
            name, Extractor, parents=[joined_names], simulator_config=dict(index=i)
        )
        nodes.append(node)

    # Return composite node and child nodes, which both must be added to the graph
    return node_comp, *nodes

create_graph_from_config

create_graph_from_config(graph_config, _cfg=None)

Create a computational graph from YAML configuration.

Parameters:

Name Type Description Default
graph_config

Dictionary containing graph node definitions

required
_cfg

Full Hydra configuration object (optional)

None

Returns:

Name Type Description
Graph

The computational graph

Raises:

Type Description
ValueError

If configuration is invalid (missing required fields, unknown references)

Source code in falcon/core/graph.py
def create_graph_from_config(graph_config, _cfg=None):
    """Create a computational graph from YAML configuration.

    Args:
        graph_config: Dictionary containing graph node definitions
        _cfg: Full Hydra configuration object (optional)

    Returns:
        Graph: The computational graph

    Raises:
        ValueError: If configuration is invalid (missing required fields, unknown references)
    """
    nodes = []
    observations = {}

    for node_name, node_config in graph_config.items():
        # Validate configuration
        _validate_node_config(node_name, node_config)

        # Extract node parameters
        parents = node_config.get("parents", [])
        evidence = node_config.get("evidence", [])
        scaffolds = node_config.get("scaffolds", [])
        observed = node_config.get(
            "observed", False
        )  # TODO: Remove from internal logic
        data_path = node_config.get("observed", None)
        resample = node_config.get("resample", False)
        actor_config = node_config.get("ray", {})
        num_actors = node_config.get("num_actors", 1)

        if actor_config != {}:
            actor_config = OmegaConf.to_container(actor_config, resolve=True)

        if data_path is not None:
            # Parse path for NPZ key extraction syntax: "file.npz['key']"
            file_path, key = _parse_observation_path(data_path)
            if not os.path.exists(file_path):
                raise FileNotFoundError(f"Observation file not found: {file_path}")
            data = np.load(file_path)
            if key is not None:
                # Extract specific key from NPZ
                data = data[key]
            elif hasattr(data, 'files') and len(data.files) == 1:
                # Auto-extract single-key NPZ files
                data = data[data.files[0]]
            observations[node_name] = data

        # Extract target from simulator
        simulator = node_config.get("simulator")
        if isinstance(simulator, str):
            simulator_cls = simulator
            simulator_config = {}
        else:
            simulator_cls = simulator.get("_target_")
            simulator_config = simulator
            simulator_config = OmegaConf.to_container(simulator_config, resolve=True)
            simulator_config.pop("_target_", None)

        # Extract target from infer
        if "estimator" in node_config:
            estimator = node_config.get("estimator")
            if isinstance(estimator, str):
                estimator_cls = estimator
                estimator_config = {}
            else:
                estimator_cls = estimator.get("_target_")
                estimator_config = estimator
                estimator_config = OmegaConf.to_container(
                    estimator_config, resolve=True
                )
                estimator_config.pop("_target_", None)
        else:
            estimator_cls = None
            estimator_config = {}

        # Create the node
        node = Node(
            name=node_name,
            simulator_cls=simulator_cls,
            estimator_cls=estimator_cls,
            parents=parents,
            evidence=evidence,
            scaffolds=scaffolds,
            observed=observed,
            resample=resample,
            simulator_config=simulator_config,
            estimator_config=estimator_config,
            actor_config=actor_config,
            num_actors=num_actors,
        )

        nodes.append(node)

    # Validate node references
    node_names = {node.name for node in nodes}
    _validate_node_references(nodes, node_names)

    # Create and return the graph
    return Graph(nodes), observations