Counterfactual Inference

Summary

Counterfactual inference asks “what would have happened under an alternative scenario?” It requires building a model on pre-intervention data, then using posterior predictive sampling under the counterfactual conditions. Applied here to excess death estimation during COVID-19 in England and Wales.

The Core Formula

Expected deaths is a counterfactual: deaths that would have occurred if nothing had changed. It can never be observed because we cannot simultaneously experience both the actual and counterfactual timeline.

Strategy

  1. Build a model on pre-intervention data — fit on pre-COVID deaths data (2006–2019) using predictors: month, linear time trend, average temperature
  2. Retrodict — check that the model fits the pre-intervention period (posterior predictive check)
  3. Counterfactual forecast — use pm.set_data() with post-COVID predictors but the same pre-COVID posterior to generate predictions of “what would have happened”
  4. Compute excess — subtract counterfactual predictions from reported deaths

The do-operator

Formally, this is the do-calculus applied to time: we intervene to set the world to pre-COVID conditions (no pandemic). Practically, this is implemented by running posterior predictive sampling on out-of-sample (post-COVID) predictor values.

Model Structure

with pm.Model(coords={"month": month_strings}) as model:
    month = pm.MutableData("month", pre["month"].to_numpy())
    time  = pm.MutableData("time",  pre["t"].to_numpy())
    temp  = pm.MutableData("temp",  pre["temp"].to_numpy())
 
    intercept    = pm.Normal("intercept", 40_000, 10_000)
    month_mu     = ZeroSumNormal("month mu", sigma=3000, dims="month")  # seasonal effects
    linear_trend = pm.TruncatedNormal("linear trend", 0, 50, lower=0)  # increasing baseline
    temp_coeff   = pm.Normal("temp coeff", 0, 200)  # ~-764 deaths/°C
 
    mu = intercept + (linear_trend * time) + month_mu[month - 1] + (temp_coeff * temp)
    sigma = pm.HalfNormal("sigma", 2_000)
    pm.TruncatedNormal("obs", mu=mu, sigma=sigma, lower=0, observed=deaths)

ZeroSumNormal: monthly deflections constrained to sum to zero (removes one degree of freedom, aids identifiability).

Counterfactual Sampling

# Update data to post-COVID predictors
with model:
    pm.set_data({"month": post["month"], "time": post["t"], "temp": post["temp"]})
    counterfactual = pm.sample_posterior_predictive(idata, var_names=["obs"])
 
# Excess deaths
excess = post_deaths - counterfactual.posterior_predictive["obs"]
cumulative_excess = excess.cumsum(dim="t")

The result is a posterior distribution over excess deaths — not just a point estimate, giving uncertainty bands.

Causal Caveats

  • The model assumes no unmodelled confounders changed between pre and post periods (other than COVID)
  • Many other things changed in 2020 (policy, healthcare access, behaviour) — strong causal claims require accounting for these
  • Excess deaths captures all-cause excess mortality, not just direct COVID deaths

Connections

Source