Skip to content

Flow Estimator

Flow-based posterior estimation using normalizing flows.

Overview

Flow is the primary estimator in Falcon for learning posterior distributions. It uses dual normalizing flows (conditional and marginal) with importance sampling for adaptive proposal generation.

Key features:

  • Dual flow architecture for posterior and proposal sampling
  • Parameter space normalization via hypercube mapping
  • Importance sampling with effective sample size monitoring
  • Automatic learning rate scheduling and early stopping

Configuration

Flow is configured through four groups: loop, network, optimizer, and inference. The embedding configuration sits as a sibling of network, not nested inside it.

estimator:
  _target_: falcon.estimators.Flow

  loop:
    # Training loop parameters

  network:
    # Neural network architecture

  embedding:
    # Observation embedding network

  optimizer:
    # Learning rate and scheduling

  inference:
    # Sampling and amortization settings

Configuration Reference

Training Loop (loop)

Controls the training process.

Parameter Type Default Description
num_epochs int 300 Maximum training epochs
batch_size int 128 Training batch size
early_stop_patience int 32 Epochs without improvement before stopping
reset_network_after_pause bool false Reset network weights when training resumes after pause
cache_sync_every int 0 Epochs between cache syncs with the buffer (0 = every epoch)
max_cache_samples int 0 Maximum samples to cache (0 = cache all available)
cache_on_device bool false Keep cached training data on the estimator's device (e.g. GPU)
loop:
  num_epochs: 300
  batch_size: 128
  early_stop_patience: 32
  reset_network_after_pause: false
  cache_sync_every: 0
  max_cache_samples: 0
  cache_on_device: false

Data Caching

Training data is loaded into a local cache that is periodically synced with the shared simulation buffer. This avoids repeated remote data fetches and allows fast random-access batching.

  • cache_sync_every: Controls how often the cache pulls new samples from the buffer. A value of 0 (default) syncs every epoch. Higher values reduce sync overhead at the cost of slightly stale data, which can be useful when simulations are slow.
  • max_cache_samples: Caps the number of samples held in the cache. Set to 0 to cache everything. A positive value randomly subsamples, which helps limit GPU memory usage for very large buffers.
  • cache_on_device: When true, cached tensors are moved to the estimator's device (typically GPU) once during sync rather than per-batch. This eliminates CPU-to-GPU transfer overhead during training but increases device memory usage.

Network Architecture (network)

Defines the neural network structure.

Parameter Type Default Description
net_type str nsf Flow architecture (see FlowDensity for all types)
theta_norm bool true Normalize parameter space
norm_momentum float 0.003 Momentum for online normalization updates
use_log_update bool false Use log-space variance updates
adaptive_momentum bool false Sample-dependent momentum
network:
  net_type: nsf
  theta_norm: true
  norm_momentum: 0.003
  use_log_update: false
  adaptive_momentum: false

Embedding

The embedding network processes observations before they enter the flow. It is configured as a sibling of network (not nested inside it).

embedding:
  _target_: model.MyEmbedding
  _input_: [x]

See Embeddings for details on the declarative embedding system, including multi-input and nested pipeline configurations.

Optimizer (optimizer)

Controls learning rate and scheduling.

Parameter Type Default Description
lr float 0.01 Initial learning rate
lr_decay_factor float 0.5 LR multiplier when plateau detected
scheduler_patience int 16 Epochs without improvement before LR decay
optimizer:
  lr: 0.01
  lr_decay_factor: 0.5
  scheduler_patience: 16

Inference (inference)

Controls posterior sampling and amortization.

Parameter Type Default Description
gamma float 0.5 Amortization mixing coefficient (0=focused, 1=amortized)
discard_samples bool false Discard low-likelihood samples during training
log_ratio_threshold float -20 Log-likelihood threshold for sample discarding
sample_reference_posterior bool false Sample from reference posterior
use_best_models_during_inference bool true Use best validation model for sampling
inference:
  gamma: 0.5
  discard_samples: false
  log_ratio_threshold: -20
  sample_reference_posterior: false
  use_best_models_during_inference: true

Understanding gamma (Amortization)

The gamma parameter controls the trade-off between focused and amortized inference:

  • gamma=0: Fully focused on the specific observation (best for single-observation inference)
  • gamma=1: Fully amortized (network generalizes across observations)
  • gamma=0.5: Balanced (default, good for most cases)

Embedding Networks

Flow requires an embedding network to process observations. The embedding maps high-dimensional observations to a lower-dimensional summary statistic.

Basic Embedding

embedding:
  _target_: model.MyEmbedding
  _input_: [x]

Multi-Input Embedding

embedding:
  _target_: model.MyEmbedding
  _input_: [x, y]  # Multiple observation nodes

Nested Embedding Pipeline

embedding:
  _target_: model.Concatenate
  _input_:
    - _target_: timm.create_model
      model_name: resnet18
      pretrained: true
      num_classes: 0
      _input_:
        _target_: model.Unsqueeze
        _input_: [image]
    - _target_: torch.nn.Linear
      in_features: 64
      out_features: 32
      _input_: [metadata]

Complete Example

graph:
  z:
    evidence: [x]

    simulator:
      _target_: falcon.priors.Hypercube
      priors:
        - ['uniform', -100.0, 100.0]
        - ['uniform', -100.0, 100.0]
        - ['uniform', -100.0, 100.0]

    estimator:
      _target_: falcon.estimators.Flow

      loop:
        num_epochs: 300
        batch_size: 128
        early_stop_patience: 32
        cache_sync_every: 0
        max_cache_samples: 0

      network:
        net_type: nsf
        theta_norm: true
        norm_momentum: 0.003

      embedding:
        _target_: model.E
        _input_: [x]

      optimizer:
        lr: 0.01
        lr_decay_factor: 0.5
        scheduler_patience: 16

      inference:
        gamma: 0.5
        discard_samples: false
        log_ratio_threshold: -20

    ray:
      num_gpus: 0

  x:
    parents: [z]
    simulator:
      _target_: model.Simulate
    observed: "./data/obs.npz['x']"

Training Strategies

Standard Training

Default configuration with continuous resampling:

buffer:
  min_training_samples: 4096
  max_training_samples: 32768
  resample_batch_size: 128
  keep_resampling: true
  resample_interval: 10

Amortized Training

Fixed dataset without resampling (for learning across many observations):

buffer:
  min_training_samples: 32000
  max_training_samples: 32000
  resample_batch_size: 0       # No resampling
  keep_resampling: false

# Higher gamma for amortization
inference:
  gamma: 0.8

Round-Based Training

Large batch renewal for sequential refinement:

buffer:
  min_training_samples: 8000
  max_training_samples: 8000
  resample_batch_size: 8000    # Full renewal
  keep_resampling: true
  resample_interval: 30        # Less frequent

inference:
  discard_samples: true        # Remove poor samples

Logged Metrics

Flow logs the following metrics during training:

Metric Description
loss/train Training loss (negative log-likelihood)
loss/val Validation loss
lr Current learning rate
epoch Training epoch
best_val_loss Best validation loss seen

Tips

  1. Start with defaults: The default configuration works well for most problems
  2. Increase num_epochs for complex posteriors
  3. Enable discard_samples if training becomes unstable with outliers
  4. Use GPU (ray.num_gpus: 1) for faster training with large embeddings
  5. Lower gamma for single-observation inference, higher for amortization
  6. Adjust early_stop_patience based on expected convergence time
  7. Set cache_on_device: true when GPU memory permits, to eliminate per-batch CPU-to-GPU transfers
  8. Increase cache_sync_every (e.g. 5-10) when simulations are slow and training data changes infrequently

Class Reference

Flow

Flow(simulator_instance, theta_key=None, condition_keys=None, config=None)

Bases: StepwiseEstimator

Flow-based posterior estimation (formerly SNPE_A).

Implementation-specific features: - Dual flow architecture (conditional + marginal) - Parameter space normalization via hypercube mapping - Importance sampling for posterior/proposal

Initialize Flow estimator.

Parameters:

Name Type Description Default
simulator_instance

Prior/simulator instance

required
theta_key Optional[str]

Key for theta in batch data

None
condition_keys Optional[List[str]]

Keys for condition data in batch

None
config Optional[dict]

Configuration dict with loop, network, optimizer, inference sections

None
Source code in falcon/estimators/flow.py
def __init__(
    self,
    simulator_instance,
    theta_key: Optional[str] = None,
    condition_keys: Optional[List[str]] = None,
    config: Optional[dict] = None,
):
    """
    Initialize Flow estimator.

    Args:
        simulator_instance: Prior/simulator instance
        theta_key: Key for theta in batch data
        condition_keys: Keys for condition data in batch
        config: Configuration dict with loop, network, optimizer, inference sections
    """
    # Merge user config with defaults using OmegaConf structured config
    schema = OmegaConf.structured(FlowConfig)
    config = OmegaConf.merge(schema, config or {})

    super().__init__(
        simulator_instance=simulator_instance,
        loop_config=config.loop,
        theta_key=theta_key,
        condition_keys=condition_keys,
    )

    self.config = config

    # Device setup
    self.device = self._setup_device(config.device)

    # Embedding network
    embedding_config = OmegaConf.to_container(config.embedding, resolve=True)
    self._embedding = instantiate_embedding(embedding_config).to(self.device)

    # Flow networks (initialized lazily)
    self._conditional_flow = None
    self._marginal_flow = None
    self._best_conditional_flow = None
    self._best_marginal_flow = None
    self._best_embedding = None
    self._init_parameters = None

    # Best loss tracking
    self.best_conditional_flow_val_loss = float("inf")
    self.best_marginal_flow_val_loss = float("inf")

    # Optimizer/scheduler (initialized lazily)
    self._optimizer = None
    self._scheduler = None

    # Extended history for Flow-specific tracking
    self.history.update({
        "theta_mins": [],
        "theta_maxs": [],
    })

train_step

train_step(batch)

Flow training step with gradient update and optional sample discarding.

Source code in falcon/estimators/flow.py
def train_step(self, batch) -> Dict[str, float]:
    """Flow training step with gradient update and optional sample discarding."""
    ids, theta, theta_logprob, conditions, u, u_device, conditions_device = \
        self._unpack_batch(batch, "train")

    # Initialize networks on first batch
    if not self.networks_initialized:
        self._initialize_networks(u, conditions)

    # Embed conditions
    s = self._embed(conditions_device, train=True)

    # Track theta ranges
    with torch.no_grad():
        self.history["theta_mins"].append(theta.min(dim=0).values.cpu().numpy())
        self.history["theta_maxs"].append(theta.max(dim=0).values.cpu().numpy())

    # Forward and backward pass
    self._optimizer.zero_grad()
    loss_cond, loss_marg = self._compute_flow_losses(u_device, s, train=True)
    (loss_cond + loss_marg).backward()
    self._optimizer.step()

    # Discard samples based on log-likelihood ratio
    if self.config.inference.discard_samples:
        discard_mask = self._compute_discard_mask(theta, theta_logprob, conditions_device)
        batch.discard(discard_mask)

    return {"loss": loss_cond.item(), "loss_aux": loss_marg.item()}

val_step

val_step(batch)

Flow validation step without gradient computation.

Source code in falcon/estimators/flow.py
def val_step(self, batch) -> Dict[str, float]:
    """Flow validation step without gradient computation."""
    _, theta, theta_logprob, conditions, u, u_device, conditions_device = \
        self._unpack_batch(batch, "val")

    # Embed conditions (eval mode)
    s = self._embed(conditions_device, train=False)

    # Compute losses without gradients
    with torch.no_grad():
        loss_cond, loss_marg = self._compute_flow_losses(u_device, s, train=False)

    return {"loss": loss_cond.item(), "loss_aux": loss_marg.item()}

sample_prior

sample_prior(num_samples, conditions=None)

Sample from the prior distribution.

Source code in falcon/estimators/flow.py
def sample_prior(self, num_samples: int, conditions: Optional[Dict] = None) -> dict:
    """Sample from the prior distribution."""
    if conditions:
        raise ValueError("Conditions are not supported for sample_prior.")
    samples = self.simulator_instance.simulate_batch(num_samples)
    # Log probability for uniform prior over hypercube [-bound, bound]^d
    bound = self.config.inference.hypercube_bound
    log_prob = np.ones(num_samples) * (-np.log(2 * bound) ** self.param_dim)
    return {'value': samples, 'log_prob': log_prob}

sample_posterior

sample_posterior(num_samples, conditions=None)

Sample from the posterior distribution q(theta|x).

Source code in falcon/estimators/flow.py
def sample_posterior(
    self,
    num_samples: int,
    conditions: Optional[Dict] = None,
) -> dict:
    """Sample from the posterior distribution q(theta|x)."""
    # Fall back to prior if networks not yet initialized (training hasn't started)
    if not self.networks_initialized:
        return self.sample_prior(num_samples)

    samples, logprob = self._importance_sample(num_samples, mode="posterior", conditions=conditions or {})
    return {'value': samples.numpy(), 'log_prob': logprob.numpy()}

sample_proposal

sample_proposal(num_samples, conditions=None)

Sample from the widened proposal distribution for adaptive resampling.

Source code in falcon/estimators/flow.py
def sample_proposal(
    self,
    num_samples: int,
    conditions: Optional[Dict] = None,
) -> dict:
    """Sample from the widened proposal distribution for adaptive resampling."""
    # Fall back to prior if networks not yet initialized (training hasn't started)
    if not self.networks_initialized:
        return self.sample_prior(num_samples)

    cfg_inf = self.config.inference
    conditions = conditions or {}

    if cfg_inf.sample_reference_posterior:
        post_samples, _ = self._importance_sample(cfg_inf.reference_samples, mode="posterior", conditions=conditions)
        mean, std = post_samples.mean(dim=0).cpu(), post_samples.std(dim=0).cpu()
        log({f"sample_proposal:posterior_mean_{i}": mean[i].item() for i in range(len(mean))})
        log({f"sample_proposal:posterior_std_{i}": std[i].item() for i in range(len(std))})

    samples, logprob = self._importance_sample(num_samples, mode="proposal", conditions=conditions)
    log({
        "sample_proposal:mean": samples.mean().item(),
        "sample_proposal:std": samples.std().item(),
        "sample_proposal:logprob": logprob.mean().item(),
    })
    return {'value': samples.numpy(), 'log_prob': logprob.numpy()}

save

save(node_dir)

Save Flow state.

Source code in falcon/estimators/flow.py
def save(self, node_dir: Path) -> None:
    """Save Flow state."""
    debug(f"Saving: {node_dir}")
    if not self.networks_initialized:
        raise RuntimeError("Networks not initialized.")

    torch.save(self._best_conditional_flow.state_dict(), node_dir / "conditional_flow.pth")
    torch.save(self._best_marginal_flow.state_dict(), node_dir / "marginal_flow.pth")
    torch.save(self._init_parameters, node_dir / "init_parameters.pth")

    # Save history
    torch.save(self.history["train_ids"], node_dir / "train_id_history.pth")
    torch.save(self.history["val_ids"], node_dir / "validation_id_history.pth")
    torch.save(self.history["theta_mins"], node_dir / "theta_mins_batches.pth")
    torch.save(self.history["theta_maxs"], node_dir / "theta_maxs_batches.pth")
    torch.save(self.history["epochs"], node_dir / "epochs.pth")
    torch.save(self.history["train_loss"], node_dir / "loss_train_posterior.pth")
    torch.save(self.history["val_loss"], node_dir / "loss_val_posterior.pth")
    torch.save(self.history["n_samples"], node_dir / "n_samples_total.pth")
    torch.save(self.history["elapsed_min"], node_dir / "elapsed_minutes.pth")

    if self._best_embedding is not None:
        torch.save(self._best_embedding.state_dict(), node_dir / "embedding.pth")

load

load(node_dir)

Load Flow state.

Source code in falcon/estimators/flow.py
def load(self, node_dir: Path) -> None:
    """Load Flow state."""
    debug(f"Loading: {node_dir}")
    init_parameters = torch.load(node_dir / "init_parameters.pth")
    self._initialize_networks(init_parameters[0], init_parameters[1])

    self._best_conditional_flow.load_state_dict(
        torch.load(node_dir / "conditional_flow.pth")
    )
    self._best_marginal_flow.load_state_dict(
        torch.load(node_dir / "marginal_flow.pth")
    )

    if (node_dir / "embedding.pth").exists() and self._best_embedding is not None:
        self._best_embedding.load_state_dict(torch.load(node_dir / "embedding.pth"))

Configuration Classes

FlowConfig dataclass

FlowConfig(loop=TrainingLoopConfig(), network=NetworkConfig(), optimizer=OptimizerConfig(), inference=InferenceConfig(), embedding=None, device=None)

Top-level Flow estimator configuration.

NetworkConfig dataclass

NetworkConfig(net_type='zuko_nice', theta_norm=True, norm_momentum=0.01, adaptive_momentum=False, use_log_update=False)

Neural network architecture parameters.

OptimizerConfig dataclass

OptimizerConfig(lr=0.01, lr_decay_factor=0.1, scheduler_patience=8)

Optimizer parameters (training-time).

InferenceConfig dataclass

InferenceConfig(gamma=0.5, discard_samples=True, log_ratio_threshold=-20.0, sample_reference_posterior=False, use_best_models_during_inference=True, num_proposals=256, reference_samples=128, hypercube_bound=2.0, out_of_bounds_penalty=100.0, nan_replacement=-100.0)

Inference and sampling parameters.

TrainingLoopConfig dataclass

TrainingLoopConfig(num_epochs=100, batch_size=128, early_stop_patience=16, reset_network_after_pause=False, cache_sync_every=0, max_cache_samples=0, cache_on_device=False)

Generic training loop parameters.