SNPE_A
Sequential Neural Posterior Estimation (SNPE-A) implementation.
Overview
SNPE_A is the primary estimator in Falcon for learning posterior distributions.
It uses dual normalizing flows (conditional and marginal) with importance sampling
for adaptive proposal generation.
Key features:
- Dual flow architecture for posterior and proposal sampling
- Parameter space normalization via hypercube mapping
- Importance sampling with effective sample size monitoring
- Automatic learning rate scheduling and early stopping
Configuration
SNPE_A is configured via nested dataclasses:
estimator:
_target_: falcon.contrib.SNPE_A
loop:
num_epochs: 300
batch_size: 128
network:
net_type: nsf
optimizer:
lr: 0.01
inference:
gamma: 0.5
Class Reference
SNPE_A
SNPE_A(simulator_instance, theta_key=None, condition_keys=None, config=None)
Bases: StepwiseEstimator
Sequential Neural Posterior Estimation (SNPE-A).
Implementation-specific features:
- Dual flow architecture (conditional + marginal)
- Parameter space normalization via hypercube mapping
- Importance sampling for posterior/proposal
Initialize SNPE_A estimator.
Parameters:
| Name |
Type |
Description |
Default |
simulator_instance
|
|
|
required
|
theta_key
|
Optional[str]
|
Key for theta in batch data
|
None
|
condition_keys
|
Optional[List[str]]
|
Keys for condition data in batch
|
None
|
config
|
Optional[dict]
|
Configuration dict with loop, network, optimizer, inference sections
|
None
|
Source code in falcon/contrib/SNPE_A.py
| def __init__(
self,
simulator_instance,
theta_key: Optional[str] = None,
condition_keys: Optional[List[str]] = None,
config: Optional[dict] = None,
):
"""
Initialize SNPE_A estimator.
Args:
simulator_instance: Prior/simulator instance
theta_key: Key for theta in batch data
condition_keys: Keys for condition data in batch
config: Configuration dict with loop, network, optimizer, inference sections
"""
# Merge user config with defaults using OmegaConf structured config
schema = OmegaConf.structured(SNPEConfig)
config = OmegaConf.merge(schema, config or {})
super().__init__(
simulator_instance=simulator_instance,
loop_config=config.loop,
theta_key=theta_key,
condition_keys=condition_keys,
)
self.config = config
# Device setup
self.device = self._setup_device(config.device)
# Embedding network
# Convert to plain dict for instantiate_embedding which uses isinstance(x, dict)
embedding_config = OmegaConf.to_container(config.network.embedding, resolve=True)
self._embedding = instantiate_embedding(embedding_config).to(self.device)
# Flow networks (initialized lazily)
self._conditional_flow = None
self._marginal_flow = None
self._best_conditional_flow = None
self._best_marginal_flow = None
self._best_embedding = None
self._init_parameters = None
# Best loss tracking
self.best_conditional_flow_val_loss = float("inf")
self.best_marginal_flow_val_loss = float("inf")
# Optimizer/scheduler (initialized lazily)
self._optimizer = None
self._scheduler = None
# Extended history for SNPE-specific tracking
self.history.update({
"theta_mins": [],
"theta_maxs": [],
})
|
train_step
SNPE-A training step with gradient update and optional sample discarding.
Source code in falcon/contrib/SNPE_A.py
| def train_step(self, batch) -> Dict[str, float]:
"""SNPE-A training step with gradient update and optional sample discarding."""
ids, theta, theta_logprob, conditions, u, u_device, conditions_device = \
self._unpack_batch(batch, "train")
# Initialize networks on first batch
if not self.networks_initialized:
self._initialize_networks(u, conditions)
# Embed conditions
s = self._embed(conditions_device, train=True)
# Track theta ranges
with torch.no_grad():
self.history["theta_mins"].append(theta.min(dim=0).values.cpu().numpy())
self.history["theta_maxs"].append(theta.max(dim=0).values.cpu().numpy())
# Forward and backward pass
self._optimizer.zero_grad()
loss_cond, loss_marg = self._compute_flow_losses(u_device, s, train=True)
(loss_cond + loss_marg).backward()
self._optimizer.step()
# Discard samples based on log-likelihood ratio
if self.config.inference.discard_samples:
discard_mask = self._compute_discard_mask(theta, theta_logprob, conditions_device)
batch.discard(discard_mask)
return {"loss": loss_cond.item(), "loss_aux": loss_marg.item()}
|
val_step
SNPE-A validation step without gradient computation.
Source code in falcon/contrib/SNPE_A.py
| def val_step(self, batch) -> Dict[str, float]:
"""SNPE-A validation step without gradient computation."""
_, theta, theta_logprob, conditions, u, u_device, conditions_device = \
self._unpack_batch(batch, "val")
# Embed conditions (eval mode)
s = self._embed(conditions_device, train=False)
# Compute losses without gradients
with torch.no_grad():
loss_cond, loss_marg = self._compute_flow_losses(u_device, s, train=False)
return {"loss": loss_cond.item(), "loss_aux": loss_marg.item()}
|
sample_prior
sample_prior(num_samples, parent_conditions=None)
Sample from the prior distribution.
Source code in falcon/contrib/SNPE_A.py
| def sample_prior(self, num_samples: int, parent_conditions: Optional[List] = None) -> RVBatch:
"""Sample from the prior distribution."""
if parent_conditions:
raise ValueError("Conditions are not supported for sample_prior.")
samples = self.simulator_instance.simulate_batch(num_samples)
# Log probability for uniform prior over hypercube [-bound, bound]^d
bound = self.config.inference.hypercube_bound
logprob = np.ones(num_samples) * (-np.log(2 * bound) ** self.param_dim)
return RVBatch(samples, logprob=logprob)
|
sample_posterior
sample_posterior(num_samples, parent_conditions=None, evidence_conditions=None)
Sample from the posterior distribution q(theta|x).
Source code in falcon/contrib/SNPE_A.py
| def sample_posterior(
self,
num_samples: int,
parent_conditions: Optional[List] = None,
evidence_conditions: Optional[List] = None,
) -> RVBatch:
"""Sample from the posterior distribution q(theta|x)."""
# Fall back to prior if networks not yet initialized (training hasn't started)
if not self.networks_initialized:
return self.sample_prior(num_samples, parent_conditions)
samples, logprob = self._importance_sample(
num_samples,
mode="posterior",
parent_conditions=parent_conditions or [],
evidence_conditions=evidence_conditions or [],
)
return RVBatch(samples.numpy(), logprob=logprob.numpy())
|
sample_proposal
sample_proposal(num_samples, parent_conditions=None, evidence_conditions=None)
Sample from the widened proposal distribution for adaptive resampling.
Source code in falcon/contrib/SNPE_A.py
| def sample_proposal(
self,
num_samples: int,
parent_conditions: Optional[List] = None,
evidence_conditions: Optional[List] = None,
) -> RVBatch:
"""Sample from the widened proposal distribution for adaptive resampling."""
# Fall back to prior if networks not yet initialized (training hasn't started)
if not self.networks_initialized:
return self.sample_prior(num_samples, parent_conditions)
cfg_inf = self.config.inference
parent_conditions = parent_conditions or []
evidence_conditions = evidence_conditions or []
if cfg_inf.sample_reference_posterior:
post_samples, _ = self._importance_sample(
cfg_inf.reference_samples,
mode="posterior",
parent_conditions=parent_conditions,
evidence_conditions=evidence_conditions,
)
mean, std = post_samples.mean(dim=0).cpu(), post_samples.std(dim=0).cpu()
log({f"sample_proposal:posterior_mean_{i}": mean[i].item() for i in range(len(mean))})
log({f"sample_proposal:posterior_std_{i}": std[i].item() for i in range(len(std))})
samples, logprob = self._importance_sample(
num_samples,
mode="proposal",
parent_conditions=parent_conditions,
evidence_conditions=evidence_conditions,
)
log({
"sample_proposal:mean": samples.mean().item(),
"sample_proposal:std": samples.std().item(),
"sample_proposal:logprob": logprob.mean().item(),
})
return RVBatch(samples.numpy(), logprob=logprob.numpy())
|
save
Save SNPE-A state.
Source code in falcon/contrib/SNPE_A.py
| def save(self, node_dir: Path) -> None:
"""Save SNPE-A state."""
debug(f"Saving: {node_dir}")
if not self.networks_initialized:
raise RuntimeError("Networks not initialized.")
torch.save(self._best_conditional_flow.state_dict(), node_dir / "conditional_flow.pth")
torch.save(self._best_marginal_flow.state_dict(), node_dir / "marginal_flow.pth")
torch.save(self._init_parameters, node_dir / "init_parameters.pth")
# Save history
torch.save(self.history["train_ids"], node_dir / "train_id_history.pth")
torch.save(self.history["val_ids"], node_dir / "validation_id_history.pth")
torch.save(self.history["theta_mins"], node_dir / "theta_mins_batches.pth")
torch.save(self.history["theta_maxs"], node_dir / "theta_maxs_batches.pth")
torch.save(self.history["epochs"], node_dir / "epochs.pth")
torch.save(self.history["train_loss"], node_dir / "loss_train_posterior.pth")
torch.save(self.history["val_loss"], node_dir / "loss_val_posterior.pth")
torch.save(self.history["n_samples"], node_dir / "n_samples_total.pth")
torch.save(self.history["elapsed_min"], node_dir / "elapsed_minutes.pth")
if self._best_embedding is not None:
torch.save(self._best_embedding.state_dict(), node_dir / "embedding.pth")
|
load
Load SNPE-A state.
Source code in falcon/contrib/SNPE_A.py
| def load(self, node_dir: Path) -> None:
"""Load SNPE-A state."""
debug(f"Loading: {node_dir}")
init_parameters = torch.load(node_dir / "init_parameters.pth")
self._initialize_networks(init_parameters[0], init_parameters[1])
self._best_conditional_flow.load_state_dict(
torch.load(node_dir / "conditional_flow.pth")
)
self._best_marginal_flow.load_state_dict(
torch.load(node_dir / "marginal_flow.pth")
)
if (node_dir / "embedding.pth").exists() and self._best_embedding is not None:
self._best_embedding.load_state_dict(torch.load(node_dir / "embedding.pth"))
|
Configuration Classes
SNPEConfig
dataclass
Top-level SNPE_A configuration.
NetworkConfig
dataclass
NetworkConfig(net_type='zuko_nice', theta_norm=True, norm_momentum=0.01, adaptive_momentum=False, use_log_update=False, embedding=None)
Neural network architecture parameters.
OptimizerConfig
dataclass
OptimizerConfig(lr=0.01, lr_decay_factor=0.1, scheduler_patience=8)
Optimizer parameters (training-time).
InferenceConfig
dataclass
InferenceConfig(gamma=0.5, discard_samples=True, log_ratio_threshold=-20.0, sample_reference_posterior=False, use_best_models_during_inference=True, num_proposals=256, reference_samples=128, hypercube_bound=2.0, out_of_bounds_penalty=100.0, nan_replacement=-100.0)
Inference and sampling parameters.
TrainingLoopConfig
dataclass
TrainingLoopConfig(num_epochs=100, batch_size=128, early_stop_patience=16, reset_network_after_pause=False)
Generic training loop parameters.