Machine Learning for Astroparticle Physics:
A Crash-course in SBI

Lecture 2a - Linear regression as conditional density estimation

Christoph Weniger — University of Amsterdam (GRAPPA)

Today's Roadmap

The goal: approximate the true posterior \(p(\theta\mid x)\) density, as defined through the simulator \(p(x\mid\theta)\) and the prior \(p(\theta)\), with a fitting function \(q_\phi(\theta\mid x)\).
1
Conditional densities from linear regression
Read the joint cloud column-wise: a Gaussian band \(\mathcal{N}(\mu_\theta(x), \sigma_\theta^{2})\) from a linear basis
2
Fitting by maximum likelihood
Gaussian likelihood becomes a loss; closed-form weights and variance
3
Overfitting and validation
When flexible models memorise noise, and how train / validation / test catch it
The thread: least-squares fitting is already conditional density estimation. Make it honest about uncertainty, then watch where a fixed basis breaks — which sets up Lecture 2b.

Conditional densities from linear regression

Read the joint cloud column-wise: the simplest \(q(\theta\mid x)\) comes out almost for free.

What we are after: the conditional density

A simulator gives us pairs \((\theta, x)\): parameter in, observation out. Plotted together they form a joint cloud.

Our job is to describe the cloud in the \(\theta\) direction: at each value of \(x\), what is the distribution of \(\theta\) consistent with it?

\[ q(\theta\mid x) \;=\; \text{vertical slice of the cloud at } x \]

Move the slider: the orange band picks out a slice; the points inside are samples of \(q(\theta\mid x_\mathrm{obs})\).

For now we assume \(q(\theta\mid x)\) is uni-modal at every \(x\). The multi-modal case waits until the last section.

The simplest conditional density model

Fit a curve \(\mu_\theta(x)\) through the cloud and assume the slice at every \(x\) is Gaussian with the same width \(\sigma_\theta\).

\[ q_\phi(\theta\mid x) \;=\; \mathcal{N}\!\bigl(\theta\,\big|\,\mu_\theta(x),\,\sigma_\theta^{2}\bigr) \]

Trainable parameters \(\phi = (\mathbf{w}, \sigma_\theta)\): the curve and the band width.

Black line: \(\mu_\theta(x)\). Shaded: \(\mu_\theta(x) \pm \sigma_\theta\).

The linear-basis picture

Pick basis functions \(\phi_j(x)\) that encode plausible functional behaviour. Combine \(M\) of them with free parameters:

\[ \mu_\theta(x; \mathbf{w}) \;=\; \sum_{j=0}^{M-1} w_j\,\phi_j(x) \;=\; \mathbf{w}^T\boldsymbol{\phi}(x), \qquad \mathbf{w} \in \mathbb{R}^M \]
  • Basis functions \(\phi_j(x)\) — fixed templates capturing expected shapes
  • Weight vector \(\mathbf{w} = (w_0, \ldots, w_{M-1})^T\) — \(M\) free parameters, fitted to data
  • "Linear" refers to linearity in \(\mathbf{w}\), not in \(x\) — the templates themselves can be nonlinear

The modelling decision is the choice of \(\boldsymbol{\phi}\). Everything that follows in this lecture is one continuous attack on that choice.

Example basis functions

Click a family to see its shape. Each \(\phi_j(x)\) is one template — stack \(M\) of them to build your model.

Sampling random linear model realisations

Pick random weights \(\mathbf{w}\); they define a function \(\mu_\theta(x) = \mathbf{w}^T\boldsymbol{\phi}(x)\). Observe noisy samples \(\theta_n = \mu_\theta(x_n) + \varepsilon\), \(\varepsilon \sim \mathcal{N}(0, \sigma_\theta^{2})\).

Black: true curve \(\mu_\theta(x)\). Red points: noisy samples \(\theta_n\). Given only the points, can you recover the black curve?

Fitting by maximum likelihood

The Gaussian likelihood becomes a loss; the weights and the variance follow in closed form.

From model likelihood to loss function

With independent Gaussian noise, the joint likelihood factorises:

\[ p(\boldsymbol\theta \mid \mathbf{w}, \sigma_\theta) \;=\; \prod_{n=1}^{N} \mathcal{N}\!\bigl(\theta_n \,\big|\, \mu_\theta(x_n;\mathbf{w}),\, \sigma_\theta^{2}\bigr) \]

Take the log and write the constant explicitly in terms of \(N\) and \(\sigma_\theta\):

\[ \ln p(\boldsymbol\theta \mid \mathbf{w}, \sigma_\theta) \;=\; -\frac{N}{2}\ln\!\bigl(2\pi\sigma_\theta^{2}\bigr) \;-\; \frac{1}{2\sigma_\theta^{2}}\sum_{n=1}^{N}\bigl(\theta_n - \mu_\theta(x_n;\mathbf{w})\bigr)^{2} \]

Negate to turn maximising the likelihood into minimising a loss:

\[ E(\mathbf{w}, \sigma_\theta) \;=\; \frac{N}{2}\ln\!\bigl(2\pi\sigma_\theta^{2}\bigr) \;+\; \frac{1}{2\sigma_\theta^{2}}\sum_{n=1}^{N}\bigl(\theta_n - \mu_\theta(x_n;\mathbf{w})\bigr)^{2} \]

Solving for the weights and the variance

Split the loss into a \(\sigma_\theta\)-independent shape \(S(\mathbf{w})\) and a \(\sigma_\theta\)-only piece:

\[ E(\mathbf{w}, \sigma_\theta) \;=\; \frac{N}{2}\ln\!\bigl(2\pi\sigma_\theta^{2}\bigr) \;+\; \frac{S(\mathbf{w})}{2\sigma_\theta^{2}}, \qquad S(\mathbf{w}) = \sum_{n}\bigl(\theta_n - \mathbf{w}^T\boldsymbol{\phi}(x_n)\bigr)^{2} \]

Step 1 — weights. The \(\sigma_\theta\) prefactor is positive, so minimising \(E\) in \(\mathbf{w}\) is the same as minimising \(S(\mathbf{w})\). Set \(\nabla_\mathbf{w} S = 0\); with the design matrix \(\boldsymbol\Phi_{nj} = \phi_j(x_n)\),

\[ \mathbf{w}_{\mathrm{ML}} \;=\; (\boldsymbol\Phi^T\boldsymbol\Phi)^{-1}\boldsymbol\Phi^T\boldsymbol\theta \]

Step 2 — variance. Plug \(\mathbf{w}_{\mathrm{ML}}\) back and set \(\partial E/\partial\sigma_\theta = 0\):

\[ \frac{N}{\sigma_\theta} \;-\; \frac{S(\mathbf{w}_{\mathrm{ML}})}{\sigma_\theta^{3}} \;=\; 0 \quad\Longrightarrow\quad \sigma_\theta^{2} \;=\; \frac{1}{N}\sum_{n}\bigl(\theta_n - \mu_\theta(x_n;\mathbf{w}_{\mathrm{ML}})\bigr)^{2} \]

The MLE divides by \(N\). The unbiased estimator divides by \(N-M\) (Bessel-style correction for the \(M\) fitted weights); the two agree as \(N \to \infty\).

Worked example: polynomial fit with a Gaussian band

Black curve: \(\mu(x;\mathbf{w}_{\mathrm{ML}})\). Shaded band: \(\mu(x)\pm\sigma_{\mathrm{ML}}\). The fitted Gaussian density is \(q(\theta\mid x)=\mathcal{N}(\mu(x),\sigma^2)\).

Overfitting and validation

A flexible enough basis can memorise the noise. Held-out data is how we catch it.

Interactive: Polynomial Regression

RMSE = root mean square error: \(\sqrt{\frac{1}{N}\sum_n (\theta_n - \mu_\theta(x_n;\mathbf{w}))^2}\).

Click Next ▸ to walk through the key lessons.

Overfitting & Underfitting

Underfitting

Model too simple — misses the pattern.

Training error: high
Validation error: high

Good fit

Right complexity.

Training error: low
Validation error: low

Overfitting

Model too flexible — memorises noise.

Training error: very low
Validation error: high

The gap between training and validation error is the signature of overfitting.

Validation & Testing

Training set — fit parameters \(\mathbf{w}\).

Validation set — choose model complexity (\(M\), \(\lambda\), architecture, ...).

Test set — one-shot final score. Look once.

Training (70%)
Val (15%)
Test (15%)

Never use test data for model selection — it defeats the purpose.

Training error keeps decreasing with \(M\). Validation error is U-shaped — the minimum picks the sweet spot.

Choosing a basis gets hard in higher dimensions

In a \(D\)-dimensional input \(\mathbf{x} = (x_1, \ldots, x_D)\) we still write \(\mu_\theta(\mathbf{x}) = \mathbf{w}^T\boldsymbol{\phi}(\mathbf{x})\), but the basis must be hand-picked to tile the whole space:

  • Polynomials: \(1,\; x_1,\; x_2,\; x_1^2,\; x_1 x_2,\; x_2^2,\; \ldots\)
  • Gaussian blobs: \(e^{-\lVert\mathbf{x}-\boldsymbol{\mu}_1\rVert^2/2s^2},\; e^{-\lVert\mathbf{x}-\boldsymbol{\mu}_2\rVert^2/2s^2},\; \ldots\)
  • Fourier modes: \(\sin(\mathbf{k}_1\!\cdot\mathbf{x}),\; \cos(\mathbf{k}_1\!\cdot\mathbf{x}),\; \sin(\mathbf{k}_2\!\cdot\mathbf{x}),\; \ldots\)

Tiling \(D\) dimensions needs \(\sim M^D\) of them, and you must place every centre, width and frequency yourself: the curse of dimensionality. Lecture 2b's fix: let the network learn the basis.

Key Takeaways

  • Regression is conditional density estimation: read the joint cloud column-wise, and a least-squares fit already hands you a full \(q(\theta\mid x)\) — the fitted curve is the mean, the residual scatter is the width.
  • Gaussian band model: assume the slice at every \(x\) is Gaussian with a mean curve and a single shared width — the simplest honest model of uncertainty.
  • Linear basis: build the mean curve as a weighted sum of fixed basis functions (polynomials, Gaussians, sigmoids). It is linear in the weights even when the basis itself is nonlinear.
  • Maximum likelihood = least squares: Gaussian noise turns the log-likelihood into a sum-of-squares loss; both the weights and the noise width then follow in closed form.
  • Overfitting and validation: a flexible enough basis memorises noise — training error keeps dropping while validation error turns up. Use train / validation / test, and never tune on the test set.
  • The catch with fixed bases: one knob (the basis size) trades expressiveness against noise sensitivity, and hand-picking a good basis gets exponentially harder in higher dimensions.