Crumble.jl Main Vignette

Crumble.jl supports several mediation estimands:

  • Natural direct and indirect effects
  • Organic direct and indirect effects
  • Randomized interventional effects
  • Recanting twins decompositions

Basic Usage

result = crumble(
    data,
    ["A"],
    outcome = "Y",
    mediators = ["M"],
    covar = ["W1", "W2"],
    d0 = d0,
    d1 = d1,
    effect = "N",
    control = crumble_control(crossfit_folds = 2, epochs = 1, batch_size = 32)
)

Crumble.tidy(result)
3×6 DataFrame
Rowestimandestimatestd_errorconf_lowconf_highp_value
StringFloat64Float64Float64Float64Float64
1direct-0.0007543280.641549-1.258191.256680.999062
2ate-0.001375360.641355-1.258431.255680.998289
3indirect-0.0006210290.641579-1.258121.256870.999228

Recanting Twins

result = crumble(
    data,
    ["A"],
    outcome = "Y",
    mediators = ["M"],
    moc = ["Z"],
    covar = ["W1", "W2"],
    d0 = d0,
    d1 = d1,
    effect = "RT",
    control = crumble_control(crossfit_folds = 2, epochs = 1, batch_size = 32)
)

print(result)
┌ Warning: Layer with Float32 parameters got Float64 input.
│   The input will be converted, but any earlier layers may be very slow.
│   layer = Dense(3 => 20, elu)  # 80 parameters
│   summary(x) = "3×32 adjoint(::Matrix{Float64}) with eltype Float64"
└ @ Flux ~/.julia/packages/Flux/DZYiO/src/layers/stateless.jl:60
CrumbleResult
  Effect type: RT

Estimates:
  Direct Effect                              0.0001 (SE:   0.6468) [95% CI:  -1.2676,   1.2679]
  Average Treatment Effect                  -0.0022 (SE:   0.6477) [95% CI:  -1.2717,   1.2673]
  Indirect Effect                           -0.0023 (SE:   0.6474) [95% CI:  -1.2712,   1.2666]

Custom Neural Networks

custom_nn = sequential_module(layers = 2, hidden = 32, dropout = 0.2)
custom_nn(3)
Chain(
  Dense(3 => 32, elu),                  # 128 parameters
  Chain(
    Chain(
      Dense(32 => 32, elu),             # 1_056 parameters
    ),
    Chain(
      Dense(32 => 32, elu),             # 1_056 parameters
    ),
  ),
  Dense(32 => 1),                       # 33 parameters
  Dropout(0.2),
  NNlib.softplus,
)                   # Total: 8 arrays, 2_273 parameters, 9.309 KiB.