Flow
Normalizing flow networks for density estimation.
Overview
The Flow class wraps various normalizing flow architectures for use in
posterior estimation. It supports multiple flow types from different libraries.
Supported Flow Types
| Type |
Library |
Description |
nsf |
sbi/nflows |
Neural Spline Flow |
maf |
sbi/nflows |
Masked Autoregressive Flow |
zuko_nice |
Zuko |
NICE architecture |
zuko_naf |
Zuko |
Neural Autoregressive Flow |
zuko_nsf |
Zuko |
Neural Spline Flow (Zuko) |
zuko_maf |
Zuko |
Masked Autoregressive Flow (Zuko) |
Class Reference
Flow
Flow(theta, s, theta_norm=False, norm_momentum=0.003, net_type='nsf', use_log_update=False, adaptive_momentum=False)
Bases: Module
Normalizing flow network with optional parameter normalization.
Source code in falcon/contrib/flow.py
| def __init__(
self,
theta,
s,
theta_norm=False,
norm_momentum=3e-3,
net_type="nsf",
use_log_update=False,
adaptive_momentum=False,
):
super().__init__()
self.theta_norm = (
LazyOnlineNorm(
momentum=norm_momentum,
use_log_update=use_log_update,
adaptive_momentum=adaptive_momentum,
)
if theta_norm
else None
)
builder = NET_BUILDERS.get(net_type)
if builder is None:
raise ValueError(f"Unknown net_type: {net_type}. Available: {list(NET_BUILDERS.keys())}")
self.net = builder(theta.float(), s.float(), z_score_x=None, z_score_y=None)
if self.theta_norm is not None:
self.theta_norm(theta) # Initialize normalization stats
self.scale = 0.2
|
loss
Compute negative log-likelihood loss.
Source code in falcon/contrib/flow.py
| def loss(self, theta, s):
"""Compute negative log-likelihood loss."""
if self.theta_norm is not None:
theta = self.theta_norm(theta)
theta = theta.float() * self.scale
loss = self.net.loss(theta, condition=s.float())
loss = loss - np.log(self.scale) * theta.shape[-1]
if self.theta_norm is not None:
loss = loss + torch.log(self.theta_norm.volume())
return loss
|
sample
Sample from the flow.
Source code in falcon/contrib/flow.py
| def sample(self, num_samples, s):
"""Sample from the flow."""
samples = self.net.sample((num_samples,), condition=s).detach()
samples = samples / self.scale
if self.theta_norm is not None:
samples = self.theta_norm.inverse(samples).detach()
return samples
|
log_prob
Compute log probability.
Source code in falcon/contrib/flow.py
| def log_prob(self, theta, s):
"""Compute log probability."""
if self.theta_norm is not None:
theta = self.theta_norm(theta).detach()
theta = theta * self.scale
log_prob = self.net.log_prob(theta.float(), condition=s.float())
log_prob = log_prob + np.log(self.scale) * theta.shape[-1]
if self.theta_norm is not None:
log_prob = log_prob - torch.log(self.theta_norm.volume().detach())
return log_prob
|