Skip to content

Flow

Normalizing flow networks for density estimation.

Overview

The Flow class wraps various normalizing flow architectures for use in posterior estimation. It supports multiple flow types from different libraries.

Supported Flow Types

Type Library Description
nsf sbi/nflows Neural Spline Flow
maf sbi/nflows Masked Autoregressive Flow
zuko_nice Zuko NICE architecture
zuko_naf Zuko Neural Autoregressive Flow
zuko_nsf Zuko Neural Spline Flow (Zuko)
zuko_maf Zuko Masked Autoregressive Flow (Zuko)

Class Reference

Flow

Flow(theta, s, theta_norm=False, norm_momentum=0.003, net_type='nsf', use_log_update=False, adaptive_momentum=False)

Bases: Module

Normalizing flow network with optional parameter normalization.

Source code in falcon/contrib/flow.py
def __init__(
    self,
    theta,
    s,
    theta_norm=False,
    norm_momentum=3e-3,
    net_type="nsf",
    use_log_update=False,
    adaptive_momentum=False,
):
    super().__init__()
    self.theta_norm = (
        LazyOnlineNorm(
            momentum=norm_momentum,
            use_log_update=use_log_update,
            adaptive_momentum=adaptive_momentum,
        )
        if theta_norm
        else None
    )

    builder = NET_BUILDERS.get(net_type)
    if builder is None:
        raise ValueError(f"Unknown net_type: {net_type}. Available: {list(NET_BUILDERS.keys())}")
    self.net = builder(theta.float(), s.float(), z_score_x=None, z_score_y=None)

    if self.theta_norm is not None:
        self.theta_norm(theta)  # Initialize normalization stats
    self.scale = 0.2

loss

loss(theta, s)

Compute negative log-likelihood loss.

Source code in falcon/contrib/flow.py
def loss(self, theta, s):
    """Compute negative log-likelihood loss."""
    if self.theta_norm is not None:
        theta = self.theta_norm(theta)
    theta = theta.float() * self.scale
    loss = self.net.loss(theta, condition=s.float())
    loss = loss - np.log(self.scale) * theta.shape[-1]
    if self.theta_norm is not None:
        loss = loss + torch.log(self.theta_norm.volume())
    return loss

sample

sample(num_samples, s)

Sample from the flow.

Source code in falcon/contrib/flow.py
def sample(self, num_samples, s):
    """Sample from the flow."""
    samples = self.net.sample((num_samples,), condition=s).detach()
    samples = samples / self.scale
    if self.theta_norm is not None:
        samples = self.theta_norm.inverse(samples).detach()
    return samples

log_prob

log_prob(theta, s)

Compute log probability.

Source code in falcon/contrib/flow.py
def log_prob(self, theta, s):
    """Compute log probability."""
    if self.theta_norm is not None:
        theta = self.theta_norm(theta).detach()
    theta = theta * self.scale
    log_prob = self.net.log_prob(theta.float(), condition=s.float())
    log_prob = log_prob + np.log(self.scale) * theta.shape[-1]
    if self.theta_norm is not None:
        log_prob = log_prob - torch.log(self.theta_norm.volume().detach())
    return log_prob