9  Heterogeneous Treatment Effects with Machine Learning

library(tidyverse)
library(grf)
library(policytree)
library(ggplot2)

The estimands chapter defined the conditional average treatment effect

\[ \text{CATE}(x) = \mathbb{E}[Y(1) - Y(0) \mid X = x] \]

as the effect of treatment for individuals with covariates \(X = x\). There, we computed CATE from known potential outcomes — a luxury only simulation affords. In real data, \(Y(0)\) and \(Y(1)\) are never both observed, and CATE must be estimated from \((X, D, Y)\) without strong functional-form assumptions.

This chapter covers the modern toolkit for that problem:

9.1 The data-generating process

Reuse the DGP from the estimands chapter, where the true CATE is \(\tau(x) = 1 + 2x\):

set.seed(42)
n <- 5000
p <- 5      # 5 covariates: only X1 is the effect modifier; X2..X5 are noise

X <- matrix(runif(n * p), n, p)
colnames(X) <- paste0("X", 1:p)
U <- rnorm(n)
# Treatment depends on a noise covariate (X2) and U — backdoor confounding
ps  <- plogis(-0.3 + 1.5 * X[, 2] + 0.5 * U)
D   <- rbinom(n, 1, ps)
tau <- 1 + 2 * X[, 1]          # CATE varies along X1 only
Y0  <- 0.5 * X[, 2] + 0.3 * U + rnorm(n)
Y1  <- Y0 + tau
Y   <- ifelse(D == 1, Y1, Y0)

cat(sprintf("n = %d, true ATE = %.3f\n", n, mean(tau)))
n = 5000, true ATE = 2.006
cat(sprintf("True CATE at X1 = 0.2: %.2f\n", 1 + 2 * 0.2))
True CATE at X1 = 0.2: 1.40
cat(sprintf("True CATE at X1 = 0.8: %.2f\n", 1 + 2 * 0.8))
True CATE at X1 = 0.8: 2.60

The treatment is confounded through \(X_2\) and the unobserved \(U\). Including \(X_2\) in the conditioning set restores identification; we will not condition on \(U\) (it is unobserved).

9.2 Meta-learners

A meta-learner is a recipe that turns a generic regression estimator (any black-box ML method) into a CATE estimator. The simplest meta-learners are:

9.2.1 S-learner (“single”)

Fit one regression of \(Y\) on \((D, X)\), then predict the difference between \(D=1\) and \(D=0\) at each \(x\):

\[ \hat\tau^S(x) = \hat\mu(1, x) - \hat\mu(0, x). \]

# Use a flexible learner — random forest from grf::regression_forest
XD_train <- cbind(D = D, X)
fit_S <- regression_forest(XD_train, Y, num.trees = 500)

# Counterfactual predictions at D = 1 and D = 0 for every X
mu1 <- predict(fit_S, cbind(D = rep(1, n), X))$predictions
mu0 <- predict(fit_S, cbind(D = rep(0, n), X))$predictions
tau_S <- mu1 - mu0

cor_S <- cor(tau_S, tau)
cat(sprintf("S-learner CATE correlation with truth: %.3f\n", cor_S))
S-learner CATE correlation with truth: 0.976

The S-learner is simple but biased toward zero when the treatment effect is small relative to outcome variation, because the regression “ignores” the \(D\) variable if it has little predictive value.

9.2.2 T-learner (“two”)

Fit two separate regressions: \(\hat\mu_1\) on treated units, \(\hat\mu_0\) on controls. Then \(\hat\tau^T(x) = \hat\mu_1(x) - \hat\mu_0(x)\).

fit_T1 <- regression_forest(X[D == 1, ], Y[D == 1], num.trees = 500)
fit_T0 <- regression_forest(X[D == 0, ], Y[D == 0], num.trees = 500)

mu1_T <- predict(fit_T1, X)$predictions
mu0_T <- predict(fit_T0, X)$predictions
tau_T <- mu1_T - mu0_T

cor_T <- cor(tau_T, tau)
cat(sprintf("T-learner CATE correlation with truth: %.3f\n", cor_T))
T-learner CATE correlation with truth: 0.977

The T-learner avoids the bias-toward-zero of the S-learner but suffers when the treated and control groups have very different propensity distributions: extrapolation errors compound when subtracting the two fitted functions.

9.2.3 X-learner (Künzel et al. 2019)

The X-learner imputes the missing potential outcomes using the T-learner fits, then regresses these “pseudo-outcomes” on \(X\) separately within each arm. Finally, the two CATE estimates are combined with weights from the propensity score.

# Step 1: T-learner fits (already done above)
# Step 2: Pseudo-outcomes
D1_pseudo <- Y[D == 1] - predict(fit_T0, X[D == 1, ])$predictions   # observed - imputed Y(0)
D0_pseudo <- predict(fit_T1, X[D == 0, ])$predictions - Y[D == 0]   # imputed Y(1) - observed

# Step 3: Regress pseudo-outcomes on X in each arm
fit_X1 <- regression_forest(X[D == 1, ], D1_pseudo, num.trees = 500)
fit_X0 <- regression_forest(X[D == 0, ], D0_pseudo, num.trees = 500)

tau_X1 <- predict(fit_X1, X)$predictions
tau_X0 <- predict(fit_X0, X)$predictions

# Step 4: Weighted combination — use propensity score as the weight
ps_fit <- regression_forest(X, D, num.trees = 500)
ps_hat <- predict(ps_fit, X)$predictions
ps_hat <- pmin(pmax(ps_hat, 0.02), 0.98)   # trim

tau_X <- ps_hat * tau_X0 + (1 - ps_hat) * tau_X1

cor_X <- cor(tau_X, tau)
cat(sprintf("X-learner CATE correlation with truth: %.3f\n", cor_X))
X-learner CATE correlation with truth: 0.988

The X-learner is particularly effective when the propensity score is far from 0.5 and one treatment arm is much smaller than the other — common in observational studies.

9.2.4 R-learner (Nie & Wager 2021)

The R-learner uses Robinson’s (1988) partialled-out residual approach. Define pseudo-outcomes that, in expectation, equal \(\tau(X)\):

\[ \tilde Y_i = \frac{Y_i - \hat m(X_i)}{D_i - \hat e(X_i)}, \quad \text{weights} = (D_i - \hat e(X_i))^2, \]

where \(\hat m(x) = \mathbb{E}[Y \mid X = x]\) and \(\hat e(x) = \mathbb{E}[D \mid X = x]\) are cross-fitted nuisance estimates. Regressing \(\tilde Y\) on \(X\) with the indicated weights yields the CATE.

# Cross-fitting: split sample into 2 folds, fit nuisances on one, predict on the other
set.seed(7)
folds <- sample(rep(1:2, length.out = n))

m_hat <- numeric(n)
e_hat <- numeric(n)
for (k in 1:2) {
  train <- folds != k
  test  <- folds == k
  m_fit <- regression_forest(X[train, ], Y[train], num.trees = 500)
  e_fit <- regression_forest(X[train, ], D[train], num.trees = 500)
  m_hat[test] <- predict(m_fit, X[test, ])$predictions
  e_hat[test] <- predict(e_fit, X[test, ])$predictions
}
e_hat <- pmin(pmax(e_hat, 0.02), 0.98)

# R-learner pseudo-outcome and weights
pseudo_R <- (Y - m_hat) / (D - e_hat)
weights_R <- (D - e_hat)^2

# Regress pseudo-outcomes on X with weights — use a regression forest
fit_R <- regression_forest(X, pseudo_R, sample.weights = weights_R, num.trees = 500)
tau_R <- predict(fit_R, X)$predictions

cor_R <- cor(tau_R, tau)
cat(sprintf("R-learner CATE correlation with truth: %.3f\n", cor_R))
R-learner CATE correlation with truth: 0.943

The R-learner is doubly robust in a Neyman-orthogonality sense and tends to be the most efficient meta-learner in moderately well-overlapping data.

9.2.5 DR-learner (Kennedy 2020)

The DR-learner constructs pseudo-outcomes via the AIPW formula and regresses them on \(X\):

\[ \tilde Y_i^{DR} = \hat\mu_1(X_i) - \hat\mu_0(X_i) + \frac{D_i (Y_i - \hat\mu_1(X_i))}{\hat e(X_i)} - \frac{(1 - D_i) (Y_i - \hat\mu_0(X_i))}{1 - \hat e(X_i)}. \]

mu1_cf <- numeric(n); mu0_cf <- numeric(n); e_cf <- numeric(n)
for (k in 1:2) {
  train <- folds != k
  test  <- folds == k
  mu1_fit <- regression_forest(X[train & D == 1, , drop = FALSE], Y[train & D == 1],
                               num.trees = 500)
  mu0_fit <- regression_forest(X[train & D == 0, , drop = FALSE], Y[train & D == 0],
                               num.trees = 500)
  e_fit   <- regression_forest(X[train, ], D[train], num.trees = 500)
  mu1_cf[test] <- predict(mu1_fit, X[test, ])$predictions
  mu0_cf[test] <- predict(mu0_fit, X[test, ])$predictions
  e_cf[test]   <- predict(e_fit,   X[test, ])$predictions
}
e_cf <- pmin(pmax(e_cf, 0.02), 0.98)

pseudo_DR <- (mu1_cf - mu0_cf) +
             D * (Y - mu1_cf) / e_cf -
             (1 - D) * (Y - mu0_cf) / (1 - e_cf)

fit_DR <- regression_forest(X, pseudo_DR, num.trees = 500)
tau_DR <- predict(fit_DR, X)$predictions

cor_DR <- cor(tau_DR, tau)
cat(sprintf("DR-learner CATE correlation with truth: %.3f\n", cor_DR))
DR-learner CATE correlation with truth: 0.941

9.3 Causal forests

Wager-Athey (2018) causal forests are tailored for CATE estimation. They use a splitting criterion that maximises heterogeneity in treatment effects (rather than outcome variance), and they use honest sample-splitting so that each tree’s leaf estimates are valid for inference at any test point.

cf <- causal_forest(X = X, Y = Y, W = D, num.trees = 2000)

tau_cf <- predict(cf)$predictions
cor_cf <- cor(tau_cf, tau)
cat(sprintf("Causal forest CATE correlation with truth: %.3f\n", cor_cf))
Causal forest CATE correlation with truth: 0.987
# Doubly robust ATE estimate
ate_cf <- average_treatment_effect(cf)
cat(sprintf("\nCausal forest AIPW ATE: %.3f  (SE = %.3f, true = %.3f)\n",
            ate_cf["estimate"], ate_cf["std.err"], mean(tau)))

Causal forest AIPW ATE: 2.158  (SE = 0.033, true = 2.006)

The pointwise variance estimates predict(cf, estimate.variance = TRUE) give asymptotically valid confidence intervals for each CATE prediction.

9.3.1 Comparing all estimators

comparison <- tibble(
  X1     = X[, 1],
  true   = tau,
  S      = tau_S,
  Tlearn = tau_T,
  X      = tau_X,
  R      = tau_R,
  DR     = tau_DR,
  CF     = tau_cf
) |>
  pivot_longer(cols = c(S, Tlearn, X, R, DR, CF),
               names_to = "Estimator", values_to = "Predicted")

ggplot(comparison, aes(x = X1, y = Predicted)) +
  geom_point(alpha = 0.1, size = 0.4) +
  geom_smooth(method = "loess", se = FALSE, colour = "steelblue", linewidth = 1) +
  geom_abline(aes(intercept = 1, slope = 2), linetype = "dashed",
              colour = "firebrick") +
  facet_wrap(~ Estimator, ncol = 3) +
  labs(x = expression(X[1]),
       y = "Estimated CATE",
       caption = "Red dashed line = true τ(x) = 1 + 2*X1") +
  theme_minimal()

Estimated CATE vs true τ(x) = 1 + 2*X1 for each estimator. Closer to the dashed line = better.

9.4 Best linear projection (BLP)

Even when CATE is high-dimensional, we often want a simple summary: how much does the treatment effect change with each covariate? The best linear projection of CATE onto \(X\) is the OLS projection

\[ \text{BLP}(X) = \arg\min_{\beta} \mathbb{E}\left[ (\tau(X) - X'\beta)^2 \right], \]

which grf estimates with valid standard errors:

blp <- best_linear_projection(cf, A = X)
print(blp)

Best linear projection of the conditional average treatment effect.
Confidence intervals are cluster- and heteroskedasticity-robust (HC3):

             Estimate Std. Error t value Pr(>|t|)    
(Intercept)  1.159737   0.126526  9.1660   <2e-16 ***
X1           2.171075   0.111928 19.3971   <2e-16 ***
X2           0.029308   0.112772  0.2599   0.7950    
X3           0.079042   0.110110  0.7178   0.4729    
X4          -0.151089   0.109918 -1.3746   0.1693    
X5          -0.147769   0.111352 -1.3270   0.1846    
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

The BLP coefficients on \(X_1\) should be close to 2 (the true slope of the CATE) and significant; coefficients on \(X_2\) through \(X_5\) should be close to 0 and insignificant. This is a clean test for heterogeneity along specific covariates.

9.5 Variable importance

grf provides a variable-importance measure based on how often each covariate is used for splits in the causal-effect splits (vs the outcome-prediction splits):

vi <- variable_importance(cf)
vi_df <- tibble(variable = paste0("X", 1:p), importance = as.numeric(vi)) |>
  arrange(desc(importance))
print(vi_df)
# A tibble: 5 × 2
  variable importance
  <chr>         <dbl>
1 X1           0.728 
2 X5           0.0753
3 X4           0.0716
4 X2           0.0657
5 X3           0.0598

\(X_1\) should top the list. In real applications, this is a useful guide to which covariates carry heterogeneity.

9.6 Policy trees: who should we treat?

A CATE estimate tells us the effect of treating each individual. But the policy question is often different: given a budget or a cost-benefit trade-off, which individuals should we treat? Athey-Wager (2021) propose policy trees — shallow decision trees that maximise welfare given a CATE estimate and a treatment cost.

# Compute doubly-robust scores (used for policy learning)
dr_scores <- double_robust_scores(cf)
cost <- 0.5     # cost of treating one unit, in outcome units

# Net benefit of treatment vs control
gain <- cbind(0, dr_scores - cost)   # column 1 = no treat (gain 0), col 2 = treat (gain τ - cost)

# Fit a depth-2 policy tree (shallow for interpretability)
tree <- policy_tree(X, gain, depth = 2)
print(tree)
policy_tree object 
Tree depth:  2 
Actions:  1:  2: control 3: treated 
Variable splits: 
(1) split_variable: X1  split_value: 0.163281 
  (2) split_variable: X2  split_value: 0.984766 
    (4) * action: 3 
    (5) * action: 2 
  (3) split_variable: X1  split_value: 0.164223 
    (6) * action: 2 
    (7) * action: 3 
# Recommended action for each unit
actions <- predict(tree, X)
cat(sprintf("\nTreatment rate under policy tree: %.1f%%\n",
            mean(actions == 2) * 100))

Treatment rate under policy tree: 0.3%
cat(sprintf("Estimated welfare gain over treat-everyone: %.3f units per person\n",
            mean((actions == 2) * (dr_scores - cost) - (dr_scores - cost))))
Estimated welfare gain over treat-everyone: -0.743 units per person

The policy tree provides an interpretable assignment rule. Real-world policies are often required to be interpretable (e.g. fair-housing laws, medical decision rules), and shallow trees are the natural compromise between flexibility and explainability.

For an extended walkthrough of policy trees with IPW and AIPW losses, see the companion blog chapter on policytree. For a comparison of CATE estimators across software ecosystems (including Stata 19’s new cate command), see the Stata CATE blog chapter.

9.7 CLAN: testing for heterogeneity

A formal test of heterogeneity is the CLAN (Classification Analysis) proposed by Chernozhukov-Demirer-Duflo-Fernandez-Val (2018). The idea: bin the population by predicted CATE, then estimate the average treatment effect in each bin. If the bin-specific ATEs are statistically different, there is heterogeneity.

# Sort by predicted CATE, split into 5 quintiles
nq <- 5
quintile <- cut(tau_cf, breaks = quantile(tau_cf, probs = seq(0, 1, 1/nq)),
                include.lowest = TRUE, labels = FALSE)

# Compute AIPW ATE within each quintile
get_ate <- function(idx) {
  ate <- average_treatment_effect(cf, subset = idx)
  c(est = ate["estimate"], se = ate["std.err"])
}

clan <- t(sapply(1:nq, function(q) get_ate(quintile == q)))
clan_df <- as_tibble(clan) |>
  mutate(quintile = 1:nq,
         lo = est.estimate - 1.96 * se.std.err,
         hi = est.estimate + 1.96 * se.std.err)

knitr::kable(clan_df[, c("quintile", "est.estimate", "se.std.err", "lo", "hi")],
             col.names = c("Quintile", "ATE", "SE", "95% LB", "95% UB"),
             digits = 3,
             caption = "Quintile-specific ATEs from CLAN: heterogeneity is evident if quintile estimates differ.")
Quintile-specific ATEs from CLAN: heterogeneity is evident if quintile estimates differ.
Quintile ATE SE 95% LB 95% UB
1 1.310 0.071 1.170 1.450
2 1.710 0.071 1.570 1.849
3 2.244 0.069 2.109 2.380
4 2.607 0.071 2.467 2.747
5 2.918 0.074 2.773 3.063

The bottom and top quintiles of predicted CATE should give significantly different ATEs — a clean test for whether the heterogeneity is real (rather than estimation noise).

9.8 Causal forests with panel data

When the data are a panel — repeated observations on the same units — and the unobserved unit effects are correlated with treatment or covariates, the cross-sectional CATE methods above can be biased. Two practical solutions:

  1. Demean by fixed effects before running the causal forest.
  2. Pass a cluster argument to causal_forest so that observations from the same unit are kept in the same training fold during honest sample-splitting.
set.seed(2024)
n_firms <- 200
n_t     <- 5
n_total <- n_firms * n_t

# Panel data with a unit-specific effect correlated with V1 (the modifier)
firm_id  <- rep(1:n_firms, each = n_t)
unit_fe  <- rnorm(n_firms, sd = 1.5)[firm_id]
V1_firm  <- rnorm(n_firms, sd = 1.0)[firm_id]   # firm-level covariate
V1       <- V1_firm + rnorm(n_total, sd = 0.3)  # within-firm variation
W        <- rbinom(n_total, 1, plogis(0.5 + 0.3 * V1_firm))
# True CATE varies linearly with V1
tau_panel <- 0.5 + 1.0 * V1
Y_panel  <- unit_fe + V1 + W + tau_panel * (W - mean(W)) + rnorm(n_total)

df_panel <- tibble(firm = firm_id, V1 = V1, W = W, Y = Y_panel)
glimpse(df_panel)
Rows: 1,000
Columns: 4
$ firm <int> 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5,…
$ V1   <dbl> 0.60231981, 0.23871074, 0.05916058, -0.24663844, 0.28957601, -0.1…
$ W    <int> 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1,…
$ Y    <dbl> 4.32198952, 4.54919561, 1.16597186, 3.35752234, 2.89403659, 0.037…

A naive causal forest on the panel data ignores the unit effects:

X_panel <- as.matrix(df_panel[, "V1", drop = FALSE])

cf_naive <- causal_forest(X = X_panel, Y = df_panel$Y, W = df_panel$W,
                          num.trees = 1000)
ate_naive <- average_treatment_effect(cf_naive)

cat(sprintf("Naive panel CF ATE: %.3f  (SE = %.3f)\n",
            ate_naive["estimate"], ate_naive["std.err"]))
Naive panel CF ATE: 1.523  (SE = 0.137)
cat(sprintf("True ATE: %.3f\n", mean(tau_panel)))
True ATE: 0.566

Adding clusters = firm: grf keeps each firm’s observations in the same training/honest sample, eliminating the leakage that would otherwise bias CATE estimates:

cf_clust <- causal_forest(X = X_panel, Y = df_panel$Y, W = df_panel$W,
                          clusters = df_panel$firm,
                          num.trees = 1000)
ate_clust <- average_treatment_effect(cf_clust)

cat(sprintf("Clustered panel CF ATE: %.3f  (SE = %.3f)\n",
            ate_clust["estimate"], ate_clust["std.err"]))
Clustered panel CF ATE: 1.541  (SE = 0.164)

For the strongest panel-data identification, demean both \(Y\) and \(W\) by firm before fitting (within transformation):

df_dm <- df_panel |>
  group_by(firm) |>
  mutate(Y_dm = Y - mean(Y),
         W_dm = W - mean(W),
         V1_dm = V1 - mean(V1)) |>
  ungroup()

cf_fe <- causal_forest(X = as.matrix(df_dm[, "V1_dm"]),
                       Y = df_dm$Y_dm,
                       W = df_dm$W_dm,
                       clusters = df_dm$firm,
                       num.trees = 1000)
ate_fe <- average_treatment_effect(cf_fe)
cat(sprintf("Within-transform CF ATE: %.3f  (SE = %.3f)\n",
            ate_fe["estimate"], ate_fe["std.err"]))
Within-transform CF ATE: 1.484  (SE = 0.104)

The clustered + within-transform causal forest is the cleanest analog of a panel fixed-effects regression for heterogeneous treatment effects. See the companion blog chapter on causal forests in panel data for an extended simulation study comparing the panel CF to fixed-effect regression and random-effect regression under linear and nonlinear DGPs.

9.9 Summary

  • Meta-learners (S, T, X, R, DR) turn any regression algorithm into a CATE estimator. R-learner and DR-learner exploit Neyman orthogonality and cross-fitting for valid inference; X-learner helps with imbalanced treatment arms.
  • Causal forests (grf) provide a purpose-built CATE estimator with honest sample splitting, pointwise confidence intervals, and convenient diagnostics (BLP, variable importance, CLAN).
  • Best Linear Projection (BLP) gives a simple regression-style summary of CATE — coefficients with valid standard errors on how the effect varies with each covariate.
  • CLAN (quintile-binned ATEs) provides a formal test for heterogeneity and a clean way to communicate it.
  • Policy trees (policytree) translate a CATE into an interpretable treatment-assignment rule, balancing welfare against cost.
  • For applied work: estimate CATE with at least two methods (e.g. causal forest + R-learner), report BLP coefficients and quintile-binned ATEs, and use a shallow policy tree when treatment-assignment recommendations are needed.