Exercise 8: Multilevel regression

In Exercise 7 we fit a single regression line to all the data at once. But many real datasets have group structure: measurements come from different subjects, sites, batches, or schools, and each group may behave somewhat differently. A multilevel (or hierarchical) model treats each group as having its own parameters, while tying those parameters together through a shared prior whose hyperparameters are themselves learned from the data. This produces partial pooling: groups with little data borrow strength from the global pattern, while groups with lots of data are allowed to deviate.

The dataset is the classic sleepstudy from Belenky et al. (2003), in which 18 subjects had their sleep restricted to 3 hours per night for 9 nights and their reaction times measured daily on a psychomotor vigilance task. We will fit three regressions to it — complete pooling, no pooling, and partial pooling — diagnose their MCMC fits, visualize the shrinkage that partial pooling induces, criticize each model with posterior predictive checks, and finally compare them by their expected out-of-sample predictive accuracy via PSIS-LOO and a held-out test set.

The BDA workflow from 5  Bayesian Data Analysis includes a prior predictive check and a fake-data parameter-recovery step before fitting on real data. We practiced both in Exercise 7 and will not repeat them here in order to keep this sheet focused on the multilevel-specific issues (pooling, shrinkage, model comparison). For your own modeling work, treat their omission as a shortcut justified by the simplicity of the linear-Gaussian likelihood, not as a recommendation.

Setup

This exercise uses the same libraries as Exercise 7, with the addition of pandas for loading the CSV file:

  • JAX — numerical computing backend.
  • NumPyro — probabilistic programming and MCMC inference.
  • ArviZ — diagnostics and visualization.
  • matplotlib — plotting.
  • pandas — reading the CSV data file.

Install via:

pip install jax numpyro arviz matplotlib pandas
import matplotlib.pyplot as plt
import pandas as pd

import jax
from jax import random
import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive

import arviz as az

rng_key = random.PRNGKey(42)

The data

The sleepstudy dataset records the average reaction time (in milliseconds) on a 10-minute psychomotor vigilance task for each of 18 subjects on each of 10 consecutive days of sleep restriction. To make the multilevel structure more interesting, we have artificially thinned the data so that some subjects retain all 10 measurements, others only 5, and a few only 2. The dropped rows are kept aside in a separate held-out file and we will use them in Part 5 as a clean test set for model comparison.

Load the data from sleepstudy_processed.csv:

df = pd.read_csv("../data/exercise-8/sleepstudy_processed.csv")

day = jnp.array(df["day"].to_numpy(), dtype=jnp.float32)
reaction = jnp.array(df["reaction"].to_numpy(), dtype=jnp.float32)
subject_idx = jnp.array(df["subject_idx"].to_numpy(), dtype=jnp.int32)

n_subjects = int(df["subject_idx"].nunique())
n_obs = len(df)
n_subjects, n_obs
(18, 110)

The columns are subject_idx (taking values in \(\{0, \ldots, S-1\}\)), day, and reaction. There are \(S = 18\) subjects and \(N = 110\) observations in total.

Task 0 — Visualize the per-subject trends. Plot reaction time vs. day for each subject. Fit a line through each subject’s points and overlay it. You can fit a line through each subject’s points with jnp.polyfit, passing deg=1 to obtain the slope and intercept. What features do you notice that a single global regression would miss?

Code
fig, axes = plt.subplots(5, 4, figsize=(7, 7), sharex=True, sharey=True)
day_grid = jnp.linspace(0, 9, 50)

for s, ax in enumerate(axes.flat):
    if s >= n_subjects:
        ax.axis("off")
        continue
    mask = subject_idx == s
    xs = day[mask]
    ys = reaction[mask]
    slope, intercept = jnp.polyfit(xs, ys, 1)
    ax.scatter(xs, ys, s=15, color="black", alpha=0.7)
    ax.plot(day_grid, intercept + slope * day_grid, color="C1", linewidth=1.2)
    ax.set_title(f"Subject {s} (n={int(mask.sum())})", fontsize=8)

for ax in axes[-1, :]:
    ax.set_xlabel("Day")
for ax in axes[:, 0]:
    ax.set_ylabel("Reaction (ms)")
plt.tight_layout()
plt.show()
Figure 1: Reaction time (ms) vs. day for each of the 18 subjects, with a best-fit line overlaid.

Most subjects with enough observations show a positive trend: as the days of sleep restriction accumulate, reaction times rise. But the amount of degradation varies substantially across subjects, and baseline reaction times also differ markedly. A single global regression line would average these very different trajectories into a moderate slope that fits no individual particularly well.

For the subjects with only 2 observations, the line is an interpolation. It passes exactly through both points by construction and is therefore at the mercy of their noise: a single noisy measurement can flip the slope from steep to flat or even negative. We will see that this is precisely the situation in which partial pooling helps.

Standardizing the data

Before fitting any model, it pays to put the predictor and the response on a common, unit-free scale. We do this by standardizing (also called z-scoring): for each variable \(v\) we subtract the sample mean and divide by the sample standard deviation, \[ z_v = \frac{v - \bar{v}}{\text{sd}(v)}, \] so that \(z_v\) has mean \(0\) and standard deviation \(1\). We will fit our models on the standardized scale, and back-transform when we want to report results in the original units (milliseconds and days).

day_mean, day_std = day.mean(), day.std()
reaction_mean, reaction_std = reaction.mean(), reaction.std()

z_day = (day - day_mean) / day_std
z_reaction = (reaction - reaction_mean) / reaction_std

We keep the four constants day_mean, day_std, reaction_mean, reaction_std so that we can later un-standardize the model’s predictions: a posterior draw \(z_y\) on the standardized scale corresponds to \(y = \bar{y} + \text{sd}(y) \cdot z_y\) on the original scale.

Standardization helps in two ways: it makes priors easier to choose, and it makes the posterior easier for MCMC to explore.

Priors become easier to set. On the standardized scale, both predictor and response have unit standard deviation by construction, so a \(\text{Normal}(0, 1)\) prior on slopes and intercepts is automatically weakly informative in the sense of Choosing priors and prior predictive checks: it allows the data to speak while ruling out absurdly extreme values. The same prior recipe transports across datasets, since the units have been removed.

The posterior geometry becomes nicer. When the predictor is centered, the posterior correlation between the intercept and the slope is greatly reduced. This is because moving the slope no longer drags the intercept along by a large lever arm. Less correlated posteriors are easier for MCMC to explore because they don’t have narrow ridges or funnels that are tricky to navigate.

Part 1: Three pooling regimes

We will fit three regression models that differ only in how they treat the per-subject parameters: complete pooling (one global line), no pooling (one independent line per subject), and partial pooling (per-subject lines tied together by a shared prior). In each subtask below we state the model mathematically and then translate it into NumPyro code.

Throughout we use \(i = 1, \ldots, N\) for observations and \(s = 1, \ldots, S\) for subjects, and write \(s[i]\) for the subject of the \(i\)-th observation. All three models use the standardized variables \(x_i = z_{\text{day},i}\) and \(y_i = z_{\text{reaction},i}\), and share a common call signature:

def model_*(x, subject_idx, y=None, n_subjects=n_subjects): ...

Each model ends with numpyro.sample("obs", dist.Normal(mu, sigma), obs=y), so the same function generates fake data when y=None and conditions on data when y is provided. The keyword argument n_subjects=n_subjects captures the value computed in the data section as a default, so we don’t have to pass it explicitly to mcmc.run, Predictive, or log_likelihood.

Note that all three models share a single observation-level scale \(\sigma\) across subjects. Letting each subject have its own \(\sigma_s\) is a natural extension, but we keep \(\sigma\) pooled to keep the focus on the intercepts and slopes.

Task 1a — Complete pooling

Ignore subject identity entirely: one global intercept and slope describe everyone.

\[ \begin{aligned} \alpha &\sim \text{Normal}(0, 1) \\ \beta &\sim \text{Normal}(0, 1) \\ \sigma &\sim \text{Exponential}(1) \\ \mu_i &= \alpha + \beta \, x_i \\ y_i &\sim \text{Normal}(\mu_i, \sigma) \end{aligned} \]

This is just a Gaussian linear regression and serves as a baseline: if subject identity does not matter, this simple model will do as well as any.

Write model_pooled(x, subject_idx, y=None). The argument subject_idx is unused here but kept for a uniform signature.

def model_pooled(x, subject_idx, y=None):
    alpha = numpyro.sample("alpha", dist.Normal(0, 1))
    beta = numpyro.sample("beta", dist.Normal(0, 1))
    sigma = numpyro.sample("sigma", dist.Exponential(1.0))

    mu = alpha + beta * x
    numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)
Note

In this and the following models, we purposefully do not register mu with numpyro.deterministic. In Task 5 we evaluate log_likelihood on a test set of a different size, and having mu stored with the training length would causes a shape mismatch that requires a workaround.

Task 1b — No pooling

Each subject gets its own independent intercept and slope, with no information shared between subjects.

\[ \begin{aligned} \alpha_s &\sim \text{Normal}(0, 1), \quad s = 1, \ldots, S \\ \beta_s &\sim \text{Normal}(0, 1), \quad s = 1, \ldots, S \\ \sigma &\sim \text{Exponential}(1) \\ \mu_i &= \alpha_{s[i]} + \beta_{s[i]} \, x_i \\ y_i &\sim \text{Normal}(\mu_i, \sigma) \end{aligned} \]

This is the opposite extreme: each subject is fit independently, as if the other subjects did not exist. It captures the per-subject variation but cannot generalize to a new subject, and overfits subjects with very few observations.

Write model_unpooled(x, subject_idx, y=None, n_subjects=n_subjects) with one \((\alpha_s, \beta_s)\) pair per subject.

The per-subject parameters form a vector of \(S\) independent draws from the same prior. This kind of structure can be represented in NumPyro using a plate. See the box below for a quick introduction. Once you have \(\alpha_s\) and \(\beta_s\) as length-\(S\) vectors, gather the right entries for each observation with array indexing: alpha_s[subject_idx] produces a length-\(N\) vector containing the intercept of the subject for each row.

A numpyro.plate(name, size) is a context manager that turns a single numpyro.sample call into independent draws of length size. The plate corresponds to plate notation for diagrams introduced in Bayesian Data Analysis. The line

with numpyro.plate("subjects", n_subjects):
    alpha_s = numpyro.sample("alpha_s", dist.Normal(0, 1))

draws a vector of i.i.d. \(\text{Normal}(0, 1)\) values of length n_subjects and registers them as a single inference site named "alpha_s". You can have multiple sample statements inside the same plate. They are vectorized over the same axis.

Why not just a for loop? The same model could be written without a plate as

alpha_s = []

for s in range(n_subjects):
    alpha_s.append(numpyro.sample(f"alpha_{s}", dist.Normal(0, 1)))

alpha_s = jnp.stack(alpha_s)

This loop creates n_subjects separate sample sites, each with its own scalar distribution. However, this runs much slower than a single vectorized sample inside a plate, and it also creates a more cluttered trace with many more sites to keep track of. Whenever you want to use a for loop to create multiple independent parameters, it’s a sign that you should be using a plate instead.

def model_unpooled(x, subject_idx, y=None, n_subjects=n_subjects):
    with numpyro.plate("subjects", n_subjects):
        alpha_s = numpyro.sample("alpha_s", dist.Normal(0, 1))
        beta_s = numpyro.sample("beta_s", dist.Normal(0, 1))
    sigma = numpyro.sample("sigma", dist.Exponential(1.0))

    mu = alpha_s[subject_idx] + beta_s[subject_idx] * x
    numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)

Task 1c — Partial pooling

Each subject still has its own \((\alpha_s, \beta_s)\), but they are drawn from a common prior whose location and scale are themselves inferred. Subjects with few or noisy observations are pulled (shrunk) toward the population means.

\[ \begin{aligned} \mu_\alpha, \mu_\beta &\sim \text{Normal}(0, 1) \\ \tau_\alpha, \tau_\beta &\sim \text{Exponential}(1) \\ \alpha_s &\sim \text{Normal}(\mu_\alpha, \tau_\alpha), \quad s = 1, \ldots, S \\ \beta_s &\sim \text{Normal}(\mu_\beta, \tau_\beta), \quad s = 1, \ldots, S \\ \sigma &\sim \text{Exponential}(1) \\ \mu_i &= \alpha_{s[i]} + \beta_{s[i]} \, x_i \\ y_i &\sim \text{Normal}(\mu_i, \sigma) \end{aligned} \]

The hyperparameters \(\mu_\alpha\) and \(\mu_\beta\) describe the population-level intercept and slope — what we would expect for an average new subject — and \(\tau_\alpha\), \(\tau_\beta\) describe how much subjects vary around them. The model uses the data from all subjects to learn these, then uses them as a prior on each individual \(\alpha_s, \beta_s\).

Write model_partial(x, subject_idx, y=None, n_subjects=n_subjects). The structure is the same as model_unpooled, but the priors on \(\alpha_s\) and \(\beta_s\) now use the learned hyperparameters \((\mu_\alpha, \tau_\alpha)\) and \((\mu_\beta, \tau_\beta)\) instead of fixed \((0, 1)\). Sample the hyperparameters outside the plate.

def model_partial(x, subject_idx, y=None, n_subjects=n_subjects):
    mu_alpha = numpyro.sample("mu_alpha", dist.Normal(0, 1))
    mu_beta = numpyro.sample("mu_beta", dist.Normal(0, 1))
    tau_alpha = numpyro.sample("tau_alpha", dist.Exponential(1.0))
    tau_beta = numpyro.sample("tau_beta", dist.Exponential(1.0))

    with numpyro.plate("subjects", n_subjects):
        alpha_s = numpyro.sample("alpha_s", dist.Normal(mu_alpha, tau_alpha))
        beta_s = numpyro.sample("beta_s", dist.Normal(mu_beta, tau_beta))
    sigma = numpyro.sample("sigma", dist.Exponential(1.0))

    mu = alpha_s[subject_idx] + beta_s[subject_idx] * x
    numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)

When you suspect a shape bug, the Debugging models technique from Implementing in NumPyro is invaluable: trace the model once with numpyro.handlers.seed + numpyro.handlers.trace and inspect the shape of every site. Here we use it to confirm that the per-subject parameters are length-\(S\) and obs is length-\(N\):

with numpyro.handlers.seed(rng_seed=0):
    trace = numpyro.handlers.trace(model_partial).get_trace(
        x=z_day, subject_idx=subject_idx, y=None
    )

for name, site in trace.items():
    print(f"{name:>10s}  shape={tuple(site['value'].shape)}")
  mu_alpha  shape=()
   mu_beta  shape=()
 tau_alpha  shape=()
  tau_beta  shape=()
  subjects  shape=(18,)
   alpha_s  shape=(18,)
    beta_s  shape=(18,)
     sigma  shape=()
       obs  shape=(110,)

The model above is the centered parametrization: each \(\alpha_s\) is drawn directly from \(\text{Normal}(\mu_\alpha, \tau_\alpha)\), so the prior on \(\alpha_s\) depends on the values of the hyperparameters. An algebraically equivalent way to write the same model introduces a standardized auxiliary variable and reconstructs \(\alpha_s\) as a deterministic transform:

\[ z_s^\alpha \sim \text{Normal}(0, 1), \qquad \alpha_s = \mu_\alpha + \tau_\alpha \, z_s^\alpha, \]

and analogously for \(\beta_s\). By the change-of-variables rule the two formulations encode the same Bayesian model, so the resulting posteriors over \((\mu_\alpha, \tau_\alpha, \alpha_1, \ldots, \alpha_S)\) are identical.

What is not identical is the geometry that NUTS sees. NUTS samples the latent variables the model registers: \(\alpha_s\) in one case, \(z_s^\alpha\) in the other. In the centered version the posterior of \((\log \tau_\alpha, \alpha_s)\) has a funnel shape: small \(\tau_\alpha\) pinches the \(\alpha_s\) near \(\mu_\alpha\), large \(\tau_\alpha\) lets them spread out, and a single step size cannot fit both regions. In the non-centered version \(z_s^\alpha\) is sampled from a fixed unit Normal regardless of \(\tau_\alpha\), removing the funnel from the latent geometry.

def model_partial_nc(x, subject_idx, y=None, n_subjects=n_subjects):
    mu_alpha = numpyro.sample("mu_alpha", dist.Normal(0, 1))
    mu_beta = numpyro.sample("mu_beta", dist.Normal(0, 1))
    tau_alpha = numpyro.sample("tau_alpha", dist.Exponential(1.0))
    tau_beta = numpyro.sample("tau_beta", dist.Exponential(1.0))

    with numpyro.plate("subjects", n_subjects):
        z_alpha = numpyro.sample("z_alpha", dist.Normal(0, 1))
        z_beta = numpyro.sample("z_beta", dist.Normal(0, 1))
    sigma = numpyro.sample("sigma", dist.Exponential(1.0))

    alpha_s = numpyro.deterministic("alpha_s", mu_alpha + tau_alpha * z_alpha)
    beta_s = numpyro.deterministic("beta_s", mu_beta + tau_beta * z_beta)

    mu = alpha_s[subject_idx] + beta_s[subject_idx] * x
    numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)

The non-centered form is the safer default for hierarchical models, especially when the data are sparse or the population scales \(\tau\) are small. For our sleepstudy data, both parametrizations fit fine and produce essentially identical diagnostics. We stick with the centered version model_partial in the remainder of the exercise.

Part 2: Inference

We now run NUTS on each of the three models and inspect the diagnostics. To avoid repeating boilerplate, we wrap the inference in a small helper that runs four chains and returns the MCMC object. We will convert each MCMC object to an ArviZ InferenceData only when needed (e.g. for az.summary and az.compare), and otherwise work with mcmc.get_samples() directly. We use 1000 warmup and 2000 sampling iterations per chain.

def run_mcmc(model, rng_key, **model_kwargs):
    kernel = NUTS(model)
    mcmc = MCMC(
        kernel,
        num_warmup=1000,
        num_samples=2000,
        num_chains=4,
        progress_bar=False,
    )
    mcmc.run(rng_key, **model_kwargs)
    return mcmc

We split the master rng_key once so each fit uses a fresh sub-key:

k_pooled, k_unpooled, k_partial = random.split(rng_key, 3)
data_kwargs = dict(x=z_day, subject_idx=subject_idx, y=z_reaction)

mcmc_pooled = run_mcmc(model_pooled, k_pooled, **data_kwargs)
mcmc_unpooled = run_mcmc(model_unpooled, k_unpooled, **data_kwargs)
mcmc_partial = run_mcmc(model_partial, k_partial, **data_kwargs)

Task 2 — Diagnose convergence. For each fit, call az.summary on the top-level parameters and check that \(\hat{R} \le 1.01\) and that the effective sample sizes ess_bulk and ess_tail are at least a few hundred. Also count divergent transitions: NumPyro reports them via mcmc.get_extra_fields()["diverging"].

NUTS (and HMC more generally, see the chapter on MCMC) explores the posterior by simulating a Hamiltonian trajectory using the leapfrog integrator with some step size \(\varepsilon\). Leapfrog is only stable when \(\varepsilon\) is small compared to the local curvature of the log-posterior. When the chain enters a region where the curvature is too sharp for the current \(\varepsilon\) the simulated energy explodes and the trajectory is flagged as a divergent transition. NumPyro stores these flags alongside the samples; ArviZ exposes them as idata.sample_stats.diverging.

A handful of divergences (a few out of thousands) often does not invalidate the fit, but a substantial fraction means the sampler is systematically failing to enter parts of the posterior, and the resulting estimates may be biased. Standard remedies are: (i) increase NUTS’s target acceptance probability so it picks a smaller \(\varepsilon\) (NUTS(model, target_accept_prob=0.95)); (ii) reparametrize the model, typically by switching to a non-centered parametrization for hierarchical priors.

We first convert the MCMC objects to ArviZ InferenceData and summarize the results:

idata_pooled = az.from_numpyro(mcmc_pooled)
idata_unpooled = az.from_numpyro(mcmc_unpooled)
idata_partial = az.from_numpyro(mcmc_partial)
az.summary(idata_pooled, var_names=["alpha", "beta", "sigma"])
mean sd eti89_lb eti89_ub ess_bulk ess_tail r_hat mcse_mean mcse_sd
alpha -0.001 0.082 -0.13 0.13 7911 5591 1.00 0.00092 0.00063
beta 0.518 0.083 0.38 0.65 8370 6081 1.00 0.00091 0.00064
sigma 0.869 0.059 0.78 0.97 7227 6254 1.00 0.0007 0.00053
az.summary(idata_unpooled, var_names=["alpha_s", "beta_s", "sigma"])
mean sd eti89_lb eti89_ub ess_bulk ess_tail r_hat mcse_mean mcse_sd
alpha_s[0] 0.576 0.173 0.3 0.85 16487 5657 1.00 0.0013 0.00099
alpha_s[1] -1.238 0.377 -1.8 -0.63 14607 6181 1.00 0.0031 0.0022
alpha_s[2] -1.086 0.233 -1.5 -0.71 17344 5426 1.00 0.0018 0.0013
alpha_s[3] 0.04 0.167 -0.23 0.3 16243 6143 1.00 0.0013 0.00096
alpha_s[4] -0.012 0.253 -0.41 0.4 13499 6452 1.00 0.0022 0.0016
alpha_s[5] 0.068 0.171 -0.21 0.34 17901 6348 1.00 0.0013 0.00092
alpha_s[6] 0.219 0.238 -0.15 0.59 14696 6436 1.00 0.002 0.0014
alpha_s[7] -0.297 0.457 -1 0.42 11899 6581 1.00 0.0042 0.003
alpha_s[8] -0.818 0.166 -1.1 -0.55 17039 5987 1.00 0.0013 0.00089
alpha_s[9] 1.159 0.17 0.88 1.4 15233 6253 1.00 0.0014 0.00099
alpha_s[10] -0.434 0.24 -0.82 -0.051 14571 6031 1.00 0.002 0.0014
alpha_s[11] 0.57 0.369 -0.012 1.1 14346 6885 1.00 0.0031 0.0022
alpha_s[12] -0.203 0.171 -0.47 0.071 15267 5399 1.00 0.0014 0.001
alpha_s[13] 0.424 0.231 0.052 0.8 16089 6250 1.00 0.0018 0.0013
alpha_s[14] 0.247 0.361 -0.32 0.82 17160 6284 1.00 0.0028 0.002
alpha_s[15] -0.401 0.236 -0.78 -0.013 17452 5780 1.00 0.0018 0.0013
alpha_s[16] -0.139 0.171 -0.41 0.13 19681 6237 1.00 0.0012 0.00084
alpha_s[17] -0.18 0.74 -1.4 1 10013 6211 1.00 0.0074 0.0053
beta_s[0] 1.093 0.176 0.82 1.4 17777 6401 1.00 0.0013 0.00099
beta_s[1] -0.005 0.343 -0.56 0.54 15415 6080 1.00 0.0028 0.002
beta_s[2] 0.361 0.24 -0.011 0.74 17078 5892 1.00 0.0018 0.0013
beta_s[3] 0.149 0.176 -0.13 0.43 16755 6060 1.00 0.0014 0.001
beta_s[4] 0.043 0.25 -0.36 0.45 14082 6475 1.00 0.0021 0.0015
beta_s[5] 0.48 0.176 0.2 0.76 17133 6179 1.00 0.0013 0.00096
beta_s[6] 0.487 0.268 0.065 0.91 17575 6008 1.00 0.002 0.0015
beta_s[7] 0.184 0.466 -0.57 0.93 10848 6287 1.00 0.0045 0.0032
beta_s[8] -0.146 0.178 -0.42 0.14 16859 5900 1.00 0.0014 0.001
beta_s[9] 0.957 0.172 0.68 1.2 16035 6047 1.00 0.0014 0.00098
beta_s[10] 0.618 0.226 0.25 0.98 14143 6188 1.00 0.0019 0.0014
beta_s[11] 0.902 0.36 0.33 1.5 14774 6202 1.00 0.003 0.0021
beta_s[12] 0.322 0.174 0.041 0.6 16458 6655 1.00 0.0014 0.00098
beta_s[13] 0.792 0.223 0.43 1.1 17006 5990 1.00 0.0017 0.0012
beta_s[14] 0.505 0.314 -0.0089 1 17614 5897 1.00 0.0024 0.0017
beta_s[15] 0.767 0.218 0.42 1.1 15867 5676 1.00 0.0017 0.0013
beta_s[16] 0.458 0.172 0.18 0.73 16134 6279 1.00 0.0014 0.00096
beta_s[17] 0.23 0.627 -0.78 1.2 10520 6403 1.00 0.0061 0.0044
sigma 0.538 0.0434 0.47 0.61 6831 6382 1.00 0.00053 0.00039
az.summary(
    idata_partial,
    var_names=["mu_alpha", "mu_beta", "tau_alpha", "tau_beta", "sigma"],
)
mean sd eti89_lb eti89_ub ess_bulk ess_tail r_hat mcse_mean mcse_sd
mu_alpha -0.055 0.157 -0.3 0.19 7084 5878 1.00 0.0019 0.0014
mu_beta 0.5 0.103 0.34 0.66 5866 5444 1.00 0.0013 0.0011
tau_alpha 0.613 0.132 0.44 0.84 6632 5907 1.00 0.0017 0.0017
tau_beta 0.329 0.09 0.2 0.48 3025 3207 1.00 0.0016 0.0015
sigma 0.543 0.0451 0.48 0.62 6318 5685 1.00 0.00058 0.00043

Finally, we record the number of divergences for each fit:

for name, mcmc in [
    ("pooled", mcmc_pooled),
    ("unpooled", mcmc_unpooled),
    ("partial", mcmc_partial),
]:
    div = mcmc.get_extra_fields()["diverging"]
    n_div = int(div.sum()) if div is not None else 0
    print(f"{name:>10s}: {n_div} divergent transitions")
    pooled: 0 divergent transitions
  unpooled: 0 divergent transitions
   partial: 1 divergent transitions

All three fits give \(\hat{R} \approx 1.00\) and ESS values in the thousands. The pooled and unpooled fits report no divergences. The centered partial-pooling fit may produce a small number of divergences on this dataset, but not enough to require reparametrization.

Part 3: Shrinkage

The defining behaviour of partial pooling is called shrinkage: each per-subject slope \(\beta_s\) is pulled toward the population mean \(\mu_\beta\) by an amount that depends on how much the data say about subject \(s\). Subjects with many observations and a clear trend stay close to their no-pooling estimate; subjects with few observations or noisy data get pulled in much more strongly. We can see both effects directly by comparing the per-subject slopes under no pooling and partial pooling.

Task 3 — Forest plot of shrinkage. For each subject, plot the posterior mean and 89% credible interval of \(\beta_s\) under both the no-pooling and partial-pooling fits, side by side. Add a vertical line at the posterior mean of the population slope \(\mu_\beta\) from the partial-pooling fit. Sort subjects so the pattern of shrinkage is easy to read, and pay attention to subjects with only 2 observations.

Calling mcmc.get_samples() returns a dictionary of JAX arrays whose first axis stacks all chains and draws together. For example, mcmc_unpooled.get_samples()["beta_s"] has shape (num_chains * num_samples, S). Reduce along axis 0 with jnp.mean and jnp.quantile to get length-\(S\) vectors of posterior means and credible-interval bounds:

samples_unp = mcmc_unpooled.get_samples()
beta_unp = samples_unp["beta_s"]                          # shape (num_draws, S)
mean_unp = jnp.mean(beta_unp, axis=0)                     # shape (S,)
lo_unp, hi_unp = jnp.quantile(beta_unp, jnp.array([0.055, 0.945]), axis=0)

Use the same pattern for mcmc_partial. The population mean is jnp.mean(mcmc_partial.get_samples()["mu_beta"]).

Code
def per_subject_summary(samples, name):
    arr = samples[name]                                       # shape (num_draws, S)
    mean = jnp.mean(arr, axis=0)
    lo, hi = jnp.quantile(arr, jnp.array([0.055, 0.945]), axis=0)
    return mean, lo, hi

samples_unp = mcmc_unpooled.get_samples()
samples_par = mcmc_partial.get_samples()

mean_unp, lo_unp, hi_unp = per_subject_summary(samples_unp, "beta_s")
mean_par, lo_par, hi_par = per_subject_summary(samples_par, "beta_s")
mu_beta_hat = float(jnp.mean(samples_par["mu_beta"]))

n_obs_per_subject = jnp.array(
    [int((subject_idx == s).sum()) for s in range(n_subjects)]
)

# Sort subjects by no-pooling estimate.
order = jnp.argsort(mean_unp)
y_pos = jnp.arange(n_subjects)
offset = 0.18

fig, ax = plt.subplots(figsize=(6, 6))
# matplotlib's `xerr` expects a 2xN array with rows (lower_offset, upper_offset),
# each non-negative — i.e. distances from the mean, not the absolute bounds.
ax.errorbar(
    mean_unp[order], y_pos + offset,
    xerr=jnp.stack([mean_unp[order] - lo_unp[order], hi_unp[order] - mean_unp[order]]),
    fmt="o", color="C1", label="no pooling", capsize=2, markersize=4,
)
ax.errorbar(
    mean_par[order], y_pos - offset,
    xerr=jnp.stack([mean_par[order] - lo_par[order], hi_par[order] - mean_par[order]]),
    fmt="o", color="C0", label="partial pooling", capsize=2, markersize=4,
)
ax.axvline(mu_beta_hat, color="black", linestyle="--", linewidth=1, label=r"$\mu_\beta$")

ax.set_yticks(y_pos.tolist())
ax.set_yticklabels(
    [f"Subject {int(s)} (n={int(n_obs_per_subject[int(s)])})"
     for s in order],
    fontsize=8,
)
ax.set_xlabel(r"$\beta_s$ (standardized scale)")
ax.legend(loc="lower right", fontsize=8)
plt.tight_layout()
plt.show()
Figure 2: Per-subject slopes \(\beta_s\) on the standardized scale under no-pooling (orange) and partial-pooling (blue) fits, with 89% credible intervals. The vertical line marks the partial-pooling population mean \(\mu_\beta\). Subjects are sorted by their no-pooling estimate, and the count \(n_s\) of observations is shown next to each subject label.

Two patterns stand out. First, the partial-pooling estimates sit closer to the dashed line \(\mu_\beta\) than the no-pooling estimates do. Hence, the per-subject slopes are pulled toward the population mean. Second, the shrinkage is not uniform: it is strongest precisely for the subjects with only 2 observations (whose no-pooling slopes are at the mercy of two noisy points) and weakest for subjects with 10 observations and a clear trend. The credible intervals also tighten considerably under partial pooling, because each subject now also benefits from the information the other subjects provide about plausible slope values.

Part 4: Posterior predictive checks

Diagnostics like \(\hat R\) and ESS only tell us that MCMC converged, not whether the model is a good description of the data. For that we use posterior predictive checks (PPCs): we draw fake data from the fitted model and compare it to the real data. If the two look similar, the model can reproduce the patterns we care about; if not, the discrepancy tells us how the model fails.

Concretely, for each posterior draw \(\theta^{(d)}\) we generate a replicated dataset \(\tilde y^{(d)}\) from the likelihood at the observed predictors \(x_i\) and subject indices \(s[i]\). Stacking these over the \(D\) draws gives a posterior predictive sample of shape \(D \times N\). NumPyro provides the Predictive class for exactly this purpose.

Task 4 — Per-subject posterior predictive checks. For each of the three models, draw posterior predictive reaction times at the observed \((x_i, s[i])\), back-transform to the original millisecond scale, and overlay the per-subject observed scatter with the posterior mean line and an 89% predictive band. How does each model perform?

Predictive(model, posterior_samples=samples) returns a callable that runs the model once per posterior draw, conditioning on the supplied samples and re-sampling everything else. Pass y=None so that the "obs" site is sampled rather than conditioned on:

predictive = Predictive(model_partial, posterior_samples=mcmc_partial.get_samples())
ppc = predictive(rng_key_pp, x=z_day, subject_idx=subject_idx, y=None)
y_rep = ppc["obs"]   # shape (num_draws, N), on the standardized scale

To get back to milliseconds: y_rep_ms = reaction_mean + reaction_std * y_rep.

We loop over the three models, generating posterior predictive draws on the standardized scale and immediately un-standardizing them.

k_pp_pooled, k_pp_unpooled, k_pp_partial = random.split(random.PRNGKey(0), 3)

def posterior_predictive(model, mcmc, rng_key_pp):
    predictive = Predictive(model, posterior_samples=mcmc.get_samples())
    ppc = predictive(rng_key_pp, x=z_day, subject_idx=subject_idx, y=None)
    return reaction_mean + reaction_std * ppc["obs"]   # shape (num_draws, N), milliseconds

y_rep_pooled = posterior_predictive(model_pooled, mcmc_pooled, k_pp_pooled)
y_rep_unpooled = posterior_predictive(model_unpooled, mcmc_unpooled, k_pp_unpooled)
y_rep_partial = posterior_predictive(model_partial, mcmc_partial, k_pp_partial)

For each subject we summarize the predictive distribution at each observed day by its mean and 5%/95% quantiles, then overlay the observed scatter:

Code
def ppc_summary(y_rep):
    mean = jnp.mean(y_rep, axis=0)
    lo, hi = jnp.quantile(y_rep, jnp.array([0.055, 0.945]), axis=0)
    return mean, lo, hi

models_to_plot = [
    ("Complete pooling", y_rep_pooled, "C2"),
    ("No pooling", y_rep_unpooled, "C1"),
    ("Partial pooling", y_rep_partial, "C0"),
]

fig, axes = plt.subplots(
    len(models_to_plot), n_subjects,
    figsize=(14, 5), sharex=True, sharey=True,
)

for row, (title, y_rep, color) in enumerate(models_to_plot):
    mean_pred, lo_pred, hi_pred = ppc_summary(y_rep)
    for s in range(n_subjects):
        ax = axes[row, s]
        mask = subject_idx == s
        order = jnp.argsort(day[mask])
        xs = day[mask][order]
        ax.fill_between(xs, lo_pred[mask][order], hi_pred[mask][order],
                        color=color, alpha=0.25)
        ax.plot(xs, mean_pred[mask][order], color=color, linewidth=1.2)
        ax.scatter(day[mask], reaction[mask], s=10, color="black", alpha=0.8)
        if row == 0:
            ax.set_title(f"S{s}", fontsize=8)
    axes[row, 0].set_ylabel(title, fontsize=9)

for ax in axes[-1, :]:
    ax.set_xlabel("Day")
plt.tight_layout()
plt.show()
Figure 3

Each model produces a visually different fit:

  • Complete pooling draws the same line for every subject, since it has no notion of subject identity. Subjects that are far from this average trend are poorly predicted. The 89% band is wide because a single \(\sigma\) has to absorb all the between-subject variation as if it were noise.
  • No pooling fits each subject’s points individually. The line goes through the cloud and the band is narrow. But for the subjects with only 2 observations the line is essentially an interpolation through two noisy points: the slope is over-confident and would likely generalize poorly to a third measurement on the same subject.
  • Partial pooling is a compromise. Subjects with plenty of data get a line very close to the no-pooling fit. Subjects with only 2 observations get a line and band that are visibly pulled toward the population trend.

Part 5: Model comparison

The PPCs gave a qualitative picture of how the three models fit. To put a number on it, we now compare them by their expected out-of-sample predictive accuracy: how well each model would predict a new observation from the same data-generating process.

For a single new observation \(\tilde y\) at predictor \(\tilde x\), the predictive density under our fitted model is the posterior-averaged likelihood, \[ p(\tilde y \mid \tilde x, \mathcal{D}) \;=\; \int p(\tilde y \mid \tilde x, \theta)\, p(\theta \mid \mathcal{D})\, d\theta \;\approx\; \frac{1}{D} \sum_{d=1}^{D} p\bigl(\tilde y \mid \tilde x, \theta^{(d)}\bigr), \] where \(\theta^{(d)}\) are draws from the posterior. We score how well the model predicts \(\tilde y\) by the log of this quantity (higher is better: the model placed more probability mass near the true value). Summing over a set of held-out points gives the expected log pointwise predictive density, \[ \text{ELPD} \;=\; \sum_{i} \log\!\left( \frac{1}{D} \sum_{d=1}^{D} p\bigl(y_i \mid x_i, \theta^{(d)}\bigr) \right). \]

In code we already have access to the per-draw log-likelihoods \(\ell_i^{(d)} = \log p(y_i \mid x_i, \theta^{(d)})\). We compute the inner average in log-space using jax.scipy.special.logsumexp, which avoids the numerical underflow that would occur if we exponentiated first: \[ \log \frac{1}{D}\sum_d e^{\ell_i^{(d)}} \;=\; \operatorname{logsumexp}_d \ell_i^{(d)} \;-\; \log D. \]

Task 5 — Compare the three models. We will obtain ELPD in two ways and check that they agree on which model to prefer.

For PSIS-LOO, convert each MCMC object to an InferenceData with log_likelihood=True and pass the resulting objects to az.compare. Which model has the highest elpd_loo? Are there any Pareto-\(\hat k\) warnings? If so, which observations are problematic and why?

For the held-out ELPD, load the held-out CSV, standardize with the training statistics, and use numpyro.infer.util.log_likelihood to evaluate the per-draw log-likelihoods on the test set:

from numpyro.infer.util import log_likelihood

test_kwargs = dict(x=x_test, subject_idx=subject_idx_test, y=y_test)
log_lik = log_likelihood(model, mcmc.get_samples(), **test_kwargs)["obs"]
# log_lik has shape (num_draws, n_test).

Then use jax.scipy.special.logsumexp to average over draws in log-space and sum over test points to get the ELPD.

Which model is preferred under each criterion, and do they agree?

For PSIS-LOO, ArviZ needs the per-observation training log-likelihoods \(\ell_i^{(d)}\). Passing log_likelihood=True to az.from_numpyro traces the model under each posterior draw and records log_prob(y_i) at the obs site:

idata_pooled_ll = az.from_numpyro(mcmc_pooled, log_likelihood=True)
idata_unpooled_ll = az.from_numpyro(mcmc_unpooled, log_likelihood=True)
idata_partial_ll = az.from_numpyro(mcmc_partial, log_likelihood=True)

az.compare(
    {
        "pooled": idata_pooled_ll,
        "unpooled": idata_unpooled_ll,
        "partial": idata_partial_ll,
    }
)
/Users/mariusfurter/Documents/repos/drawing-inferences-notes/.venv/lib/python3.12/site-packages/arviz_stats/loo/helper_loo.py:1143: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
/Users/mariusfurter/Documents/repos/drawing-inferences-notes/.venv/lib/python3.12/site-packages/arviz_stats/loo/helper_loo.py:1143: UserWarning: Estimated shape parameter of Pareto distribution is greater than 0.70 for one or more samples. You should consider using a more robust model, this is because importance sampling is less likely to work well if the marginal posterior and LOO posterior are very different. This is more likely to happen with a non-robust model and highly influential observations.
  warnings.warn(
rank elpd p elpd_diff weight se dse warning
partial 0 -110.0 28.6 0.0 0.62 16.0 0.0 True
unpooled 1 -110.0 32.7 2.0 0.29 16.0 2.5 True
pooled 2 -140.0 3.2 30.0 0.09 8.0 14.0 False

Complete pooling lands far below the other two: a single global line cannot capture the between-subject variation, and LOO penalizes it heavily. The no-pooling and partial-pooling models come out close on elpd_loo. ArviZ also issues Pareto-\(\hat k\) warnings — typically for the subjects with only two observations, since removing one of those two points genuinely changes the fit for that subject and the importance-sampling correction can no longer trust the full-data posterior as a proxy.

The held-out CSV has the same columns as the training file. We standardize it with the training mean and standard deviation, so it lives on the same scale as the posterior:

df_test = pd.read_csv("../data/exercise-8/sleepstudy_heldout.csv")
x_test = (jnp.array(df_test["day"].to_numpy(), dtype=jnp.float32) - day_mean) / day_std
y_test = (jnp.array(df_test["reaction"].to_numpy(), dtype=jnp.float32) - reaction_mean) / reaction_std
subject_idx_test = jnp.array(df_test["subject_idx"].to_numpy(), dtype=jnp.int32)

To get \(\ell_i^{(d)} = \log p(y_i^{\text{test}} \mid x_i^{\text{test}}, \theta^{(d)})\) on the new inputs we use numpyro.infer.util.log_likelihood. It substitutes each posterior draw into the model, re-traces it on the supplied inputs, and returns log_prob(y) at every observed site.

from numpyro.infer.util import log_likelihood

test_kwargs = dict(x=x_test, subject_idx=subject_idx_test, y=y_test)

def held_out_elpd(model, mcmc):
    log_lik = log_likelihood(model, mcmc.get_samples(), **test_kwargs)["obs"]
    # log_lik has shape (num_draws, n_test).
    # logsumexp over the draws axis, minus log(num_draws), gives the per-point
    # log predictive density; sum over test points gives the held-out ELPD.
    pointwise = jax.scipy.special.logsumexp(log_lik, axis=0) - jnp.log(log_lik.shape[0])
    return float(pointwise.sum())

{
    "pooled":   held_out_elpd(model_pooled,   mcmc_pooled),
    "unpooled": held_out_elpd(model_unpooled, mcmc_unpooled),
    "partial":  held_out_elpd(model_partial,  mcmc_partial),
}
{'pooled': -83.05768585205078,
 'unpooled': -56.706207275390625,
 'partial': -51.243282318115234}

We see that the partially-pooled model has the highest held-out ELPD, so the partial pooling is indeed helping with generalization.