11  Bayesian Causal Inference

library(tidyverse)
library(BART)
library(bcf)

The frequentist methods in earlier chapters compute point estimates and asymptotic confidence intervals. Bayesian methods take a different route: they put priors on the unknown quantities and return a full posterior distribution. The advantages — credible intervals, principled propagation of uncertainty into derived quantities, and natural handling of hierarchical structure — make Bayesian methods particularly attractive for causal inference, where uncertainty is the central question.

This chapter covers two workhorses:

  1. BART for causal inference (Hill 2011) — Bayesian Additive Regression Trees as a flexible nonparametric outcome model, used to estimate ATE, ATT, and individual-level treatment effects via the posterior.
  2. Bayesian Causal Forests (BCF) (Hahn, Murray, & Carvalho 2020) — an extension of BART that separates the prognostic and treatment effect components, with a regularising prior on heterogeneity to reduce overfitting.

Related reading: The chapter Using Numpyro in Topics on Econometrics and Causal Inference demonstrates a Bayesian hierarchical model with Numpyro. For an MCMC-from-scratch perspective on causal inference, that companion piece is the natural next step after this chapter.

11.1 Setup

We use the same data-generating process as the Heterogeneous Treatment Effects chapter so that the Bayesian estimators can be compared directly to the frequentist meta-learners:

set.seed(42)
n <- 2000
p <- 5

X <- matrix(runif(n * p), n, p)
colnames(X) <- paste0("X", 1:p)
U <- rnorm(n)
ps <- plogis(-0.3 + 1.5 * X[, 2] + 0.5 * U)
D  <- rbinom(n, 1, ps)
tau <- 1 + 2 * X[, 1]
Y0  <- 0.5 * X[, 2] + 0.3 * U + rnorm(n)
Y1  <- Y0 + tau
Y   <- ifelse(D == 1, Y1, Y0)

df <- data.frame(Y = Y, D = D, X)
cat(sprintf("True ATE = %.3f\n", mean(tau)))
True ATE = 1.985

11.2 BART for causal inference (Hill 2011)

BART (Chipman, George, & McCulloch 2010) is a sum-of-trees model with a strong regularising prior on tree depth. It is a flexible nonparametric regressor with native uncertainty quantification from MCMC.

The Hill (2011) recipe for causal inference is:

  1. Fit BART to model \(\mu(d, x) = \mathbb{E}[Y \mid D = d, X = x]\) using the observed data.
  2. For each unit \(i\), predict counterfactuals \(\mu(1, x_i)\) and \(\mu(0, x_i)\).
  3. The individual treatment effect (ITE) is \(\hat\tau_i = \mu(1, x_i) - \mu(0, x_i)\).
  4. Average over the posterior MCMC draws for ATE, or over units for conditional/individual effects.
# BART::wbart fits a continuous-outcome BART model
# Combine D and X as the design matrix
X_train <- cbind(D = D, X)
bart_fit <- wbart(
  x.train = X_train,
  y.train = Y,
  ndpost = 1000,    # posterior samples after burn-in
  nskip  = 200,     # burn-in
  printevery = 1000
)
*****Into main of wbart
*****Data:
data:n,p,np: 2000, 6, 0
y1,yn: -0.969524, -0.658949
x1,x[n*p]: 0.000000, 0.647277
*****Number of Trees: 200
*****Number of Cut Points: 1 ... 100
*****burn and ndpost: 200, 1000
*****Prior:beta,alpha,tau,nu,lambda: 2.000000,0.950000,0.163247,3.000000,0.232155
*****sigma: 1.091702
*****w (weights): 1.000000 ... 1.000000
*****Dirichlet:sparse,theta,omega,a,b,rho,augment: 0,0,1,0.5,1,6,0
*****nkeeptrain,nkeeptest,nkeeptestme,nkeeptreedraws: 1000,1000,1000,1000
*****printevery: 1000
*****skiptr,skipte,skipteme,skiptreedraws: 1,1,1,1

MCMC
done 0 (out of 1200)
done 1000 (out of 1200)
time: 16s
check counts
trcnt,tecnt,temecnt,treedrawscnt: 1000,0,0,1000
# Counterfactual predictions
X_treat   <- cbind(D = rep(1, n), X)
X_control <- cbind(D = rep(0, n), X)

mu1_post <- predict(bart_fit, newdata = X_treat)   # 1000 x n posterior draws
*****In main of C++ for bart prediction
tc (threadcount): 1
number of bart draws: 1000
number of trees in bart sum: 200
number of x columns: 6
from x,np,p: 6, 2000
***using serial code
mu0_post <- predict(bart_fit, newdata = X_control)
*****In main of C++ for bart prediction
tc (threadcount): 1
number of bart draws: 1000
number of trees in bart sum: 200
number of x columns: 6
from x,np,p: 6, 2000
***using serial code
# Posterior of individual treatment effects
tau_post <- mu1_post - mu0_post   # 1000 x n posterior draws of ITEs

# Posterior summaries
tau_mean <- rowMeans(tau_post)    # ATE in each draw
ate_post_mean   <- mean(tau_mean)
ate_post_sd     <- sd(tau_mean)
ate_post_ci     <- quantile(tau_mean, c(0.025, 0.975))

cat(sprintf("BART ATE posterior:\n"))
BART ATE posterior:
cat(sprintf("  Mean:    %.3f\n", ate_post_mean))
  Mean:    2.198
cat(sprintf("  SD:      %.3f\n", ate_post_sd))
  SD:      0.051
cat(sprintf("  95%% CrI: [%.3f, %.3f]\n", ate_post_ci[1], ate_post_ci[2]))
  95% CrI: [2.099, 2.300]
cat(sprintf("  True ATE: %.3f\n", mean(tau)))
  True ATE: 1.985

The posterior credible interval directly captures parameter uncertainty — no asymptotic approximations needed. This is the central appeal of the Bayesian approach for causal inference.

11.2.1 Individual-level credible intervals

Unlike frequentist meta-learners, BART gives uncertainty at the individual level. Plot the posterior distribution of \(\hat\tau_i\) vs the true \(\tau_i\) for each unit:

ite_mean <- colMeans(tau_post)
ite_lo   <- apply(tau_post, 2, quantile, 0.025)
ite_hi   <- apply(tau_post, 2, quantile, 0.975)

df_ite <- tibble(
  X1     = X[, 1],
  true   = tau,
  est    = ite_mean,
  lo     = ite_lo,
  hi     = ite_hi
)

ggplot(df_ite, aes(x = X1)) +
  geom_ribbon(aes(ymin = lo, ymax = hi), alpha = 0.2,
              fill = "steelblue") +
  geom_point(aes(y = est), size = 0.4, alpha = 0.3, colour = "steelblue") +
  geom_abline(aes(intercept = 1, slope = 2), colour = "firebrick",
              linetype = "dashed", linewidth = 1) +
  labs(x = "X1", y = "Individual treatment effect",
       caption = "Red dashed = true τ(x). Blue band = pooled 95% credible interval.") +
  theme_minimal()

BART-estimated individual treatment effects (mean of posterior) vs true τ(x) = 1 + 2 X₁. Vertical bars show 95% credible intervals.

The blue ribbon shows pointwise 95% credible intervals; the red dashed line is the truth. Credible intervals cover the truth almost everywhere — a hallmark of well-calibrated Bayesian inference.

11.3 Bayesian Causal Forests (BCF)

BART for causal inference is conceptually clean but it does not separate the prognostic effect (how \(X\) predicts \(Y(0)\)) from the treatment effect (how \(X\) moderates \(D\)). This causes a known problem: when the prognostic effect varies along the same covariates as the treatment effect, BART tends to confound them.

Hahn, Murray, & Carvalho (2020) propose Bayesian Causal Forests (BCF), which decomposes:

\[ \mathbb{E}[Y \mid D, X] = \mu(X) + \tau(X)\, D, \]

and puts BART priors on \(\mu\) and \(\tau\) separately — with a tighter prior on \(\tau\) that regularises towards homogeneous effects. This “separation prior” reduces spurious heterogeneity from prognostic variation.

# BCF requires a propensity score as an input ("piHat")
ps_fit <- glm(D ~ X1 + X2 + X3 + X4 + X5, data = df, family = binomial)
piHat  <- predict(ps_fit, type = "response")

# BCF prints MCMC progress in a way that conflicts with knitr's encoding;
# fork the call into a separate R session so its stdout/stderr don't poison
# the parent locale.
tmp_rds <- tempfile(fileext = ".rds")
saveRDS(list(Y = df$Y, D = df$D, X = X, piHat = piHat), tmp_rds)
system(paste0("Rscript -e '",
  "args <- readRDS(\"", tmp_rds, "\"); ",
  "suppressMessages(library(bcf)); ",
  "f <- bcf(y=args$Y, z=args$D, x_control=args$X, x_moderate=args$X, ",
  "pihat=args$piHat, nburn=200, nsim=1000, n_chains=1, no_output=TRUE, ",
  "verbose=FALSE); ",
  "saveRDS(list(tau=f$tau), \"", tmp_rds, "\")'"),
  ignore.stdout = TRUE, ignore.stderr = TRUE)
bcf_fit <- readRDS(tmp_rds)
unlink(tmp_rds)

Extract the posterior of individual treatment effects:

tau_post_bcf <- bcf_fit$tau    # nsim x n matrix of posterior ITEs

bcf_ate_post <- rowMeans(tau_post_bcf)
bcf_ite_mean <- colMeans(tau_post_bcf)

cat(sprintf("BCF ATE posterior:\n"))
BCF ATE posterior:
cat(sprintf("  Mean:    %.3f\n", mean(bcf_ate_post)))
  Mean:    2.195
cat(sprintf("  SD:      %.3f\n", sd(bcf_ate_post)))
  SD:      0.048
cat(sprintf("  95%% CrI: [%.3f, %.3f]\n",
            quantile(bcf_ate_post, 0.025),
            quantile(bcf_ate_post, 0.975)))
  95% CrI: [2.101, 2.287]
cat(sprintf("  True ATE: %.3f\n", mean(tau)))
  True ATE: 1.985

11.3.1 BART vs BCF: individual treatment effects

Compare the estimated CATEs across the two methods:

df_compare <- bind_rows(
  tibble(X1 = X[, 1], est = bcf_ite_mean, true = tau, method = "BCF"),
  tibble(X1 = X[, 1], est = ite_mean,     true = tau, method = "BART")
)

ggplot(df_compare, aes(X1, est)) +
  geom_point(alpha = 0.2, size = 0.3, colour = "steelblue") +
  geom_smooth(method = "loess", se = FALSE, colour = "darkblue",
              linewidth = 1) +
  geom_abline(intercept = 1, slope = 2, linetype = "dashed",
              colour = "firebrick") +
  facet_wrap(~ method) +
  labs(x = "X1", y = "Estimated ITE",
       caption = "Red dashed = truth = 1 + 2 X1") +
  theme_minimal()

BCF (left) vs BART (right) individual treatment effect estimates. BCF’s separation prior gives tighter, less noisy CATE estimates.

BCF typically yields smoother and less noisy CATE estimates than vanilla BART because of the prognostic/treatment-effect separation.

11.4 When to use Bayesian methods

Reason to use Bayesian Reason to use Frequentist
Prior information matters Asymptotic theory is well-understood
Need full posterior (e.g. for downstream decisions) Computational cost matters (BART is slow on large \(n\))
Hierarchical / multilevel data structure Effects are simple averages
Small sample, weak identification Effects are well-identified
Want native individual-level CrIs Existing frequentist toolkit is sufficient

The Bayesian approach is not just a different machinery for the same answer — it provides additional outputs (posterior over individual effects, principled propagation of uncertainty to derived quantities) that are unavailable in frequentist methods. For decision-theoretic applications (treatment recommendation under cost-benefit trade-offs), the Bayesian posterior is the natural input.

11.5 Practical guidance

  • Start with BCF, not BART. BCF’s separation prior almost always yields better-calibrated CATE estimates.
  • Diagnose convergence. Inspect trace plots and effective sample sizes for the posterior. BART and bcf provide built-in diagnostics.
  • Don’t ignore the propensity score. BCF requires piHat as an input — a poor propensity score weakens the model. Use a flexible classifier (e.g. random forests, BART’s own classification mode).
  • Posterior summaries. Report posterior mean, SD, and 95% credible intervals for both ATE and bin-level CATEs. The full posterior matters more than the point estimate.

11.6 Connections

  • The Heterogeneous Treatment Effects chapter covers frequentist meta-learners (S/T/X/R/DR) and causal forests via grf. BCF can be viewed as the Bayesian counterpart to causal forests: both use ensembles of trees, but BCF puts priors on the trees while grf uses honest sample-splitting.
  • The Sensitivity Analysis chapter’s Cinelli-Hazlett bounds are frequentist; the Bayesian analog is a prior on the unmeasured-confounder strength (cf. McCandless et al. 2007).
  • For Bayesian Numpyro-style hierarchical models, see the companion blog chapter.

11.7 Summary

  • BART for causal inference (Hill 2011): fit BART to \(E[Y \mid D, X]\), predict counterfactuals, compute ITEs from posterior draws.
  • Bayesian Causal Forests (BCF) (Hahn-Murray-Carvalho 2020): separate prognostic and treatment-effect components with distinct priors; reduces spurious heterogeneity.
  • Native credible intervals at the individual level — the Bayesian approach’s main differentiator from frequentist HTE methods.
  • Computational cost is the main caveat: BART/BCF scale poorly to \(n > 10^5\). For large datasets, frequentist alternatives (HTE chapter) are usually preferred.