Skip to content

SNPE_A

Sequential Neural Posterior Estimation (SNPE-A) implementation.

Overview

SNPE_A is the primary estimator in Falcon for learning posterior distributions. It uses dual normalizing flows (conditional and marginal) with importance sampling for adaptive proposal generation.

Key features:

  • Dual flow architecture for posterior and proposal sampling
  • Parameter space normalization via hypercube mapping
  • Importance sampling with effective sample size monitoring
  • Automatic learning rate scheduling and early stopping

Configuration

SNPE_A is configured via nested dataclasses:

estimator:
  _target_: falcon.contrib.SNPE_A
  loop:
    num_epochs: 300
    batch_size: 128
  network:
    net_type: nsf
  optimizer:
    lr: 0.01
  inference:
    gamma: 0.5

Class Reference

SNPE_A

SNPE_A(simulator_instance, theta_key=None, condition_keys=None, config=None)

Bases: StepwiseEstimator

Sequential Neural Posterior Estimation (SNPE-A).

Implementation-specific features: - Dual flow architecture (conditional + marginal) - Parameter space normalization via hypercube mapping - Importance sampling for posterior/proposal

Initialize SNPE_A estimator.

Parameters:

Name Type Description Default
simulator_instance

Prior/simulator instance

required
theta_key Optional[str]

Key for theta in batch data

None
condition_keys Optional[List[str]]

Keys for condition data in batch

None
config Optional[dict]

Configuration dict with loop, network, optimizer, inference sections

None
Source code in falcon/contrib/SNPE_A.py
def __init__(
    self,
    simulator_instance,
    theta_key: Optional[str] = None,
    condition_keys: Optional[List[str]] = None,
    config: Optional[dict] = None,
):
    """
    Initialize SNPE_A estimator.

    Args:
        simulator_instance: Prior/simulator instance
        theta_key: Key for theta in batch data
        condition_keys: Keys for condition data in batch
        config: Configuration dict with loop, network, optimizer, inference sections
    """
    # Merge user config with defaults using OmegaConf structured config
    schema = OmegaConf.structured(SNPEConfig)
    config = OmegaConf.merge(schema, config or {})

    super().__init__(
        simulator_instance=simulator_instance,
        loop_config=config.loop,
        theta_key=theta_key,
        condition_keys=condition_keys,
    )

    self.config = config

    # Device setup
    self.device = self._setup_device(config.device)

    # Embedding network
    # Convert to plain dict for instantiate_embedding which uses isinstance(x, dict)
    embedding_config = OmegaConf.to_container(config.network.embedding, resolve=True)
    self._embedding = instantiate_embedding(embedding_config).to(self.device)

    # Flow networks (initialized lazily)
    self._conditional_flow = None
    self._marginal_flow = None
    self._best_conditional_flow = None
    self._best_marginal_flow = None
    self._best_embedding = None
    self._init_parameters = None

    # Best loss tracking
    self.best_conditional_flow_val_loss = float("inf")
    self.best_marginal_flow_val_loss = float("inf")

    # Optimizer/scheduler (initialized lazily)
    self._optimizer = None
    self._scheduler = None

    # Extended history for SNPE-specific tracking
    self.history.update({
        "theta_mins": [],
        "theta_maxs": [],
    })

train_step

train_step(batch)

SNPE-A training step with gradient update and optional sample discarding.

Source code in falcon/contrib/SNPE_A.py
def train_step(self, batch) -> Dict[str, float]:
    """SNPE-A training step with gradient update and optional sample discarding."""
    ids, theta, theta_logprob, conditions, u, u_device, conditions_device = \
        self._unpack_batch(batch, "train")

    # Initialize networks on first batch
    if not self.networks_initialized:
        self._initialize_networks(u, conditions)

    # Embed conditions
    s = self._embed(conditions_device, train=True)

    # Track theta ranges
    with torch.no_grad():
        self.history["theta_mins"].append(theta.min(dim=0).values.cpu().numpy())
        self.history["theta_maxs"].append(theta.max(dim=0).values.cpu().numpy())

    # Forward and backward pass
    self._optimizer.zero_grad()
    loss_cond, loss_marg = self._compute_flow_losses(u_device, s, train=True)
    (loss_cond + loss_marg).backward()
    self._optimizer.step()

    # Discard samples based on log-likelihood ratio
    if self.config.inference.discard_samples:
        discard_mask = self._compute_discard_mask(theta, theta_logprob, conditions_device)
        batch.discard(discard_mask)

    return {"loss": loss_cond.item(), "loss_aux": loss_marg.item()}

val_step

val_step(batch)

SNPE-A validation step without gradient computation.

Source code in falcon/contrib/SNPE_A.py
def val_step(self, batch) -> Dict[str, float]:
    """SNPE-A validation step without gradient computation."""
    _, theta, theta_logprob, conditions, u, u_device, conditions_device = \
        self._unpack_batch(batch, "val")

    # Embed conditions (eval mode)
    s = self._embed(conditions_device, train=False)

    # Compute losses without gradients
    with torch.no_grad():
        loss_cond, loss_marg = self._compute_flow_losses(u_device, s, train=False)

    return {"loss": loss_cond.item(), "loss_aux": loss_marg.item()}

sample_prior

sample_prior(num_samples, parent_conditions=None)

Sample from the prior distribution.

Source code in falcon/contrib/SNPE_A.py
def sample_prior(self, num_samples: int, parent_conditions: Optional[List] = None) -> RVBatch:
    """Sample from the prior distribution."""
    if parent_conditions:
        raise ValueError("Conditions are not supported for sample_prior.")
    samples = self.simulator_instance.simulate_batch(num_samples)
    # Log probability for uniform prior over hypercube [-bound, bound]^d
    bound = self.config.inference.hypercube_bound
    logprob = np.ones(num_samples) * (-np.log(2 * bound) ** self.param_dim)
    return RVBatch(samples, logprob=logprob)

sample_posterior

sample_posterior(num_samples, parent_conditions=None, evidence_conditions=None)

Sample from the posterior distribution q(theta|x).

Source code in falcon/contrib/SNPE_A.py
def sample_posterior(
    self,
    num_samples: int,
    parent_conditions: Optional[List] = None,
    evidence_conditions: Optional[List] = None,
) -> RVBatch:
    """Sample from the posterior distribution q(theta|x)."""
    # Fall back to prior if networks not yet initialized (training hasn't started)
    if not self.networks_initialized:
        return self.sample_prior(num_samples, parent_conditions)

    samples, logprob = self._importance_sample(
        num_samples,
        mode="posterior",
        parent_conditions=parent_conditions or [],
        evidence_conditions=evidence_conditions or [],
    )
    return RVBatch(samples.numpy(), logprob=logprob.numpy())

sample_proposal

sample_proposal(num_samples, parent_conditions=None, evidence_conditions=None)

Sample from the widened proposal distribution for adaptive resampling.

Source code in falcon/contrib/SNPE_A.py
def sample_proposal(
    self,
    num_samples: int,
    parent_conditions: Optional[List] = None,
    evidence_conditions: Optional[List] = None,
) -> RVBatch:
    """Sample from the widened proposal distribution for adaptive resampling."""
    # Fall back to prior if networks not yet initialized (training hasn't started)
    if not self.networks_initialized:
        return self.sample_prior(num_samples, parent_conditions)

    cfg_inf = self.config.inference
    parent_conditions = parent_conditions or []
    evidence_conditions = evidence_conditions or []

    if cfg_inf.sample_reference_posterior:
        post_samples, _ = self._importance_sample(
            cfg_inf.reference_samples,
            mode="posterior",
            parent_conditions=parent_conditions,
            evidence_conditions=evidence_conditions,
        )
        mean, std = post_samples.mean(dim=0).cpu(), post_samples.std(dim=0).cpu()
        log({f"sample_proposal:posterior_mean_{i}": mean[i].item() for i in range(len(mean))})
        log({f"sample_proposal:posterior_std_{i}": std[i].item() for i in range(len(std))})

    samples, logprob = self._importance_sample(
        num_samples,
        mode="proposal",
        parent_conditions=parent_conditions,
        evidence_conditions=evidence_conditions,
    )
    log({
        "sample_proposal:mean": samples.mean().item(),
        "sample_proposal:std": samples.std().item(),
        "sample_proposal:logprob": logprob.mean().item(),
    })
    return RVBatch(samples.numpy(), logprob=logprob.numpy())

save

save(node_dir)

Save SNPE-A state.

Source code in falcon/contrib/SNPE_A.py
def save(self, node_dir: Path) -> None:
    """Save SNPE-A state."""
    debug(f"Saving: {node_dir}")
    if not self.networks_initialized:
        raise RuntimeError("Networks not initialized.")

    torch.save(self._best_conditional_flow.state_dict(), node_dir / "conditional_flow.pth")
    torch.save(self._best_marginal_flow.state_dict(), node_dir / "marginal_flow.pth")
    torch.save(self._init_parameters, node_dir / "init_parameters.pth")

    # Save history
    torch.save(self.history["train_ids"], node_dir / "train_id_history.pth")
    torch.save(self.history["val_ids"], node_dir / "validation_id_history.pth")
    torch.save(self.history["theta_mins"], node_dir / "theta_mins_batches.pth")
    torch.save(self.history["theta_maxs"], node_dir / "theta_maxs_batches.pth")
    torch.save(self.history["epochs"], node_dir / "epochs.pth")
    torch.save(self.history["train_loss"], node_dir / "loss_train_posterior.pth")
    torch.save(self.history["val_loss"], node_dir / "loss_val_posterior.pth")
    torch.save(self.history["n_samples"], node_dir / "n_samples_total.pth")
    torch.save(self.history["elapsed_min"], node_dir / "elapsed_minutes.pth")

    if self._best_embedding is not None:
        torch.save(self._best_embedding.state_dict(), node_dir / "embedding.pth")

load

load(node_dir)

Load SNPE-A state.

Source code in falcon/contrib/SNPE_A.py
def load(self, node_dir: Path) -> None:
    """Load SNPE-A state."""
    debug(f"Loading: {node_dir}")
    init_parameters = torch.load(node_dir / "init_parameters.pth")
    self._initialize_networks(init_parameters[0], init_parameters[1])

    self._best_conditional_flow.load_state_dict(
        torch.load(node_dir / "conditional_flow.pth")
    )
    self._best_marginal_flow.load_state_dict(
        torch.load(node_dir / "marginal_flow.pth")
    )

    if (node_dir / "embedding.pth").exists() and self._best_embedding is not None:
        self._best_embedding.load_state_dict(torch.load(node_dir / "embedding.pth"))

Configuration Classes

SNPEConfig dataclass

SNPEConfig(loop=TrainingLoopConfig(), network=NetworkConfig(), optimizer=OptimizerConfig(), inference=InferenceConfig(), device=None)

Top-level SNPE_A configuration.

NetworkConfig dataclass

NetworkConfig(net_type='zuko_nice', theta_norm=True, norm_momentum=0.01, adaptive_momentum=False, use_log_update=False, embedding=None)

Neural network architecture parameters.

OptimizerConfig dataclass

OptimizerConfig(lr=0.01, lr_decay_factor=0.1, scheduler_patience=8)

Optimizer parameters (training-time).

InferenceConfig dataclass

InferenceConfig(gamma=0.5, discard_samples=True, log_ratio_threshold=-20.0, sample_reference_posterior=False, use_best_models_during_inference=True, num_proposals=256, reference_samples=128, hypercube_bound=2.0, out_of_bounds_penalty=100.0, nan_replacement=-100.0)

Inference and sampling parameters.

TrainingLoopConfig dataclass

TrainingLoopConfig(num_epochs=100, batch_size=128, early_stop_patience=16, reset_network_after_pause=False)

Generic training loop parameters.