using DataFrames
using Distributions
using Random
using Statistics
using LinearAlgebra
using Printf
using MLJ
using MLJDecisionTreeInterface
using GLM
using CairoMakie9 Heterogeneous Treatment Effects with Machine Learning
The estimands chapter defined the conditional average treatment effect
\[ \text{CATE}(x) = \mathbb{E}[Y(1) - Y(0) \mid X = x] \]
as the effect of treatment for individuals with covariates \(X = x\). There, we computed CATE from known potential outcomes — a luxury only simulation affords. In real data, \(Y(0)\) and \(Y(1)\) are never both observed, and CATE must be estimated from \((X, D, Y)\) without strong functional-form assumptions.
This chapter covers the meta-learner recipe: any flexible regression algorithm can be turned into a CATE estimator. We implement the five most common meta-learners — S-, T-, X-, R-, and DR-learners — using MLJ + DecisionTree.jl’s random forests as the base learner.
The chapter focuses on meta-learners rather than honest causal forests (Wager-Athey 2018) because Julia does not currently have a native implementation of honest sample-splitting trees. A T-learner with random forests provides a reasonable approximation; for production work with formal confidence intervals on CATE predictions, the R companion chapter covers grf::causal_forest.
9.1 The data-generating process
Reuse the DGP from the estimands chapter, where the true CATE is \(\tau(x) = 1 + 2 x_1\):
Random.seed!(42)
n = 5000
p = 5
X = rand(n, p)
U = randn(n)
# Treatment depends on a noise covariate (X2) and U — backdoor confounding
ps = @. 1 / (1 + exp(-(-0.3 + 1.5 * X[:, 2] + 0.5 * U)))
D = Float64.(rand(n) .< ps)
tau = @. 1 + 2 * X[:, 1]
Y0 = @. 0.5 * X[:, 2] + 0.3 * U + randn()
Y1 = Y0 .+ tau
Y = ifelse.(D .== 1, Y1, Y0)
@printf("n = %d, true ATE = %.3f\n", n, mean(tau))
@printf("True CATE at X1 = 0.2: %.2f\n", 1 + 2 * 0.2)
@printf("True CATE at X1 = 0.8: %.2f\n", 1 + 2 * 0.8)The treatment is confounded through \(X_2\) and the unobserved \(U\). We condition on \(X_2\) to restore identification; we cannot condition on \(U\).
9.2 A wrapper for fitting random forests
To keep the code compact, define a small helper that fits a random forest and returns the predictions:
const RFR = @load RandomForestRegressor pkg=DecisionTree verbosity=0
"""
rf_fit_predict(Xtrain, ytrain, Xpredict; n_trees=500)
Fit a random forest on (Xtrain, ytrain) and return predictions at Xpredict.
"""
function rf_fit_predict(Xtrain, ytrain, Xpredict; n_trees::Int=500)
learner = RFR(n_trees=n_trees, max_depth=-1)
Xtrain_t = MLJ.table(Xtrain)
Xpred_t = MLJ.table(Xpredict)
mach = machine(learner, Xtrain_t, Float64.(ytrain))
fit!(mach, verbosity=0)
return MLJ.predict(mach, Xpred_t)
end
nothing9.3 S-learner (“single”)
Fit one regression of \(Y\) on \((D, X)\), then predict the difference between \(D = 1\) and \(D = 0\) for each \(x\):
X_with_D = hcat(D, X)
X1_test = hcat(ones(n), X)
X0_test = hcat(zeros(n), X)
mu1_S = rf_fit_predict(X_with_D, Y, X1_test)
mu0_S = rf_fit_predict(X_with_D, Y, X0_test)
tau_S = mu1_S .- mu0_S
@printf("S-learner CATE correlation with truth: %.3f\n", cor(tau_S, tau))The S-learner is simple but biased toward zero when the treatment effect is small relative to outcome variation, because the regression “ignores” the \(D\) variable if it has little predictive value.
9.4 T-learner (“two”)
Fit two separate regressions, one on treated units and one on controls:
idx_T = D .== 1
idx_C = D .== 0
mu1_T = rf_fit_predict(X[idx_T, :], Y[idx_T], X)
mu0_T = rf_fit_predict(X[idx_C, :], Y[idx_C], X)
tau_T = mu1_T .- mu0_T
@printf("T-learner CATE correlation with truth: %.3f\n", cor(tau_T, tau))The T-learner avoids the bias-toward-zero of the S-learner but suffers when the treated and control distributions diverge: extrapolation errors compound when subtracting the two fitted functions.
9.5 X-learner (Künzel et al. 2019)
The X-learner imputes the missing potential outcomes using the T-learner fits, then regresses the resulting pseudo-outcomes on \(X\) separately within each arm, and combines them with weights from the propensity score:
# Step 1: pseudo-outcomes
D1_pseudo = Y[idx_T] .- mu0_T[idx_T] # observed - imputed Y(0)
D0_pseudo = mu1_T[idx_C] .- Y[idx_C] # imputed Y(1) - observed
# Step 2: regress pseudo-outcomes on X in each arm
tau_X1 = rf_fit_predict(X[idx_T, :], D1_pseudo, X)
tau_X0 = rf_fit_predict(X[idx_C, :], D0_pseudo, X)
# Step 3: weight by propensity score
e_hat = rf_fit_predict(X, D, X)
e_hat = clamp.(e_hat, 0.02, 0.98)
tau_X = e_hat .* tau_X0 .+ (1 .- e_hat) .* tau_X1
@printf("X-learner CATE correlation with truth: %.3f\n", cor(tau_X, tau))The X-learner is particularly effective when the propensity score is far from 0.5 (so one treatment arm is much smaller than the other) — common in observational studies.
9.6 R-learner (Nie & Wager 2021)
The R-learner uses Robinson’s (1988) partialled-out residual approach. Define pseudo-outcomes that, in expectation, equal \(\tau(X)\):
\[ \tilde Y_i = \frac{Y_i - \hat m(X_i)}{D_i - \hat e(X_i)}, \qquad \text{weight}_i = (D_i - \hat e(X_i))^2, \]
where \(\hat m(x) = \mathbb{E}[Y \mid X = x]\) and \(\hat e(x) = \mathbb{E}[D \mid X = x]\) are cross-fitted nuisance estimates.
# Cross-fitting: 2 folds
Random.seed!(7)
folds = rand(1:2, n)
m_hat = zeros(n)
e_hat_cv = zeros(n)
for k in 1:2
train = folds .!= k
test = folds .== k
m_hat[test] = rf_fit_predict(X[train, :], Y[train], X[test, :])
e_hat_cv[test] = rf_fit_predict(X[train, :], D[train], X[test, :])
end
e_hat_cv = clamp.(e_hat_cv, 0.02, 0.98)
pseudo_R = (Y .- m_hat) ./ (D .- e_hat_cv)
weights_R = (D .- e_hat_cv) .^ 2
# Weighted regression: use random forest with per-sample weights via
# weighted bootstrap (cheap approximation: oversample by sqrt(weight))
# For exact weighted regression, MLJ regressors can take a weights argument.
# We approximate here with unweighted RF on pseudo_R (the R-learner identity
# holds in expectation; the weighting reduces variance).
tau_R = rf_fit_predict(X, pseudo_R, X)
@printf("R-learner CATE correlation with truth: %.3f\n", cor(tau_R, tau))The R-learner is doubly robust in a Neyman-orthogonality sense and tends to be the most efficient meta-learner in moderately well-overlapping data.
9.7 DR-learner (Kennedy 2020)
The DR-learner constructs AIPW-based pseudo-outcomes and regresses them on \(X\):
\[ \tilde Y_i^{DR} = \hat\mu_1(X_i) - \hat\mu_0(X_i) + \frac{D_i (Y_i - \hat\mu_1(X_i))}{\hat e(X_i)} - \frac{(1 - D_i) (Y_i - \hat\mu_0(X_i))}{1 - \hat e(X_i)}. \]
mu1_cf = zeros(n)
mu0_cf = zeros(n)
e_cf = zeros(n)
for k in 1:2
train = folds .!= k
test = folds .== k
mu1_cf[test] = rf_fit_predict(X[train .& (D .== 1), :], Y[train .& (D .== 1)],
X[test, :])
mu0_cf[test] = rf_fit_predict(X[train .& (D .== 0), :], Y[train .& (D .== 0)],
X[test, :])
e_cf[test] = rf_fit_predict(X[train, :], D[train], X[test, :])
end
e_cf = clamp.(e_cf, 0.02, 0.98)
pseudo_DR = @. (mu1_cf - mu0_cf) +
D * (Y - mu1_cf) / e_cf -
(1 - D) * (Y - mu0_cf) / (1 - e_cf)
tau_DR = rf_fit_predict(X, pseudo_DR, X)
@printf("DR-learner CATE correlation with truth: %.3f\n", cor(tau_DR, tau))9.8 Comparing all five estimators
fig = Figure(size = (1000, 700))
labs = ["S-learner", "T-learner", "X-learner", "R-learner", "DR-learner"]
preds = [tau_S, tau_T, tau_X, tau_R, tau_DR]
for (i, (lab, pred)) in enumerate(zip(labs, preds))
row, col = divrem(i - 1, 3) .+ (1, 1)
ax = Axis(fig[row, col],
xlabel = "X1", ylabel = "Estimated CATE",
title = lab)
scatter!(ax, X[:, 1], pred, color = (:steelblue, 0.2), markersize = 4)
lines!(ax, 0:0.01:1, x -> 1 + 2x, color = :firebrick, linestyle = :dash,
linewidth = 2)
end
fig9.9 Best Linear Projection (BLP)
Even when CATE is high-dimensional, we often want a simple summary: how does the treatment effect change with each covariate? The best linear projection of CATE onto \(X\) is the OLS regression of the pseudo-outcomes on the covariates:
\[ \text{BLP}(X) = \arg\min_{\beta} \mathbb{E}\left[(\tau(X) - X'\beta)^2\right]. \]
A doubly-robust BLP uses the DR-learner pseudo-outcomes:
blp_df = DataFrame(hcat(pseudo_DR, X), [:tau_pseudo, :X1, :X2, :X3, :X4, :X5])
blp_fit = lm(@formula(tau_pseudo ~ X1 + X2 + X3 + X4 + X5), blp_df)
println(coeftable(blp_fit))The BLP coefficient on \(X_1\) should be close to 2 (the true slope of the CATE) and statistically significant; the coefficients on \(X_2\) through \(X_5\) should be close to 0. This is a clean test for heterogeneity along specific covariates.
9.10 CLAN: testing for heterogeneity
A formal test of heterogeneity is the CLAN (Classification Analysis) proposed by Chernozhukov-Demirer-Duflo-Fernandez-Val (2018). Bin the population by predicted CATE, then estimate the average treatment effect in each bin. If the bin-specific ATEs are statistically different, there is heterogeneity.
nq = 5
quintile_edges = quantile(tau_DR, range(0, 1, length = nq + 1))
quintiles = searchsortedfirst.(Ref(quintile_edges), tau_DR) .- 1
quintiles = clamp.(quintiles, 1, nq)
# AIPW-style ATE within each quintile using the DR pseudo-outcomes
function quintile_ate(tau_pseudo, idx)
n_q = sum(idx)
est = mean(tau_pseudo[idx])
sd = std(tau_pseudo[idx]) / sqrt(n_q)
return (est = est, se = sd)
end
clan_df = DataFrame(quintile = 1:nq,
ATE = [quintile_ate(pseudo_DR, quintiles .== q).est for q in 1:nq],
SE = [quintile_ate(pseudo_DR, quintiles .== q).se for q in 1:nq])
clan_df.lo = clan_df.ATE .- 1.96 .* clan_df.SE
clan_df.hi = clan_df.ATE .+ 1.96 .* clan_df.SE
@printf("%-10s %8s %8s %8s %8s\n", "Quintile", "ATE", "SE", "95% LB", "95% UB")
for row in eachrow(clan_df)
@printf("%-10d %8.3f %8.3f %8.3f %8.3f\n",
row.quintile, row.ATE, row.SE, row.lo, row.hi)
endQuintile 5 should give a significantly larger ATE than quintile 1 — a formal test for whether the heterogeneity is real (rather than estimation noise).
9.11 Variable importance — which covariate drives heterogeneity?
A natural diagnostic: regress the DR pseudo-outcomes on each covariate separately and look at the partial \(R^2\):
function single_var_r2(target, x)
df_one = DataFrame(t = target, x = x)
fit = lm(@formula(t ~ x), df_one)
1 - sum(abs2.(residuals(fit))) / sum(abs2.(target .- mean(target)))
end
vi = [single_var_r2(pseudo_DR, X[:, j]) for j in 1:p]
vi_df = DataFrame(variable = ["X$j" for j in 1:p], R2 = vi)
sort!(vi_df, :R2, rev = true)
println(vi_df)\(X_1\) should top the list. In real applications, this is a useful guide to which covariates carry heterogeneity.
9.12 Policy learning: who should we treat?
A CATE estimate tells us the effect of treating each individual. But the policy question is often different: given a cost-benefit trade-off, which individuals should we treat? Define a treatment cost \(c\) in outcome units; the optimal individualised policy is
\[ \pi^*(x) = \mathbb{1}\{\tau(x) > c\}. \]
For interpretability, restrict to a shallow decision rule. A simple approach: pick a single covariate threshold that maximises welfare.
cost = 0.5
# Optimal individualised policy (using estimated CATE)
treat_optimal = tau_DR .> cost
welfare_opt = mean((treat_optimal .== 1) .* (pseudo_DR .- cost) .+
(treat_optimal .== 0) .* 0.0)
# Naive "treat everyone" policy
welfare_all = mean(pseudo_DR .- cost)
# Simple threshold rule on X1
function threshold_welfare(thresh)
treat = X[:, 1] .> thresh
mean(treat .* (pseudo_DR .- cost) .+ (1 .- treat) .* 0.0)
end
threshes = 0:0.05:1
welfares = [threshold_welfare(t) for t in threshes]
best_t = threshes[argmax(welfares)]
@printf("Welfare (treat everyone): %.3f\n", welfare_all)
@printf("Welfare (treat if τ̂(x) > %.2f): %.3f\n", cost, welfare_opt)
@printf("Welfare (best X1-threshold rule, X1 > %.2f): %.3f\n",
best_t, maximum(welfares))The optimal policy treats only individuals whose estimated CATE exceeds the cost. The interpretable threshold rule on \(X_1\) alone is a reasonable compromise between welfare and explainability.
For a richer treatment of policy trees with IPW and AIPW losses (using R’s policytree package), see the companion blog chapter on policytree. For a cross-software comparison of CATE estimators (including Stata 19’s new cate command), see the Stata CATE blog chapter.
9.13 Causal forests with panel data
When the data are a panel — repeated observations on the same units — and the unobserved unit effects are correlated with treatment or covariates, the cross-sectional CATE methods above can be biased. The classical fix is the within-transform: demean each variable by its unit-level mean before applying the meta-learner.
Random.seed!(2024)
n_firms = 200
n_t = 5
N = n_firms * n_t
firm_id = repeat(1:n_firms, inner = n_t)
unit_fe = randn(n_firms) .* 1.5
V1_firm = randn(n_firms) .* 1.0
V1_panel = V1_firm[firm_id] .+ 0.3 .* randn(N)
W_panel = Float64.(rand(N) .< @. 1 / (1 + exp(-(0.5 + 0.3 * V1_firm[firm_id]))))
tau_panel = @. 0.5 + 1.0 * V1_panel
Y_panel = unit_fe[firm_id] .+ V1_panel .+ W_panel .+
tau_panel .* (W_panel .- mean(W_panel)) .+ randn(N)
df_panel = DataFrame(firm = firm_id, V1 = V1_panel, W = W_panel, Y = Y_panel)
@printf("True panel ATE: %.3f\n", mean(tau_panel))A naive T-learner on the panel ignores the firm fixed effects:
X_panel_naive = reshape(df_panel.V1, N, 1)
idx_T_p = df_panel.W .== 1
idx_C_p = df_panel.W .== 0
mu1_naive = rf_fit_predict(X_panel_naive[idx_T_p, :], df_panel.Y[idx_T_p],
X_panel_naive)
mu0_naive = rf_fit_predict(X_panel_naive[idx_C_p, :], df_panel.Y[idx_C_p],
X_panel_naive)
tau_naive = mu1_naive .- mu0_naive
@printf("Naive panel T-learner ATE: %.3f (true = %.3f)\n",
mean(tau_naive), mean(tau_panel))The within-transform removes unit-level variation, leaving only the panel’s within-firm signal for CATE estimation:
# Within transformation: subtract firm means
df_dm = combine(groupby(df_panel, :firm),
:Y => (y -> y .- mean(y)) => :Y_dm,
:W => (w -> w .- mean(w)) => :W_dm,
:V1 => (v -> v .- mean(v)) => :V1_dm)
X_dm = reshape(df_dm.V1_dm, N, 1)
idx_T_dm = df_dm.W_dm .> 0
idx_C_dm = df_dm.W_dm .<= 0
mu1_dm = rf_fit_predict(X_dm[idx_T_dm, :], df_dm.Y_dm[idx_T_dm], X_dm)
mu0_dm = rf_fit_predict(X_dm[idx_C_dm, :], df_dm.Y_dm[idx_C_dm], X_dm)
tau_dm = mu1_dm .- mu0_dm
@printf("Within-transform T-learner ATE: %.3f (true = %.3f)\n",
mean(tau_dm), mean(tau_panel))The within-transform recovers the panel ATE because firm fixed effects are differenced out. The cost is variance: each unit contributes only its within-firm variation, so effective sample size shrinks.
For R’s grf::causal_forest, the cleaner approach is passing clusters = firm_id so the forest’s honest sample-splitting keeps each firm’s observations in the same training fold. The Julia ecosystem currently lacks a native equivalent — see the companion blog chapter on causal forests in panel data and the R companion to this chapter for the GRF-based workflow.
9.14 Summary
- Meta-learners (S, T, X, R, DR) turn any regression algorithm into a CATE estimator. R-learner and DR-learner exploit Neyman orthogonality and cross-fitting for valid inference; X-learner helps with imbalanced treatment arms.
- All five fit on top of
MLJ+ random forests in Julia using ~10–30 lines each. - Best Linear Projection (BLP) gives a simple regression-style summary of CATE — coefficients on how the effect varies with each covariate.
- CLAN (quintile-binned ATEs) provides a formal test for heterogeneity.
- Policy learning translates a CATE into a treatment-assignment rule; even a simple threshold rule on the most-heterogeneous covariate often achieves most of the welfare gain.
- For applied work: estimate CATE with at least two methods, report BLP coefficients and quintile-binned ATEs, and use an interpretable rule when treatment-assignment recommendations are needed.