library(engression)
# Nonlinear DGP with heteroskedastic noise
n <- 1000
X <- matrix(rnorm(n * 2), ncol = 2)
Y <- (X[,1])^2 + X[,2] + rnorm(n) * (0.5 + abs(X[,1]))
# Fit engression — learns the full conditional distribution
engr <- engression(X, Y)
# Predict conditional mean (like lm, but nonlinear)
Yhat_mean <- predict(engr, X, type = "mean")
# Predict conditional quantiles (like quantile regression, but joint)
Yhat_quant <- predict(engr, X, type = "quantiles")
# Sample from the learned distribution (generative)
Ysample <- predict(engr, X, type = "sample", nsample = 1)41 Frengression and Engression
In the causal simulation chapter, we saw how Evans & Didelez (2024) decompose the joint distribution \(p(x,a,y)\) into three pieces:
\[ p(x,a,y) = p(x,a) \cdot p_a(y) \cdot c(x,y|a) \]
where \(p(x,a)\) is the “past” (marginal of covariates and treatment), \(p_a(y)\) is the causal margin (marginal structural model), and \(c(x,y|a)\) is the copula (dependence between covariates and outcome given treatment). The causl package implements this with parametric families: you specify Gaussian, binomial, etc. for each piece.
But what if the real data has complicated structure that’s hard to specify parametrically? What if we don’t know the right copula family, or the marginal distributions are non-standard?
Frengression (Yang, Evans and Shen, 2025) solves this by learning all three pieces from data using neural networks. Same decomposition, but nonparametric.
It has two major functions: first, causal simulation using nonparametric generative models; second, interventional prediction using learned confounding structure. The latter is called Engression. First let’s talk about Engression.
41.1 Engression
Engression (Shen & Meinshausen, 2024, JRSSB) stands for “energy regression.” Standard regression predicts \(E[Y|X]\) — a single number. Engression learns the entire conditional distribution \(P(Y|X)\) using neural networks trained with the energy score.
41.1.1 The Energy Score
The energy score is a proper scoring rule for distributions. Given observations \(Y\) and samples \(\hat Y, \hat Y'\) from the model:
\[ S(Y, \hat Y, \hat Y') = E\|Y - \hat Y\| - \frac{1}{2}E\|\hat Y - \hat Y'\| \]
The first term measures how close the model’s samples are to the truth. The second term penalizes lack of spread — without it, the model could collapse to a point prediction. Minimizing this score trains the network to match the full distribution, not just the mean.
41.1.2 The Stochastic Network (StoNet)
The building block is the StoNet — a neural network that concatenates random noise \(\epsilon \sim N(0,I)\) with the input at each layer:
\[\text{StoNet}(x) = f(x, \epsilon), \quad \epsilon \sim N(0,I)\]
Different noise draws produce different outputs — each one is a sample from the learned conditional distribution. This is what makes engression a generative model. You can:
- Predict the mean: average over many noise draws → \(E[Y|X=x]\)
- Predict quantiles: take quantiles of the noise draws → conditional quantile regression
- Sample: each forward pass with fresh noise is a sample from \(P(Y|X=x)\)
41.1.3 Why Engression?
What makes engression special compared to other distributional regression methods (quantile regression, mixture density networks)?
- Extrapolation: the energy score encourages the model to extrapolate sensibly outside the training range. Standard neural networks often produce wild predictions outside training data; engression stays more calibrated.
- Multivariate: it handles multivariate responses naturally — no need to specify marginals separately.
- No distributional assumptions: unlike mixture density networks (which assume Gaussian mixtures) or quantile regression (which estimates quantiles one at a time), engression learns the full joint distribution of \(Y|X\) without parametric assumptions.
41.1.4 Quick Example
The engression R package (separate from Rfrengression) provides the standalone energy-regression estimator. The code below shows the API; see the engression CRAN page for installation. The Rfrengression examples later in this chapter are fully executable.
The key point: engression gives you a distributional model of \(Y|X\), not just a point prediction.
41.2 From Engression to Frengression
Frengression = frugal simulation + engression. It uses three engression models (StoNets) to decompose the joint distribution \(P(X,A,Y)\) into the causal simulation framework. Each sub-network is an engression model trained with the energy score.
41.3 The Architecture
Frengression has three sub-networks, matching the three pieces of the decomposition:
| Decomposition | causl (parametric) | Frengression (neural net) |
|---|---|---|
| \(P(X,A)\) — the past | Specified by family + params | model_xz: noise-only generator |
| \(P_a(Y)\) — causal margin | Specified by family + params | model_y: takes \((x, \text{eta})\), outputs \(y\) |
| \(c(X,Y \mid A)\) — copula | Gaussian/Frank/etc. copula | model_eta: takes \((x,z)\), outputs eta |
The key insight is how model_eta works. It takes the covariates \((x,z)\) and outputs a latent variable eta that encodes the dependence between confounders and outcome. During training, a Z-permutation trick forces eta to be marginally \(N(0,I)\). This means:
- When
etacomes frommodel_eta(x,z)(observational): it carries the confounding information, reproducing the observed data - When
etais sampled from \(N(0,1)\) independently (interventional): no confounding, giving you \(P(Y|do(X))\)
Same network, different source of eta. That’s the whole trick.
41.4 Using Rfrengression
The Rfrengression package is a native R implementation using the torch package.
library(Rfrengression)41.4.1 Example: Binary Treatment DGP
Let’s use a DGP with binary treatment, binary outcome, and four confounders:
set.seed(42)
n <- 2000
w1 <- rbinom(n, 1, 0.5)
w2 <- rbinom(n, 1, 0.5)
w3 <- round(runif(n, 0, 4), 3)
w4 <- round(runif(n, 0, 5), 3)
A <- rbinom(n, 1, plogis(-0.4 + 0.2*w2 + 0.15*w3 + 0.2*w4 + 0.15*w2*w4))
Y.1 <- rbinom(n, 1, plogis(-1 + 1 - 0.1*w1 + 0.3*w2 + 0.25*w3 + 0.2*w4 + 0.15*w2*w4))
Y.0 <- rbinom(n, 1, plogis(-1 + 0 - 0.1*w1 + 0.3*w2 + 0.25*w3 + 0.2*w4 + 0.15*w2*w4))
Y <- Y.1 * A + Y.0 * (1 - A)
true_ate <- mean(Y.1 - Y.0)
cat("True ATE:", round(true_ate, 4), "\n")True ATE: 0.1935
41.4.2 Train Frengression
Two-stage training: first the outcome model (model_y + model_eta), then the marginal (model_xz).
x_mat <- matrix(A, ncol = 1)
z_mat <- as.matrix(data.frame(w1, w2, w3, w4))
y_mat <- matrix(Y, ncol = 1)
model <- frengression(x_dim = 1, y_dim = 1, z_dim = 4,
x_binary = TRUE, y_binary = TRUE,
z_binary_dims = 2, noise_dim = 10)
model <- train_y(model, x_mat, z_mat, y_mat,
num_iters = 1000, lr = 1e-3, print_every = 500)Epoch 1: loss 1.2467, loss_y 0.5001 (0.5227, 0.0452), loss_eta 0.7466 (0.8311, 0.1690)
Stopping at iter 75
model <- train_xz(model, x_mat, z_mat,
num_iters = 1000, lr = 1e-4, print_every = 500)Epoch 1: loss 3.4889, loss1 3.6549, loss2 0.3320
Epoch 500: loss 1.3532, loss1 2.6650, loss2 2.6235
Epoch 1000: loss 1.3733, loss1 2.7221, loss2 2.6977
41.4.3 Estimate ATE
The interventional prediction samples eta from \(N(0,1)\) — no confounding:
y1 <- predict(model, matrix(1, ncol = 1), type = "mean", nsample = 2000)
y0 <- predict(model, matrix(0, ncol = 1), type = "mean", nsample = 2000)
cat("Frengression ATE:", round(y1 - y0, 4), "\n")Frengression ATE: 0.2056
cat("True ATE:", round(true_ate, 4), "\n")True ATE: 0.1935
41.5 Frengression as DGP
Frengression can estimate ATE, but that’s not its main strength — dedicated semiparametric estimators (npcausal, TMLE, DoubleML) are designed for that. The real value is using frengression as a data-generating process for benchmarking.
The idea:
- Train frengression on observational data — it learns the realistic confounding structure
- Replace the causal margin with a known function via
specify_causal() - Generate synthetic datasets — realistic confounding, known ground truth
# Specify a known causal effect: Y = sigmoid(0.2*X + eta)
model_bench <- specify_causal(model, function(x, eta) {
torch::torch_sigmoid(0.2 * x + eta)
})
# True ATE computed empirically from the model
y1_true <- predict(model_bench, matrix(1, ncol = 1), type = "mean", nsample = 5000)
y0_true <- predict(model_bench, matrix(0, ncol = 1), type = "mean", nsample = 5000)
cat("True ATE (specified):", round(as.numeric(y1_true - y0_true), 4), "\n")True ATE (specified): 0.0829
Now we can generate synthetic datasets and benchmark any method:
syn <- sample_joint(model_bench, sample_size = 1000)
cat("Synthetic data: n =", nrow(syn$x), "\n")Synthetic data: n = 1000
cat("Treatment prevalence:", mean(syn$x), "\n")Treatment prevalence: 0.917
cat("Outcome prevalence:", mean(syn$y), "\n")Outcome prevalence: 0.618
41.6 Real Data: LaLonde
Where this really shines is with real data. The LaLonde dataset (job training program, 614 observations) has complex confounding that no one could specify parametrically — many zeros in prior earnings, non-standard covariate distributions, poor overlap.
library(MatchIt)
data("lalonde", package = "MatchIt")
lalonde$black <- as.integer(lalonde$race == "black")
lalonde$hispan <- as.integer(lalonde$race == "hispan")
x_lal <- matrix(lalonde$treat, ncol = 1)
y_lal <- matrix(lalonde$re78, ncol = 1)
z_lal <- as.matrix(lalonde[, c("age", "educ", "black", "hispan",
"married", "nodegree", "re74", "re75")])
# Standardize for training
z_sc <- scale(z_lal)
y_sc <- (y_lal - mean(y_lal)) / sd(y_lal)
model_lal <- frengression(x_dim = 1, y_dim = 1, z_dim = 8,
x_binary = TRUE, y_binary = FALSE,
z_binary_dims = 4, noise_dim = 10)
model_lal <- train_y(model_lal, x_lal, z_sc, y_sc,
num_iters = 1000, lr = 1e-3, print_every = 500)Epoch 1: loss 1.5082, loss_y 0.7632 (0.8178, 0.1092), loss_eta 0.7449 (0.8150, 0.1402)
Stopping at iter 39
model_lal <- train_xz(model_lal, x_lal, z_sc,
num_iters = 1000, lr = 1e-4, print_every = 500)Epoch 1: loss 2.8319, loss1 2.9915, loss2 0.3193
Epoch 500: loss 2.0077, loss1 3.8244, loss2 3.6334
Epoch 1000: loss 1.9812, loss1 3.8653, loss2 3.7682
Now specify a known effect and generate:
model_lal_bench <- specify_causal(model_lal, function(x, eta) {
0.3 * x + eta
})
# Generate synthetic LaLonde-like data with known ATE = 0.3 (standardized)
syn_lal <- sample_joint(model_lal_bench, sample_size = 2000)
cat("Synthetic LaLonde: n =", nrow(syn_lal$x), "\n")Synthetic LaLonde: n = 2000
cat("Treatment prevalence:", round(mean(syn_lal$x), 3), "\n")Treatment prevalence: 0.032
cat("True ATE:", 0.3, "(standardized) =", round(0.3 * sd(y_lal), 0), "dollars\n")True ATE: 0.3 (standardized) = 2241 dollars
The synthetic data inherits LaLonde’s messy confounding structure — the complex propensity score, the zero-inflated prior earnings, the demographic imbalances — but we know the true causal effect. No researcher could hand-specify this DGP.
41.7 Comparison: causl vs frengression
Both packages implement the same decomposition:
| Feature | causl | Frengression |
|---|---|---|
| Marginal \(P(X,A)\) | Parametric families | Learned by neural net |
| Causal margin \(P_a(Y)\) | Parametric families | Learned (or specified) |
| Copula | Gaussian, Frank, etc. | Learned by model_eta |
| Requires specification | Yes — families + params | No — learns from data |
| Can learn from real data | Limited (plasmode) | Yes — full joint |
| Distributional flexibility | Limited by family choice | Arbitrary |
The causl approach is more transparent — you know exactly what you specified. Frengression is more flexible — it can capture structure you didn’t anticipate. They’re complementary.
41.8 Benchmarking: Frengression vs Semiparametric Methods
Now let’s compare frengression against standard semiparametric causal inference tools on the same data. The three comparison methods are:
- npcausal: AIPW estimator with SuperLearner nuisance models
- TMLE: Targeted minimum loss-based estimation with SuperLearner
- DoubleML: Partially linear regression with ranger (random forest)
These methods are purpose-built for ATE estimation using influence function theory — they should be more precise for this specific estimand. The question is how frengression compares, and what it can do that they cannot.
library(npcausal)
library(tmle)
library(SuperLearner)
library(DoubleML)
library(mlr3)
library(mlr3learners)
library(tidyverse)
library(knitr)41.8.1 Single-Dataset Comparison
First, run all four methods on the binary DGP data created above.
W <- data.frame(w1, w2, w3, w4)
SL.library <- c("SL.earth", "SL.glm.interaction", "SL.mean",
"SL.ranger", "SL.glmnet")
set.seed(123)
aipw_fit <- ate(y = Y, a = A, x = W, nsplits = 2, sl.lib = SL.library)
|
| | 0%
|
|========= | 12%
|
|================== | 25%
|
|========================== | 38%
|
|=================================== | 50%
|
|============================================ | 62%
|
|==================================================== | 75%
|
|============================================================= | 88%
|
|======================================================================| 100%
parameter est se ci.ll ci.ul pval
1 E{Y(0)} 0.5697312 0.02038159 0.5297832 0.6096791 0
2 E{Y(1)} 0.7608322 0.01207030 0.7371744 0.7844900 0
3 E{Y(1)-Y(0)} 0.1911010 0.02335102 0.1453330 0.2368690 0
ate_npcausal <- aipw_fit$res$est[3]
se_npcausal <- aipw_fit$res$se[3]
ci_npcausal <- c(aipw_fit$res$ci.ll[3], aipw_fit$res$ci.ul[3])
cat("npcausal ATE:", round(ate_npcausal, 4), "\n")npcausal ATE: 0.1911
cat("SE:", round(se_npcausal, 4), "\n")SE: 0.0234
set.seed(123)
tmle_fit <- tmle(Y = Y, A = A, W = W, family = "binomial",
Q.SL.library = SL.library, g.SL.library = SL.library)
ate_tmle <- tmle_fit$estimates$ATE$psi
se_tmle <- sqrt(tmle_fit$estimates$ATE$var.psi)
ci_tmle <- tmle_fit$estimates$ATE$CI
cat("TMLE ATE:", round(ate_tmle, 4), "\n")TMLE ATE: 0.1916
cat("SE:", round(se_tmle, 4), "\n")SE: 0.0209
lgr::get_logger("mlr3")$set_threshold("warn")
data_obs <- data.frame(w1, w2, w3, w4, A, Y)
dml_data <- DoubleMLData$new(data_obs,
y_col = "Y", d_cols = "A",
x_cols = c("w1", "w2", "w3", "w4"))
learner_l <- lrn("regr.ranger", num.trees = 500, max.depth = 5, min.node.size = 2)
learner_m <- learner_l$clone()
set.seed(123)
dml_plr <- DoubleMLPLR$new(dml_data, ml_l = learner_l, ml_m = learner_m)
dml_plr$fit()
ate_dml <- dml_plr$coef
se_dml <- dml_plr$se
ci_dml <- c(dml_plr$confint()[1], dml_plr$confint()[2])
cat("DoubleML ATE:", round(ate_dml, 4), "\n")DoubleML ATE: 0.1872
cat("SE:", round(se_dml, 4), "\n")SE: 0.0229
# Frengression ATE already computed above (y1 - y0)
ate_freng <- as.numeric(y1 - y0)
# Bootstrap SE via repeated prediction (Monte Carlo variability)
set.seed(123)
ate_boot <- replicate(200, {
predict(model, matrix(1, ncol = 1), type = "mean", nsample = 500) -
predict(model, matrix(0, ncol = 1), type = "mean", nsample = 500)
})
se_freng <- sd(ate_boot)results <- data.frame(
Method = c("True ATE", "Frengression", "npcausal (AIPW)", "TMLE", "DoubleML (PLR)"),
ATE = c(true_ate, ate_freng, ate_npcausal, ate_tmle, ate_dml),
SE = c(NA, se_freng, se_npcausal, se_tmle, se_dml),
CI_lower = c(NA,
ate_freng - 1.96 * se_freng,
ci_npcausal[1], ci_tmle[1], ci_dml[1]),
CI_upper = c(NA,
ate_freng + 1.96 * se_freng,
ci_npcausal[2], ci_tmle[2], ci_dml[2])
)
results$Bias <- results$ATE - true_ate
results$Covers <- ifelse(is.na(results$CI_lower), NA,
results$CI_lower <= true_ate & true_ate <= results$CI_upper)
kable(results, digits = 4,
caption = "ATE Estimation: Frengression vs Semiparametric Methods")| Method | ATE | SE | CI_lower | CI_upper | Bias | Covers |
|---|---|---|---|---|---|---|
| True ATE | 0.1935 | NA | NA | NA | 0.0000 | NA |
| Frengression | 0.2056 | 0.0329 | 0.1411 | 0.2700 | 0.0121 | TRUE |
| npcausal (AIPW) | 0.1911 | 0.0234 | 0.1453 | 0.2369 | -0.0024 | TRUE |
| TMLE | 0.1916 | 0.0209 | 0.1506 | 0.2326 | -0.0019 | TRUE |
| DoubleML (PLR) | 0.1872 | 0.0229 | 0.1424 | 0.2320 | -0.0063 | TRUE |
The semiparametric estimators are purpose-built for ATE and leverage influence function theory to achieve \(\sqrt{n}\)-consistent, doubly-robust estimates. Frengression learns the entire joint distribution \(P(X,Y,Z)\), from which ATE is just one functional — so some loss of precision for this specific estimand is expected.
41.8.2 Monte Carlo Benchmarking (Simulated DGP)
This is frengression’s “killer app.” We use the model trained above, plug in a known causal effect via specify_causal(), generate B=10 synthetic datasets, and benchmark all four estimators.
B <- 10
n_syn <- 2000
methods <- c("Frengression", "npcausal", "TMLE", "DoubleML")
ate_mat <- matrix(NA, nrow = B, ncol = length(methods),
dimnames = list(NULL, methods))
se_mat <- ate_mat
cover_mat <- ate_mat
SL.lib_fast <- c("SL.glm.interaction", "SL.ranger", "SL.mean")
# True ATE from the specified model (already computed above as y1_true - y0_true)
true_ate_bench <- as.numeric(y1_true - y0_true)
set.seed(2024)
for (b in seq_len(B)) {
syn <- sample_joint(model_bench, sample_size = n_syn)
x_syn <- as.numeric(syn$x)
y_syn <- as.numeric(syn$y)
z_syn <- as.data.frame(syn$z)
colnames(z_syn) <- paste0("z", 1:ncol(z_syn))
df_syn <- cbind(data.frame(A = x_syn, Y = y_syn), z_syn)
# --- Frengression ---
tryCatch({
mod_f <- frengression(x_dim = 1, y_dim = 1, z_dim = ncol(z_syn),
x_binary = TRUE, y_binary = TRUE,
noise_dim = 10, hidden_dim = 100, num_layer = 3)
mod_f <- train_y(mod_f, matrix(x_syn, ncol=1), as.matrix(z_syn),
matrix(y_syn, ncol=1),
num_iters = 500, lr = 1e-3, silent = TRUE)
mod_f <- train_xz(mod_f, matrix(x_syn, ncol=1), as.matrix(z_syn),
num_iters = 500, lr = 1e-4, silent = TRUE)
y1_f <- predict(mod_f, matrix(1, ncol=1), type="mean", nsample=1000)
y0_f <- predict(mod_f, matrix(0, ncol=1), type="mean", nsample=1000)
ate_mat[b, "Frengression"] <- y1_f - y0_f
ate_mc <- replicate(100, {
predict(mod_f, matrix(1,ncol=1), type="mean", nsample=200) -
predict(mod_f, matrix(0,ncol=1), type="mean", nsample=200)
})
se_mat[b, "Frengression"] <- sd(ate_mc)
ci_f <- ate_mat[b, "Frengression"] + c(-1,1) * 1.96 * se_mat[b, "Frengression"]
cover_mat[b, "Frengression"] <- (ci_f[1] <= true_ate_bench & true_ate_bench <= ci_f[2])
}, error = function(e) message("Frengression failed on rep ", b, ": ", e$message))
# --- npcausal ---
tryCatch({
fit_np <- ate(y = y_syn, a = x_syn, x = z_syn,
nsplits = 2, sl.lib = SL.lib_fast)
ate_mat[b, "npcausal"] <- fit_np$res$est[3]
se_mat[b, "npcausal"] <- fit_np$res$se[3]
ci <- c(fit_np$res$ci.ll[3], fit_np$res$ci.ul[3])
cover_mat[b, "npcausal"] <- (ci[1] <= true_ate_bench & true_ate_bench <= ci[2])
}, error = function(e) message("npcausal failed on rep ", b, ": ", e$message))
# --- TMLE ---
tryCatch({
fit_tmle <- tmle(Y = y_syn, A = x_syn, W = z_syn, family = "binomial",
Q.SL.library = SL.lib_fast, g.SL.library = SL.lib_fast)
ate_mat[b, "TMLE"] <- fit_tmle$estimates$ATE$psi
se_mat[b, "TMLE"] <- sqrt(fit_tmle$estimates$ATE$var.psi)
ci_t <- fit_tmle$estimates$ATE$CI
cover_mat[b, "TMLE"] <- (ci_t[1] <= true_ate_bench & true_ate_bench <= ci_t[2])
}, error = function(e) message("TMLE failed on rep ", b, ": ", e$message))
# --- DoubleML ---
tryCatch({
dml_d <- DoubleMLData$new(df_syn, y_col = "Y", d_cols = "A",
x_cols = colnames(z_syn))
lr_l <- lrn("regr.ranger", num.trees = 300, max.depth = 5, min.node.size = 2)
lr_m <- lr_l$clone()
dml_obj <- DoubleMLPLR$new(dml_d, ml_l = lr_l, ml_m = lr_m)
dml_obj$fit()
ate_mat[b, "DoubleML"] <- dml_obj$coef
se_mat[b, "DoubleML"] <- dml_obj$se
ci_d <- c(dml_obj$confint()[1], dml_obj$confint()[2])
cover_mat[b, "DoubleML"] <- (ci_d[1] <= true_ate_bench & true_ate_bench <= ci_d[2])
}, error = function(e) message("DoubleML failed on rep ", b, ": ", e$message))
if (b %% 5 == 0) cat("Completed replication", b, "of", B, "\n")
}
|
| | 0%
|
|========= | 12%
|
|================== | 25%
|
|========================== | 38%
|
|=================================== | 50%
|
|============================================ | 62%
|
|==================================================== | 75%
|
|============================================================= | 88%
|
|======================================================================| 100%
parameter est se ci.ll ci.ul pval
1 E{Y(0)} 0.51790846 0.02453065 0.46982839 0.5659885 0.000
2 E{Y(1)} 0.56966469 0.01261647 0.54493641 0.5943930 0.000
3 E{Y(1)-Y(0)} 0.05175623 0.02759015 -0.00232047 0.1058329 0.061
|
| | 0%
|
|========= | 12%
|
|================== | 25%
|
|========================== | 38%
|
|=================================== | 50%
|
|============================================ | 62%
|
|==================================================== | 75%
|
|============================================================= | 88%
|
|======================================================================| 100%
parameter est se ci.ll ci.ul pval
1 E{Y(0)} 0.1191438 0.02932246 0.06167176 0.1766158 0
2 E{Y(1)} 0.5974827 0.01237519 0.57322736 0.6217381 0
3 E{Y(1)-Y(0)} 0.4783389 0.03182103 0.41596972 0.5407082 0
|
| | 0%
|
|========= | 12%
|
|================== | 25%
|
|========================== | 38%
|
|=================================== | 50%
|
|============================================ | 62%
|
|==================================================== | 75%
|
|============================================================= | 88%
|
|======================================================================| 100%
parameter est se ci.ll ci.ul pval
1 E{Y(0)} 0.60836527 0.04825968 0.5137763 0.70295425 0.000
2 E{Y(1)} 0.57332669 0.01240628 0.5490104 0.59764300 0.000
3 E{Y(1)-Y(0)} -0.03503858 0.04983024 -0.1327058 0.06262869 0.482
|
| | 0%
|
|========= | 12%
|
|================== | 25%
|
|========================== | 38%
|
|=================================== | 50%
|
|============================================ | 62%
|
|==================================================== | 75%
|
|============================================================= | 88%
|
|======================================================================| 100%
parameter est se ci.ll ci.ul pval
1 E{Y(0)} 0.50239076 0.03315524 0.43740650 0.5673750 0.000
2 E{Y(1)} 0.55835634 0.01260901 0.53364268 0.5830700 0.000
3 E{Y(1)-Y(0)} 0.05596558 0.03548689 -0.01358872 0.1255199 0.115
|
| | 0%
|
|========= | 12%
|
|================== | 25%
|
|========================== | 38%
|
|=================================== | 50%
|
|============================================ | 62%
|
|==================================================== | 75%
|
|============================================================= | 88%
|
|======================================================================| 100%
parameter est se ci.ll ci.ul pval
1 E{Y(0)} 0.63973494 0.03345868 0.5741559 0.7053139655 0.000
2 E{Y(1)} 0.56976968 0.01242470 0.5454173 0.5941220914 0.000
3 E{Y(1)-Y(0)} -0.06996527 0.03550954 -0.1395640 -0.0003665801 0.049
Completed replication 5 of 10
|
| | 0%
|
|========= | 12%
|
|================== | 25%
|
|========================== | 38%
|
|=================================== | 50%
|
|============================================ | 62%
|
|==================================================== | 75%
|
|============================================================= | 88%
|
|======================================================================| 100%
parameter est se ci.ll ci.ul pval
1 E{Y(0)} 0.4412840 0.07174100 0.300671635 0.5818963 0.000
2 E{Y(1)} 0.5851423 0.01196745 0.561686098 0.6085985 0.000
3 E{Y(1)-Y(0)} 0.1438583 0.07275166 0.001265052 0.2864516 0.048
|
| | 0%
|
|========= | 12%
|
|================== | 25%
|
|========================== | 38%
|
|=================================== | 50%
|
|============================================ | 62%
|
|==================================================== | 75%
|
|============================================================= | 88%
|
|======================================================================| 100%
parameter est se ci.ll ci.ul pval
1 E{Y(0)} 0.50505959 0.10731072 0.2947306 0.7153886 0.000
2 E{Y(1)} 0.60104913 0.01216778 0.5772003 0.6248980 0.000
3 E{Y(1)-Y(0)} 0.09598954 0.10792671 -0.1155468 0.3075259 0.374
|
| | 0%
|
|========= | 12%
|
|================== | 25%
|
|========================== | 38%
|
|=================================== | 50%
|
|============================================ | 62%
|
|==================================================== | 75%
|
|============================================================= | 88%
|
|======================================================================| 100%
parameter est se ci.ll ci.ul pval
1 E{Y(0)} 0.48758840 0.05521542 0.37936618 0.5958106 0.000
2 E{Y(1)} 0.57245505 0.01242651 0.54809909 0.5968110 0.000
3 E{Y(1)-Y(0)} 0.08486665 0.05659988 -0.02606911 0.1958024 0.134
|
| | 0%
|
|========= | 12%
|
|================== | 25%
|
|========================== | 38%
|
|=================================== | 50%
|
|============================================ | 62%
|
|==================================================== | 75%
|
|============================================================= | 88%
|
|======================================================================| 100%
parameter est se ci.ll ci.ul pval
1 E{Y(0)} 0.561243323 0.06147479 0.4407527 0.6817339 0.000
2 E{Y(1)} 0.566465820 0.01351743 0.5399716 0.5929600 0.000
3 E{Y(1)-Y(0)} 0.005222498 0.06293277 -0.1181257 0.1285707 0.934
|
| | 0%
|
|========= | 12%
|
|================== | 25%
|
|========================== | 38%
|
|=================================== | 50%
|
|============================================ | 62%
|
|==================================================== | 75%
|
|============================================================= | 88%
|
|======================================================================| 100%
parameter est se ci.ll ci.ul pval
1 E{Y(0)} 0.4058837 0.03521900 0.3368545 0.4749129 0
2 E{Y(1)} 0.5948978 0.01629884 0.5629520 0.6268435 0
3 E{Y(1)-Y(0)} 0.1890141 0.03878461 0.1129962 0.2650319 0
Completed replication 10 of 10
bench_summary <- data.frame(
Method = colnames(ate_mat),
Mean_ATE = colMeans(ate_mat, na.rm = TRUE),
Bias = colMeans(ate_mat, na.rm = TRUE) - true_ate_bench,
SD = apply(ate_mat, 2, sd, na.rm = TRUE),
RMSE = sqrt(colMeans((ate_mat - true_ate_bench)^2, na.rm = TRUE)),
Mean_SE = colMeans(se_mat, na.rm = TRUE),
Coverage_95 = colMeans(cover_mat, na.rm = TRUE),
N_valid = colSums(!is.na(ate_mat))
)
kable(bench_summary, digits = 4,
caption = paste0("Monte Carlo Benchmarking (B=", B,
", n=", n_syn, ", True ATE=",
round(true_ate_bench, 4), ")"))| Method | Mean_ATE | Bias | SD | RMSE | Mean_SE | Coverage_95 | N_valid | |
|---|---|---|---|---|---|---|---|---|
| Frengression | Frengression | 0.0564 | -0.0264 | 0.0423 | 0.0481 | 0.0541 | 1.0 | 10 |
| npcausal | npcausal | 0.1000 | 0.0171 | 0.1541 | 0.1472 | 0.0519 | 0.6 | 10 |
| TMLE | TMLE | 0.2388 | 0.1559 | 0.2345 | 0.2717 | 0.0195 | 0.1 | 10 |
| DoubleML | DoubleML | 0.0841 | 0.0012 | 0.0638 | 0.0605 | 0.0590 | 1.0 | 10 |
ate_long <- as.data.frame(ate_mat) |>
pivot_longer(everything(), names_to = "Method", values_to = "ATE") |>
filter(!is.na(ATE))
# Trim extreme outliers for readable plot
q_bounds <- quantile(ate_long$ATE, c(0.02, 0.98), na.rm = TRUE)
ate_long_trim <- ate_long |> filter(ATE >= q_bounds[1] & ATE <= q_bounds[2])
ggplot(ate_long_trim, aes(x = Method, y = ATE, fill = Method)) +
geom_boxplot(alpha = 0.7) +
geom_hline(yintercept = true_ate_bench, linetype = "dashed", color = "red",
linewidth = 1) +
annotate("text", x = 0.5, y = true_ate_bench, label = "True ATE",
hjust = 0, vjust = -0.5, color = "red") +
labs(title = "ATE Estimates Across Frengression-Generated Datasets",
subtitle = paste0("B=", B, " replications, n=", n_syn, " per dataset"),
y = "Estimated ATE") +
theme_minimal() +
theme(legend.position = "none")
41.8.3 Monte Carlo Benchmarking (LaLonde)
Now the same exercise, but trained on real LaLonde data — where the confounding structure is unknown and complex. We already trained model_lal above and specified a known effect with model_lal_bench.
# Compute true ATE empirically from the specified model
y1_lal_true <- predict(model_lal_bench, matrix(1, ncol=1), type="mean", nsample=5000)
y0_lal_true <- predict(model_lal_bench, matrix(0, ncol=1), type="mean", nsample=5000)
true_ate_lal <- as.numeric(y1_lal_true - y0_lal_true)
cat("True ATE (standardized):", round(true_ate_lal, 4), "\n")True ATE (standardized): 0.287
cat("True ATE (dollars):", round(true_ate_lal * sd(y_lal), 0), "\n")True ATE (dollars): 2144
B_lal <- 10
n_syn_lal <- 2000
methods <- c("Frengression", "npcausal", "TMLE", "DoubleML")
ate_mat_lal <- matrix(NA, nrow = B_lal, ncol = length(methods),
dimnames = list(NULL, methods))
se_mat_lal <- ate_mat_lal
cover_mat_lal <- ate_mat_lal
SL.lib_fast <- c("SL.glm.interaction", "SL.ranger", "SL.mean")
set.seed(2024)
for (b in seq_len(B_lal)) {
syn <- sample_joint(model_lal_bench, sample_size = n_syn_lal)
x_syn <- as.numeric(syn$x)
y_syn <- as.numeric(syn$y)
z_syn <- as.data.frame(syn$z)
colnames(z_syn) <- paste0("z", 1:ncol(z_syn))
df_syn <- cbind(data.frame(A = x_syn, Y = y_syn), z_syn)
# --- Frengression ---
tryCatch({
mod_f <- frengression(x_dim = 1, y_dim = 1, z_dim = ncol(z_syn),
x_binary = TRUE, y_binary = FALSE,
noise_dim = 10, hidden_dim = 100, num_layer = 3)
mod_f <- train_y(mod_f, matrix(x_syn, ncol=1), as.matrix(z_syn),
matrix(y_syn, ncol=1),
num_iters = 500, lr = 1e-3, silent = TRUE)
mod_f <- train_xz(mod_f, matrix(x_syn, ncol=1), as.matrix(z_syn),
num_iters = 500, lr = 1e-4, silent = TRUE)
y1_f <- predict(mod_f, matrix(1, ncol=1), type="mean", nsample=1000)
y0_f <- predict(mod_f, matrix(0, ncol=1), type="mean", nsample=1000)
ate_mat_lal[b, "Frengression"] <- y1_f - y0_f
ate_mc <- replicate(100, {
predict(mod_f, matrix(1,ncol=1), type="mean", nsample=200) -
predict(mod_f, matrix(0,ncol=1), type="mean", nsample=200)
})
se_mat_lal[b, "Frengression"] <- sd(ate_mc)
ci_f <- ate_mat_lal[b, "Frengression"] + c(-1,1) * 1.96 * se_mat_lal[b, "Frengression"]
cover_mat_lal[b, "Frengression"] <- (ci_f[1] <= true_ate_lal & true_ate_lal <= ci_f[2])
}, error = function(e) message("Frengression failed on rep ", b, ": ", e$message))
# --- npcausal ---
tryCatch({
fit_np <- ate(y = y_syn, a = x_syn, x = z_syn,
nsplits = 2, sl.lib = SL.lib_fast)
ate_mat_lal[b, "npcausal"] <- fit_np$res$est[3]
se_mat_lal[b, "npcausal"] <- fit_np$res$se[3]
ci <- c(fit_np$res$ci.ll[3], fit_np$res$ci.ul[3])
cover_mat_lal[b, "npcausal"] <- (ci[1] <= true_ate_lal & true_ate_lal <= ci[2])
}, error = function(e) message("npcausal failed on rep ", b, ": ", e$message))
# --- TMLE ---
tryCatch({
fit_tmle <- tmle(Y = y_syn, A = x_syn, W = z_syn, family = "gaussian",
Q.SL.library = SL.lib_fast, g.SL.library = SL.lib_fast)
ate_mat_lal[b, "TMLE"] <- fit_tmle$estimates$ATE$psi
se_mat_lal[b, "TMLE"] <- sqrt(fit_tmle$estimates$ATE$var.psi)
ci_t <- fit_tmle$estimates$ATE$CI
cover_mat_lal[b, "TMLE"] <- (ci_t[1] <= true_ate_lal & true_ate_lal <= ci_t[2])
}, error = function(e) message("TMLE failed on rep ", b, ": ", e$message))
# --- DoubleML ---
tryCatch({
dml_d <- DoubleMLData$new(df_syn, y_col = "Y", d_cols = "A",
x_cols = colnames(z_syn))
lr_l <- lrn("regr.ranger", num.trees = 300, max.depth = 5, min.node.size = 2)
lr_m <- lr_l$clone()
dml_obj <- DoubleMLPLR$new(dml_d, ml_l = lr_l, ml_m = lr_m)
dml_obj$fit()
ate_mat_lal[b, "DoubleML"] <- dml_obj$coef
se_mat_lal[b, "DoubleML"] <- dml_obj$se
ci_d <- c(dml_obj$confint()[1], dml_obj$confint()[2])
cover_mat_lal[b, "DoubleML"] <- (ci_d[1] <= true_ate_lal & true_ate_lal <= ci_d[2])
}, error = function(e) message("DoubleML failed on rep ", b, ": ", e$message))
if (b %% 5 == 0) cat("Completed replication", b, "of", B_lal, "\n")
}
|
| | 0%
|
|========= | 12%
|
|================== | 25%
|
|========================== | 38%
|
|=================================== | 50%
|
|============================================ | 62%
|
|==================================================== | 75%
|
|============================================================= | 88%
|
|======================================================================| 100%
parameter est se ci.ll ci.ul pval
1 E{Y(0)} -0.07909122 0.02022728 -0.1187367 -0.03944574 0.000
2 E{Y(1)} -0.18524079 0.06501341 -0.3126671 -0.05781449 0.004
3 E{Y(1)-Y(0)} -0.10614957 0.06668041 -0.2368432 0.02454403 0.111
|
| | 0%
|
|========= | 12%
|
|================== | 25%
|
|========================== | 38%
|
|=================================== | 50%
|
|============================================ | 62%
|
|==================================================== | 75%
|
|============================================================= | 88%
|
|======================================================================| 100%
parameter est se ci.ll ci.ul pval
1 E{Y(0)} -0.1136350 0.01997280 -0.1527817 -0.07448836 0.000
2 E{Y(1)} -0.3627628 0.07751419 -0.5146906 -0.21083501 0.000
3 E{Y(1)-Y(0)} -0.2491278 0.07918114 -0.4043228 -0.09393275 0.002
|
| | 0%
|
|========= | 12%
|
|================== | 25%
|
|========================== | 38%
|
|=================================== | 50%
|
|============================================ | 62%
|
|==================================================== | 75%
|
|============================================================= | 88%
|
|======================================================================| 100%
parameter est se ci.ll ci.ul pval
1 E{Y(0)} -0.05088695 0.01954409 -0.08919336 -0.01258054 0.009
2 E{Y(1)} 0.08152119 0.07953985 -0.07437692 0.23741929 0.305
3 E{Y(1)-Y(0)} 0.13240813 0.08032237 -0.02502371 0.28983998 0.099
|
| | 0%
|
|========= | 12%
|
|================== | 25%
|
|========================== | 38%
|
|=================================== | 50%
|
|============================================ | 62%
|
|==================================================== | 75%
|
|============================================================= | 88%
|
|======================================================================| 100%
parameter est se ci.ll ci.ul pval
1 E{Y(0)} -0.05978268 0.02018534 -0.09934595 -0.02021942 0.003
2 E{Y(1)} -0.44774436 0.05369259 -0.55298183 -0.34250689 0.000
3 E{Y(1)-Y(0)} -0.38796168 0.05695379 -0.49959109 -0.27633226 0.000
|
| | 0%
|
|========= | 12%
|
|================== | 25%
|
|========================== | 38%
|
|=================================== | 50%
|
|============================================ | 62%
|
|==================================================== | 75%
|
|============================================================= | 88%
|
|======================================================================| 100%
parameter est se ci.ll ci.ul pval
1 E{Y(0)} -0.06119222 0.02013633 -0.1006594 -0.02172502 0.002
2 E{Y(1)} -0.21121393 0.06576119 -0.3401059 -0.08232201 0.001
3 E{Y(1)-Y(0)} -0.15002171 0.06793501 -0.2831743 -0.01686910 0.027
Completed replication 5 of 10
|
| | 0%
|
|========= | 12%
|
|================== | 25%
|
|========================== | 38%
|
|=================================== | 50%
|
|============================================ | 62%
|
|==================================================== | 75%
|
|============================================================= | 88%
|
|======================================================================| 100%
parameter est se ci.ll ci.ul pval
1 E{Y(0)} -0.1078971 0.02005173 -0.1471985 -0.06859574 0
2 E{Y(1)} -0.6635015 0.07570968 -0.8118925 -0.51511058 0
3 E{Y(1)-Y(0)} -0.5556044 0.07750740 -0.7075189 -0.40368990 0
|
| | 0%
|
|========= | 12%
|
|================== | 25%
|
|========================== | 38%
|
|=================================== | 50%
|
|============================================ | 62%
|
|==================================================== | 75%
|
|============================================================= | 88%
|
|======================================================================| 100%
parameter est se ci.ll ci.ul pval
1 E{Y(0)} -0.10382821 0.0198795 -0.1427920 -0.0648644 0.000
2 E{Y(1)} -0.01801598 0.1005460 -0.2150861 0.1790541 0.858
3 E{Y(1)-Y(0)} 0.08581223 0.1016927 -0.1135055 0.2851300 0.399
|
| | 0%
|
|========= | 12%
|
|================== | 25%
|
|========================== | 38%
|
|=================================== | 50%
|
|============================================ | 62%
|
|==================================================== | 75%
|
|============================================================= | 88%
|
|======================================================================| 100%
parameter est se ci.ll ci.ul pval
1 E{Y(0)} -0.0807125 0.02039841 -0.1206934 -0.04073161 0
2 E{Y(1)} -0.3672311 0.04252015 -0.4505706 -0.28389158 0
3 E{Y(1)-Y(0)} -0.2865186 0.04635946 -0.3773831 -0.19565404 0
|
| | 0%
|
|========= | 12%
|
|================== | 25%
|
|========================== | 38%
|
|=================================== | 50%
|
|============================================ | 62%
|
|==================================================== | 75%
|
|============================================================= | 88%
|
|======================================================================| 100%
parameter est se ci.ll ci.ul pval
1 E{Y(0)} -0.06658224 0.02052495 -0.1068111 -0.026353335 0.001
2 E{Y(1)} -0.23176493 0.08138659 -0.3912826 -0.072247221 0.004
3 E{Y(1)-Y(0)} -0.16518270 0.08365006 -0.3291368 -0.001228585 0.048
|
| | 0%
|
|========= | 12%
|
|================== | 25%
|
|========================== | 38%
|
|=================================== | 50%
|
|============================================ | 62%
|
|==================================================== | 75%
|
|============================================================= | 88%
|
|======================================================================| 100%
parameter est se ci.ll ci.ul pval
1 E{Y(0)} -0.09248716 0.01987075 -0.1314338 -0.05354048 0.000
2 E{Y(1)} -0.20270221 0.12972404 -0.4569613 0.05155691 0.118
3 E{Y(1)-Y(0)} -0.11021505 0.13129867 -0.3675604 0.14713035 0.401
Completed replication 10 of 10
bench_lal <- data.frame(
Method = colnames(ate_mat_lal),
Mean_ATE = colMeans(ate_mat_lal, na.rm = TRUE),
Bias = colMeans(ate_mat_lal, na.rm = TRUE) - true_ate_lal,
SD = apply(ate_mat_lal, 2, sd, na.rm = TRUE),
RMSE = sqrt(colMeans((ate_mat_lal - true_ate_lal)^2, na.rm = TRUE)),
Mean_SE = colMeans(se_mat_lal, na.rm = TRUE),
Coverage_95 = colMeans(cover_mat_lal, na.rm = TRUE),
N_valid = colSums(!is.na(ate_mat_lal))
)
kable(bench_lal, digits = 4,
caption = paste0("LaLonde-Based Benchmarking (B=", B_lal,
", n=", n_syn_lal,
", True ATE=", round(true_ate_lal, 4),
" [", round(true_ate_lal * sd(y_lal), 0), " dollars])"))| Method | Mean_ATE | Bias | SD | RMSE | Mean_SE | Coverage_95 | N_valid | |
|---|---|---|---|---|---|---|---|---|
| Frengression | Frengression | 0.0805 | -0.2065 | 0.1345 | 0.2427 | 0.0885 | 0.4 | 10 |
| npcausal | npcausal | -0.1793 | -0.4663 | 0.2055 | 0.5054 | 0.0792 | 0.1 | 10 |
| TMLE | TMLE | 0.1701 | -0.1169 | 0.3174 | 0.3230 | 0.0386 | 0.2 | 10 |
| DoubleML | DoubleML | 0.0619 | -0.2251 | 0.0859 | 0.2394 | 0.1022 | 0.4 | 10 |
ate_long_lal <- as.data.frame(ate_mat_lal) |>
pivot_longer(everything(), names_to = "Method", values_to = "ATE") |>
filter(!is.na(ATE))
q_bounds_lal <- quantile(ate_long_lal$ATE, c(0.02, 0.98), na.rm = TRUE)
ate_long_lal_trim <- ate_long_lal |>
filter(ATE >= q_bounds_lal[1] & ATE <= q_bounds_lal[2])
ggplot(ate_long_lal_trim, aes(x = Method, y = ATE, fill = Method)) +
geom_boxplot(alpha = 0.7) +
geom_hline(yintercept = true_ate_lal, linetype = "dashed", color = "red",
linewidth = 1) +
annotate("text", x = 0.5, y = true_ate_lal, label = "True ATE",
hjust = 0, vjust = -0.5, color = "red") +
labs(title = "ATE Estimates: LaLonde-Based Frengression DGP",
subtitle = paste0("B=", B_lal, " replications, n=", n_syn_lal,
" per dataset (confounding learned from real data)"),
y = "Estimated ATE (standardized)") +
theme_minimal() +
theme(legend.position = "none")
The semiparametric methods show meaningful bias on the LaLonde-generated data, suggesting that neural-network-generated confounding is harder to adjust for than hand-specified simulations. The synthetic data inherits LaLonde’s messy structure — complex propensity score surfaces, zero-inflated prior earnings, demographic imbalances — which violates the smoothness assumptions that SuperLearner and ranger rely on.
41.9 Summary
- Frengression implements the same causal simulation framework as
causl(Evans & Didelez, 2024), but learns the decomposition nonparametrically - Three networks:
model_xz(past),model_y(causal margin),model_eta(copula/dependence) - The
specify_causal()function lets you plug in known effects while keeping learned confounding - Main use case: realistic Monte Carlo benchmarking with ground truth
- R package:
Rfrengression(native R torch, no Python dependency) - Paper: Yang, Evans, and Shen (2025), “Frugal, Flexible, Faithful: Causal Data Simulation via Frengression” (arXiv:2508.01018)