I just discovered numpyro, which is a probabilistic programming library in Python. It can be coupled with Jax and is very powerful.
Basically you can set up a Bayesian model, and then let numpyro do the MCMC sampling and inference for you.
Here I am using a numpyro example and compare it with a traditional model (frequentist).
Example 1
I am using the numpyro example here: https://num.pyro.ai/en/stable/tutorials/bayesian_hierarchical_linear_regression.html
The data set is from https://www.kaggle.com/c/osic-pulmonary-fibrosis-progression
“Pulmonary fibrosis is a disorder with no known cause and no known cure, created by scarring of the lungs. In this competition, we were asked to predict a patient’s severity of decline in lung function. Lung function is assessed based on output from a spirometer, which measures the forced vital capacity (FVC), i.e. the volume of air exhaled.
In medical applications, it is useful to evaluate a model’s confidence in its decisions. Accordingly, the metric used to rank the teams was designed to reflect both the accuracy and certainty of each prediction.”
I read it in R first.
library (tidyverse)
# read in csv file
library (readr)
data <- read_csv ("./osic_pulmonary_fibrosis.csv" ) |>
arrange (Patient, Weeks)
data
# A tibble: 1,549 × 7
Patient Weeks FVC Percent Age Sex SmokingStatus
<chr> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
1 ID00007637202177411956430 -4 2315 58.3 79 Male Ex-smoker
2 ID00007637202177411956430 5 2214 55.7 79 Male Ex-smoker
3 ID00007637202177411956430 7 2061 51.9 79 Male Ex-smoker
4 ID00007637202177411956430 9 2144 54.0 79 Male Ex-smoker
5 ID00007637202177411956430 11 2069 52.1 79 Male Ex-smoker
6 ID00007637202177411956430 17 2101 52.9 79 Male Ex-smoker
7 ID00007637202177411956430 29 2000 50.3 79 Male Ex-smoker
8 ID00007637202177411956430 41 2064 51.9 79 Male Ex-smoker
9 ID00007637202177411956430 57 2057 51.8 79 Male Ex-smoker
10 ID00009637202177434476278 8 3660 85.3 69 Male Ex-smoker
# ℹ 1,539 more rows
a random effect model
We’d do a random effect model on interectp and slope of weeks, with a linear trend in weeks.
# a random effect model with FVC as DV, and a linear time trend, with random effect on Patient
library (lme4)
reg1 <- lmer (FVC ~ Weeks + (1 + Weeks | Patient), data = data)
summary (reg1)
Linear mixed model fit by REML ['lmerMod']
Formula: FVC ~ Weeks + (1 + Weeks | Patient)
Data: data
REML criterion at convergence: 20891.4
Scaled residuals:
Min 1Q Median 3Q Max
-9.3209 -0.4257 0.0063 0.4458 5.6990
Random effects:
Groups Name Variance Std.Dev. Corr
Patient (Intercept) 686691.11 828.668
Weeks 25.88 5.087 -0.14
Residual 18592.11 136.353
Number of obs: 1549, groups: Patient, 176
Fixed effects:
Estimate Std. Error t value
(Intercept) 2810.2843 62.9497 44.643
Weeks -4.2629 0.4377 -9.739
Correlation of Fixed Effects:
(Intr)
Weeks -0.174
Then we do the same with a Bayesian hierarchical model. Should have similar results.
hierarchical model in numpyro
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
train = pd.read_csv(
"https://gist.githubusercontent.com/ucals/"
"2cf9d101992cb1b78c2cdd6e3bac6a4b/raw/"
"43034c39052dcf97d4b894d2ec1bc3f90f3623d9/"
"osic_pulmonary_fibrosis.csv"
)
from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
assert numpyro.__version__.startswith("0.19.0" )
def model(patient_code, Weeks, FVC_obs= None ):
μ_α = numpyro.sample("μ_α" , dist.Normal(0.0 , 500.0 ))
σ_α = numpyro.sample("σ_α" , dist.HalfNormal(100.0 ))
μ_β = numpyro.sample("μ_β" , dist.Normal(0.0 , 3.0 ))
σ_β = numpyro.sample("σ_β" , dist.HalfNormal(3.0 ))
n_patients = len (np.unique(patient_code))
with numpyro.plate("plate_i" , n_patients):
α = numpyro.sample("α" , dist.Normal(μ_α, σ_α))
β = numpyro.sample("β" , dist.Normal(μ_β, σ_β))
σ = numpyro.sample("σ" , dist.HalfNormal(100.0 ))
FVC_est = α[patient_code] + β[patient_code] * Weeks
with numpyro.plate("data" , len (patient_code)):
numpyro.sample("obs" , dist.Normal(FVC_est, σ), obs= FVC_obs)
from sklearn.preprocessing import LabelEncoder
patient_encoder = LabelEncoder()
train["patient_code" ] = patient_encoder.fit_transform(train["Patient" ].values)
FVC_obs = train["FVC" ].values
Weeks = train["Weeks" ].values
patient_code = train["patient_code" ].values
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples= 2000 , num_warmup= 2000 , progress_bar= False )
rng_key = random.PRNGKey(0 )
mcmc.run(rng_key, patient_code, Weeks, FVC_obs= FVC_obs)
posterior_samples = mcmc.get_samples()
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
α[0] 2181.96 60.74 2182.26 2080.02 2276.12 3478.15 1.00
α[1] 3783.49 73.96 3784.85 3667.70 3913.58 4928.66 1.00
α[2] 3265.39 58.09 3264.75 3175.63 3366.59 4088.11 1.00
α[3] 3478.90 66.67 3477.34 3371.61 3589.76 3993.87 1.00
α[4] 3700.88 118.15 3699.46 3507.03 3892.74 4173.49 1.00
α[5] 3724.77 59.33 3725.87 3623.35 3817.91 5760.40 1.00
α[6] 3050.36 87.13 3052.33 2901.09 3187.95 4301.39 1.00
α[7] 2151.55 78.65 2149.32 2031.82 2287.03 4117.25 1.00
α[8] 2344.04 89.13 2344.16 2197.93 2493.25 4350.38 1.00
α[9] 1496.10 59.90 1495.72 1400.21 1595.97 3821.66 1.00
α[10] 3147.59 59.34 3146.73 3054.35 3247.87 3823.34 1.00
α[11] 3061.16 134.88 3059.77 2859.88 3302.93 4108.39 1.00
α[12] 2776.40 81.49 2777.66 2638.22 2904.36 5336.77 1.00
α[13] 2681.60 125.34 2682.50 2470.12 2876.24 5991.18 1.00
α[14] 5195.74 116.71 5196.11 5008.86 5385.35 4844.13 1.00
α[15] 2780.42 75.71 2780.33 2663.12 2909.89 4404.98 1.00
α[16] 3924.87 78.16 3923.05 3795.71 4053.21 4959.80 1.00
α[17] 3613.71 61.09 3613.50 3513.57 3715.78 3498.32 1.00
α[18] 3276.12 66.24 3276.51 3162.19 3382.03 2883.28 1.00
α[19] 1477.57 69.84 1478.19 1352.44 1579.33 3962.45 1.00
α[20] 1703.49 56.27 1704.27 1612.68 1796.76 2952.87 1.00
α[21] 2984.97 218.20 2988.79 2625.00 3337.65 5150.51 1.00
α[22] 3392.43 93.03 3391.44 3252.03 3554.48 3935.83 1.00
α[23] 3637.04 103.11 3639.29 3468.56 3799.16 4211.06 1.00
α[24] 2483.06 144.52 2483.37 2247.02 2721.52 5069.97 1.00
α[25] 3453.93 66.54 3455.23 3345.72 3565.21 4795.24 1.00
α[26] 2834.53 75.77 2834.91 2716.25 2961.84 4413.27 1.00
α[27] 2419.90 83.80 2420.70 2289.84 2568.34 4094.23 1.00
α[28] 2495.48 68.58 2494.98 2379.58 2601.87 5322.98 1.00
α[29] 1828.49 109.39 1829.96 1650.90 2013.15 3854.51 1.00
α[30] 2387.46 59.38 2387.19 2291.45 2484.52 4065.98 1.00
α[31] 3258.15 64.09 3259.99 3156.46 3366.40 4339.93 1.00
α[32] 1885.49 63.12 1886.79 1779.17 1986.83 3903.91 1.00
α[33] 2946.67 109.14 2945.00 2761.04 3113.03 4230.53 1.00
α[34] 3373.80 62.44 3373.85 3277.11 3478.82 4635.49 1.00
α[35] 2537.18 72.11 2536.82 2427.51 2662.56 4037.66 1.00
α[36] 3219.40 92.97 3219.59 3067.16 3361.41 3682.34 1.00
α[37] 3578.86 64.44 3579.33 3473.77 3685.27 4499.79 1.00
α[38] 4806.80 67.82 4806.16 4704.59 4926.04 4021.18 1.00
α[39] 3142.70 128.24 3140.92 2944.80 3357.44 4483.11 1.00
α[40] 2303.84 64.02 2304.47 2191.04 2403.73 4225.49 1.00
α[41] 2567.85 104.45 2568.37 2408.97 2752.63 4547.57 1.00
α[42] 2746.64 132.81 2749.36 2525.71 2960.18 4787.09 1.00
α[43] 2856.82 79.62 2858.73 2734.38 2990.30 5177.49 1.00
α[44] 2553.31 161.52 2556.90 2279.42 2807.01 4763.51 1.00
α[45] 2028.61 99.29 2029.85 1870.11 2198.04 4007.81 1.00
α[46] 2415.84 85.37 2412.85 2274.98 2551.99 4389.82 1.00
α[47] 2559.61 91.70 2560.16 2415.14 2713.06 4703.63 1.00
α[48] 2903.48 67.59 2903.82 2783.52 3003.76 3074.21 1.00
α[49] 2916.36 61.66 2917.28 2815.96 3018.28 3231.64 1.00
α[50] 2433.25 62.36 2433.10 2333.02 2535.01 4112.43 1.00
α[51] 1874.57 149.88 1874.53 1636.37 2125.48 4466.82 1.00
α[52] 3885.27 152.80 3884.64 3634.18 4129.37 4764.01 1.00
α[53] 2453.30 74.16 2454.34 2332.25 2572.95 4032.64 1.00
α[54] 2289.13 98.67 2291.88 2113.48 2432.82 4228.26 1.00
α[55] 1968.45 148.75 1967.96 1727.68 2223.98 4738.61 1.00
α[56] 2448.41 168.45 2449.80 2177.20 2731.69 4271.57 1.00
α[57] 2284.98 61.21 2285.08 2188.96 2386.97 5302.78 1.00
α[58] 1690.54 75.88 1692.69 1572.17 1820.44 4384.45 1.00
α[59] 3652.09 175.83 3651.85 3371.15 3952.49 5137.58 1.00
α[60] 2659.02 72.32 2658.87 2540.28 2776.74 4343.19 1.00
α[61] 3162.24 59.04 3164.49 3066.44 3260.57 3628.39 1.00
α[62] 2697.13 68.00 2697.47 2589.83 2809.31 3821.55 1.00
α[63] 3943.73 119.42 3946.85 3729.60 4121.03 4423.43 1.00
α[64] 3339.19 64.10 3340.30 3229.28 3435.80 3789.94 1.00
α[65] 4422.48 85.06 4424.31 4280.59 4564.35 6713.01 1.00
α[66] 4017.49 92.23 4016.28 3868.26 4169.53 4683.90 1.00
α[67] 4089.30 82.06 4091.63 3959.81 4223.71 4983.39 1.00
α[68] 3549.15 138.06 3549.95 3327.27 3776.16 3890.86 1.00
α[69] 2901.59 73.02 2900.42 2790.57 3030.32 4683.16 1.00
α[70] 2181.65 219.45 2175.94 1841.71 2541.47 4550.12 1.00
α[71] 4770.96 74.12 4771.11 4654.07 4898.68 4209.58 1.00
α[72] 3148.67 68.49 3149.51 3038.27 3261.98 4719.43 1.00
α[73] 2568.25 84.10 2568.06 2422.14 2691.78 4620.57 1.00
α[74] 2071.70 135.79 2069.03 1851.62 2289.57 4258.39 1.00
α[75] 2805.05 126.29 2808.18 2587.92 3004.50 4962.74 1.00
α[76] 2851.84 84.64 2854.08 2708.30 2982.13 4178.09 1.00
α[77] 3069.15 75.47 3067.51 2947.36 3191.41 4765.39 1.00
α[78] 2473.79 134.71 2472.45 2246.52 2689.86 4316.09 1.00
α[79] 3176.65 112.70 3181.20 2992.57 3366.85 4780.32 1.00
α[80] 2918.38 75.56 2919.69 2795.13 3045.32 4288.53 1.00
α[81] 4380.73 73.78 4380.60 4256.85 4500.40 4449.85 1.00
α[82] 1901.83 73.17 1900.47 1776.18 2016.42 3629.52 1.00
α[83] 3939.29 74.29 3939.85 3822.65 4061.64 4840.37 1.00
α[84] 2302.36 64.27 2301.19 2197.80 2406.84 3878.34 1.00
α[85] 2803.07 66.43 2803.56 2706.09 2922.83 4427.60 1.00
α[86] 4430.48 94.38 4430.86 4275.12 4581.68 4876.03 1.00
α[87] 2862.12 122.95 2862.95 2663.85 3066.27 5159.40 1.00
α[88] 2831.12 59.21 2830.93 2738.42 2929.85 3505.62 1.00
α[89] 3174.82 69.33 3175.22 3065.44 3291.17 2897.17 1.00
α[90] 2907.47 63.66 2908.82 2807.63 3013.57 3894.72 1.00
α[91] 2321.38 73.67 2321.39 2197.67 2435.79 4540.13 1.00
α[92] 2534.82 92.63 2533.36 2380.92 2685.90 4466.08 1.00
α[93] 5897.89 61.90 5898.57 5798.89 6001.60 3792.77 1.00
α[94] 1493.36 77.54 1494.47 1357.99 1613.48 4334.43 1.00
α[95] 2553.11 64.25 2552.00 2446.32 2654.63 5198.04 1.00
α[96] 2703.26 61.69 2704.33 2597.79 2800.33 3588.30 1.00
α[97] 1254.25 84.35 1254.03 1111.30 1386.08 4096.89 1.00
α[98] 1977.65 81.96 1975.93 1850.62 2117.53 4584.07 1.00
α[99] 3626.00 94.04 3624.16 3453.35 3765.02 6152.51 1.00
α[100] 1960.98 74.47 1960.78 1839.09 2081.26 5508.20 1.00
α[101] 3618.52 56.55 3618.62 3520.58 3705.30 3322.28 1.00
α[102] 1974.77 65.12 1974.23 1860.57 2073.28 3503.83 1.00
α[103] 2105.64 128.71 2105.25 1886.33 2306.98 6616.46 1.00
α[104] 2910.32 68.85 2910.23 2789.15 3014.76 4584.25 1.00
α[105] 1401.37 63.49 1401.44 1303.52 1512.56 4920.60 1.00
α[106] 1840.26 65.45 1840.61 1734.82 1946.79 4199.05 1.00
α[107] 1381.50 80.50 1379.99 1243.57 1515.55 5144.95 1.00
α[108] 3918.88 134.89 3916.91 3674.47 4118.55 4558.49 1.00
α[109] 2231.49 103.29 2231.06 2077.49 2414.70 4677.99 1.00
α[110] 1565.15 66.49 1567.76 1456.32 1674.67 3762.40 1.00
α[111] 2075.41 87.26 2076.29 1941.11 2223.85 4193.44 1.00
α[112] 1065.46 73.31 1067.54 938.10 1176.44 4130.45 1.00
α[113] 2863.31 68.76 2865.82 2749.60 2974.66 4549.16 1.00
α[114] 2517.66 146.55 2517.04 2289.99 2768.97 4708.89 1.00
α[115] 2973.36 73.24 2973.90 2847.22 3086.66 4239.57 1.00
α[116] 4160.73 70.88 4160.89 4043.80 4277.77 4756.33 1.00
α[117] 1981.23 59.08 1981.86 1887.80 2081.02 3847.90 1.00
α[118] 1402.20 63.46 1401.69 1303.28 1512.53 3938.85 1.00
α[119] 2096.95 108.78 2098.23 1910.34 2266.26 4640.21 1.00
α[120] 2363.06 80.42 2361.92 2237.11 2505.73 3919.36 1.00
α[121] 4055.78 59.60 4055.97 3955.94 4151.83 3641.06 1.00
α[122] 2289.47 66.27 2286.09 2188.57 2405.43 4131.95 1.00
α[123] 2191.83 74.29 2192.15 2071.21 2315.32 3979.08 1.00
α[124] 1923.23 66.87 1922.47 1815.24 2032.27 3807.17 1.00
α[125] 2250.75 100.50 2249.12 2086.16 2409.58 5053.45 1.00
α[126] 2915.61 75.51 2915.58 2794.77 3037.53 5279.26 1.00
α[127] 3239.20 64.65 3237.52 3144.70 3354.14 4017.07 1.00
α[128] 3222.20 66.91 3222.45 3109.25 3328.30 3365.91 1.00
α[129] 1859.73 61.92 1858.81 1758.87 1964.10 4568.50 1.00
α[130] 2614.03 115.54 2615.68 2420.21 2804.06 4947.98 1.00
α[131] 3910.90 81.49 3911.47 3779.04 4048.18 4029.07 1.00
α[132] 2618.84 63.93 2618.52 2509.26 2720.43 4529.40 1.00
α[133] 1617.22 126.03 1614.94 1406.29 1819.48 4226.61 1.00
α[134] 2891.23 130.63 2892.29 2684.80 3108.29 4700.87 1.00
α[135] 3080.86 116.70 3078.24 2867.59 3246.76 5356.82 1.00
α[136] 2736.10 65.37 2737.26 2632.36 2845.87 4798.27 1.00
α[137] 2616.46 70.43 2617.43 2499.96 2731.58 4693.34 1.00
α[138] 3931.41 113.71 3931.37 3744.14 4120.55 6092.02 1.00
α[139] 3540.65 66.25 3541.71 3435.03 3652.82 4975.09 1.00
α[140] 2319.48 64.50 2319.73 2214.62 2429.24 3795.68 1.00
α[141] 2933.86 114.23 2931.55 2746.68 3115.25 4674.73 1.00
α[142] 2405.33 99.42 2404.73 2248.27 2568.26 4378.19 1.00
α[143] 1608.70 94.94 1607.20 1452.57 1764.13 4365.56 1.00
α[144] 2065.78 66.48 2066.25 1962.33 2177.87 4630.61 1.00
α[145] 2608.69 80.14 2607.56 2476.55 2740.32 5272.79 1.00
α[146] 4202.96 105.58 4201.76 4034.35 4378.06 5225.38 1.00
α[147] 2052.41 75.69 2052.85 1930.84 2175.75 4737.47 1.00
α[148] 2888.84 66.31 2888.21 2785.91 3000.87 4069.44 1.00
α[149] 3538.52 130.67 3538.85 3311.72 3734.16 5014.65 1.00
α[150] 2671.02 152.86 2672.30 2421.11 2920.64 5058.42 1.00
α[151] 1734.95 130.10 1735.16 1529.26 1966.11 4134.36 1.00
α[152] 3036.62 77.43 3037.41 2913.63 3163.50 5443.08 1.00
α[153] 2280.62 80.84 2280.98 2157.90 2419.70 6519.05 1.00
α[154] 2535.44 67.69 2535.98 2424.18 2645.61 4251.41 1.00
α[155] 4401.42 142.25 4403.37 4172.91 4646.13 4468.32 1.00
α[156] 3468.31 103.94 3469.87 3318.57 3657.42 4614.79 1.00
α[157] 2300.55 74.07 2301.65 2186.87 2430.85 4202.74 1.00
α[158] 3354.23 157.23 3350.28 3104.44 3608.42 4578.01 1.00
α[159] 3142.29 62.26 3142.13 3032.66 3240.83 4841.13 1.00
α[160] 2204.44 70.52 2202.80 2089.87 2320.99 4692.09 1.00
α[161] 3737.23 61.33 3736.79 3634.25 3833.27 3300.82 1.00
α[162] 3000.90 86.93 3000.32 2858.64 3146.88 4496.44 1.00
α[163] 3965.14 100.36 3964.14 3793.29 4120.08 3782.56 1.00
α[164] 2020.35 142.58 2021.99 1791.25 2264.88 5086.22 1.00
α[165] 1580.13 64.85 1579.95 1486.46 1700.85 2607.56 1.00
α[166] 3250.36 94.84 3251.18 3103.08 3407.85 4652.29 1.00
α[167] 2455.41 89.25 2455.88 2307.87 2606.59 4002.42 1.00
α[168] 1770.46 70.57 1770.27 1659.72 1888.74 4205.60 1.00
α[169] 2892.62 76.82 2890.41 2774.02 3024.54 3921.34 1.00
α[170] 3357.47 69.56 3356.52 3252.30 3478.19 4614.54 1.00
α[171] 2836.92 67.01 2839.11 2729.38 2944.73 4768.27 1.00
α[172] 2828.49 81.06 2828.37 2687.76 2949.13 4523.36 1.00
α[173] 1989.75 77.96 1989.84 1854.52 2106.57 3484.11 1.00
α[174] 3007.70 85.53 3010.05 2869.75 3152.74 4488.10 1.00
α[175] 2931.63 67.48 2931.62 2813.62 3032.52 4880.61 1.00
β[0] -3.49 2.13 -3.52 -6.86 0.11 3777.86 1.00
β[1] -8.02 2.36 -8.00 -11.68 -4.02 4356.88 1.00
β[2] -14.27 2.28 -14.26 -17.84 -10.21 3486.20 1.00
β[3] -4.27 2.31 -4.27 -7.77 -0.22 4751.57 1.00
β[4] -7.13 2.19 -7.13 -10.57 -3.43 4758.71 1.00
β[5] -11.69 2.32 -11.69 -15.47 -7.84 4411.74 1.00
β[6] -6.71 2.27 -6.70 -10.40 -3.02 4800.55 1.00
β[7] -5.62 2.20 -5.63 -9.54 -2.30 4401.94 1.00
β[8] -6.81 2.27 -6.88 -10.33 -2.83 4538.73 1.00
β[9] -3.06 2.28 -3.04 -6.67 0.83 5195.12 1.00
β[10] -9.32 2.33 -9.33 -13.17 -5.60 3699.40 1.00
β[11] -3.07 2.15 -3.07 -6.56 0.49 4320.79 1.00
β[12] -5.17 2.26 -5.14 -8.90 -1.54 4824.94 1.00
β[13] -4.64 2.92 -4.56 -9.56 -0.14 4782.43 1.00
β[14] -11.04 2.19 -11.04 -14.47 -7.42 4283.43 1.00
β[15] -8.45 2.30 -8.48 -12.12 -4.57 4260.51 1.00
β[16] -1.18 2.36 -1.12 -5.03 2.76 4967.28 1.00
β[17] -0.16 2.14 -0.19 -3.62 3.29 3726.29 1.00
β[18] -6.65 3.82 -6.71 -13.21 -0.86 3985.31 1.00
β[19] -2.16 2.20 -2.18 -5.76 1.43 3160.09 1.00
β[20] -3.71 2.20 -3.68 -7.30 -0.02 3877.82 1.00
β[21] -11.26 3.68 -11.34 -17.13 -5.13 4327.97 1.00
β[22] -5.95 2.40 -5.94 -9.97 -2.07 4703.96 1.00
β[23] 5.92 2.37 5.92 2.08 9.67 4153.01 1.00
β[24] -4.07 2.26 -4.07 -7.76 -0.37 5563.09 1.00
β[25] -5.09 3.02 -5.11 -10.29 -0.54 3885.94 1.00
β[26] -6.07 2.34 -6.07 -10.20 -2.40 5647.53 1.00
β[27] -2.60 2.13 -2.55 -6.46 0.69 4187.02 1.00
β[28] -13.06 3.09 -13.13 -18.11 -7.93 4795.54 1.00
β[29] -7.67 2.38 -7.71 -11.57 -3.82 4117.66 1.00
β[30] 8.04 2.25 8.04 4.13 11.55 4600.00 1.00
β[31] 2.45 2.26 2.47 -1.03 6.30 4007.59 1.00
β[32] -6.58 2.23 -6.59 -10.21 -2.82 3832.25 1.00
β[33] -0.80 2.37 -0.81 -4.99 2.80 4361.64 1.00
β[34] -0.83 2.16 -0.83 -4.16 2.91 4025.73 1.00
β[35] -10.61 2.40 -10.61 -14.60 -6.78 3563.09 1.00
β[36] -5.19 2.17 -5.18 -8.89 -1.93 3299.45 1.00
β[37] -0.25 2.42 -0.19 -4.01 3.83 3704.36 1.00
β[38] -2.73 2.35 -2.70 -6.32 1.24 4342.30 1.00
β[39] -3.52 1.84 -3.51 -6.77 -0.84 3976.06 1.00
β[40] -1.34 2.30 -1.37 -5.12 2.42 3742.77 1.00
β[41] -11.93 2.96 -11.83 -16.93 -7.23 4682.23 1.00
β[42] -5.60 1.96 -5.62 -8.69 -2.24 4732.31 1.00
β[43] -4.12 2.35 -4.08 -7.66 -0.08 4616.45 1.00
β[44] -5.79 2.20 -5.80 -9.27 -2.09 4924.40 1.00
β[45] -19.96 3.06 -19.91 -24.63 -14.61 3679.62 1.00
β[46] -7.46 3.85 -7.50 -13.75 -1.59 4155.61 1.00
β[47] -5.62 2.99 -5.66 -10.20 -0.32 4985.98 1.00
β[48] 0.74 2.23 0.77 -2.89 4.42 3883.18 1.00
β[49] -0.05 1.96 -0.03 -3.14 3.21 5014.63 1.00
β[50] -6.66 2.21 -6.70 -10.32 -3.04 3677.17 1.00
β[51] -4.38 2.23 -4.35 -8.12 -0.87 4395.64 1.00
β[52] -6.26 2.33 -6.21 -10.11 -2.64 4683.23 1.00
β[53] -11.29 3.00 -11.30 -16.50 -6.57 4172.36 1.00
β[54] 0.53 2.20 0.49 -2.74 4.33 3909.65 1.00
β[55] -4.76 2.25 -4.77 -8.53 -1.01 4825.57 1.00
β[56] -6.44 3.85 -6.41 -12.56 0.06 4113.27 1.00
β[57] -2.80 2.27 -2.74 -6.67 0.70 3687.04 1.00
β[58] -7.49 2.26 -7.51 -10.84 -3.50 4276.53 1.00
β[59] -4.83 2.15 -4.82 -8.31 -1.30 5575.94 1.00
β[60] -1.40 2.24 -1.41 -4.98 2.35 4515.19 1.00
β[61] -3.11 2.28 -3.11 -6.95 0.55 4081.08 1.00
β[62] -8.59 2.31 -8.61 -12.35 -4.93 3377.48 1.00
β[63] -0.72 2.21 -0.73 -4.17 3.04 4547.00 1.00
β[64] -1.68 2.26 -1.63 -5.77 1.66 3385.73 1.00
β[65] -3.64 2.22 -3.64 -7.07 0.12 5571.42 1.00
β[66] -3.56 2.21 -3.63 -7.35 -0.16 4991.63 1.00
β[67] -1.55 1.90 -1.54 -4.73 1.41 4627.93 1.00
β[68] -7.49 2.09 -7.48 -11.06 -4.34 4079.94 1.00
β[69] -1.36 2.33 -1.39 -5.11 2.52 4569.82 1.00
β[70] -0.29 2.23 -0.25 -4.25 2.95 4242.55 1.00
β[71] -10.03 1.98 -10.06 -13.32 -6.94 4442.15 1.00
β[72] -7.53 2.24 -7.46 -11.36 -3.98 5482.66 1.00
β[73] -3.37 2.34 -3.39 -7.11 0.40 5019.89 1.00
β[74] 2.41 2.23 2.45 -1.18 6.12 4508.04 1.00
β[75] -1.25 2.26 -1.33 -4.92 2.55 5646.07 1.00
β[76] -4.17 2.23 -4.17 -7.97 -0.66 5675.45 1.00
β[77] -4.32 2.26 -4.33 -8.17 -0.89 4923.77 1.00
β[78] -0.77 2.26 -0.75 -4.48 2.97 4750.75 1.00
β[79] -3.90 2.27 -3.89 -7.59 -0.15 4487.57 1.00
β[80] -12.61 2.64 -12.60 -17.07 -8.36 4246.54 1.00
β[81] -1.52 2.29 -1.45 -5.58 2.05 4651.25 1.00
β[82] -4.46 2.18 -4.50 -7.95 -0.89 4468.52 1.00
β[83] -1.38 2.24 -1.39 -4.98 2.40 5150.16 1.00
β[84] 10.63 2.32 10.60 6.78 14.33 4505.47 1.00
β[85] 0.58 2.27 0.59 -3.09 4.36 3539.17 1.00
β[86] -10.39 2.28 -10.40 -14.11 -6.87 4488.76 1.00
β[87] -3.61 2.29 -3.60 -7.36 0.10 5316.67 1.00
β[88] -0.30 2.22 -0.32 -3.80 3.35 3819.55 1.00
β[89] -1.93 2.32 -1.92 -5.96 1.63 4179.68 1.00
β[90] -2.15 2.37 -2.15 -5.91 1.70 3341.19 1.00
β[91] -3.70 2.26 -3.69 -7.61 -0.12 3904.55 1.00
β[92] 2.18 2.20 2.17 -1.22 6.00 4908.78 1.00
β[93] -3.33 2.28 -3.41 -7.01 0.50 3571.76 1.00
β[94] -2.74 2.32 -2.75 -6.52 0.98 4606.86 1.00
β[95] -4.31 2.28 -4.40 -7.87 -0.61 4132.73 1.00
β[96] -0.62 2.29 -0.61 -4.24 3.25 3616.64 1.00
β[97] -7.27 2.28 -7.18 -11.10 -3.62 4140.66 1.00
β[98] -5.21 2.29 -5.23 -8.89 -1.43 5202.64 1.00
β[99] -3.25 1.91 -3.24 -6.55 -0.23 4539.26 1.00
β[100] -2.23 2.32 -2.18 -6.27 1.43 5201.07 1.00
β[101] -4.05 2.14 -4.01 -7.60 -0.70 3869.35 1.00
β[102] -3.24 2.97 -3.28 -7.89 1.76 3127.43 1.00
β[103] 0.79 2.43 0.81 -2.78 5.17 6492.31 1.00
β[104] -4.63 2.39 -4.58 -8.28 -0.38 4463.70 1.00
β[105] -3.55 2.32 -3.48 -7.44 0.17 3883.25 1.00
β[106] -3.12 2.38 -3.04 -7.04 0.68 4317.92 1.00
β[107] -0.32 2.26 -0.27 -3.93 3.38 4878.10 1.00
β[108] -19.46 2.24 -19.44 -23.17 -15.80 3707.86 1.00
β[109] -5.55 2.08 -5.50 -8.91 -2.23 5596.44 1.00
β[110] -3.00 2.37 -2.99 -6.79 0.91 4362.65 1.00
β[111] -1.77 2.96 -1.76 -6.89 2.89 4287.54 1.00
β[112] -5.76 4.03 -5.67 -12.41 0.60 4568.87 1.00
β[113] 0.30 2.26 0.29 -3.39 4.05 4201.45 1.00
β[114] -8.96 2.15 -8.93 -12.54 -5.55 4902.72 1.00
β[115] -6.70 3.75 -6.70 -12.80 -0.58 3675.94 1.00
β[116] -4.58 2.27 -4.57 -7.94 -0.57 4612.40 1.00
β[117] -0.47 2.21 -0.45 -4.30 2.92 4604.65 1.00
β[118] -0.88 2.17 -0.89 -4.45 2.67 3857.26 1.00
β[119] 3.37 3.06 3.34 -1.63 8.40 4641.89 1.00
β[120] -7.61 2.76 -7.61 -12.08 -3.07 4584.45 1.00
β[121] -7.80 2.18 -7.84 -11.53 -4.31 4143.18 1.00
β[122] -4.47 2.28 -4.42 -8.62 -1.11 5003.13 1.00
β[123] -1.87 3.11 -1.85 -6.48 3.77 3837.84 1.00
β[124] -7.47 2.30 -7.42 -11.15 -3.56 3488.98 1.00
β[125] -2.96 2.30 -2.95 -6.51 0.86 4045.44 1.00
β[126] -4.62 2.32 -4.67 -8.37 -0.90 4501.75 1.00
β[127] -7.88 2.22 -7.87 -11.55 -4.17 4004.76 1.00
β[128] -4.28 2.31 -4.32 -7.78 -0.44 3548.34 1.00
β[129] -5.02 2.21 -5.08 -8.72 -1.32 3864.48 1.00
β[130] -9.61 2.21 -9.60 -13.21 -6.13 4645.20 1.00
β[131] 5.26 2.23 5.33 1.71 9.07 3949.80 1.00
β[132] -5.85 2.23 -5.85 -9.48 -2.19 4969.66 1.00
β[133] 0.48 3.14 0.54 -4.68 5.72 4701.96 1.00
β[134] -1.64 2.16 -1.65 -5.28 1.70 4533.16 1.00
β[135] 1.34 2.32 1.36 -2.49 5.12 4581.77 1.00
β[136] -5.27 2.41 -5.27 -9.53 -1.62 4665.97 1.00
β[137] -5.41 2.89 -5.32 -10.61 -1.06 4848.17 1.00
β[138] -0.34 2.25 -0.35 -3.84 3.53 6135.90 1.00
β[139] -8.16 2.34 -8.12 -11.84 -4.21 4310.14 1.00
β[140] -3.98 2.25 -3.96 -7.82 -0.19 5176.38 1.00
β[141] -2.42 2.22 -2.46 -5.94 1.29 4938.04 1.00
β[142] -7.50 2.23 -7.48 -10.93 -3.83 5237.66 1.00
β[143] -1.22 1.95 -1.24 -4.48 1.89 4102.71 1.00
β[144] -6.33 3.78 -6.32 -12.02 0.48 3855.15 1.00
β[145] 0.36 2.30 0.31 -3.04 4.50 5106.15 1.00
β[146] 5.67 2.24 5.69 2.32 9.50 4880.28 1.00
β[147] -1.28 2.33 -1.32 -4.71 2.91 4092.35 1.00
β[148] -2.58 2.34 -2.60 -6.47 1.23 3386.65 1.00
β[149] -12.39 2.23 -12.40 -16.17 -8.91 5298.42 1.00
β[150] -6.77 3.07 -6.72 -11.85 -1.95 4858.98 1.00
β[151] -7.95 2.38 -7.92 -11.65 -3.74 4461.45 1.00
β[152] -2.49 2.33 -2.50 -6.22 1.21 3764.36 1.00
β[153] -1.30 2.37 -1.29 -5.04 2.67 6336.56 1.00
β[154] -3.86 2.20 -3.87 -7.43 -0.39 3279.35 1.00
β[155] -2.65 2.41 -2.66 -7.01 0.97 4424.41 1.00
β[156] -3.71 2.22 -3.71 -7.33 -0.15 4825.16 1.00
β[157] -8.63 3.94 -8.60 -14.57 -1.90 3967.71 1.00
β[158] -3.99 2.31 -3.98 -7.49 -0.07 4349.28 1.00
β[159] -4.94 1.98 -4.90 -8.41 -1.91 4700.91 1.00
β[160] -6.93 2.43 -6.99 -10.98 -2.99 4785.30 1.00
β[161] -3.47 2.13 -3.47 -7.00 -0.16 3972.92 1.00
β[162] -15.86 3.11 -15.80 -21.10 -10.84 4183.32 1.00
β[163] -0.94 1.87 -0.95 -4.08 2.08 4370.92 1.00
β[164] -2.01 2.22 -2.00 -5.44 1.98 5639.13 1.00
β[165] -2.00 2.22 -1.98 -5.55 1.68 3479.20 1.00
β[166] -4.23 2.27 -4.25 -7.71 -0.32 4338.94 1.00
β[167] -1.70 3.12 -1.79 -6.56 3.57 4373.63 1.00
β[168] -1.91 2.28 -1.90 -5.93 1.68 4487.30 1.00
β[169] -17.05 2.31 -17.07 -20.95 -13.41 4857.17 1.00
β[170] -1.69 2.27 -1.68 -5.50 1.95 4521.26 1.00
β[171] -2.54 2.31 -2.57 -6.12 1.41 4106.66 1.00
β[172] -1.94 1.91 -1.91 -4.93 1.32 3996.98 1.00
β[173] -4.88 2.92 -4.92 -9.76 -0.26 4223.72 1.00
β[174] -7.95 2.24 -7.97 -11.47 -4.07 4681.34 1.00
β[175] -1.66 2.25 -1.69 -4.95 2.32 3655.75 1.00
μ_α 2774.79 54.96 2775.49 2686.53 2863.70 3474.50 1.00
μ_β -4.17 0.41 -4.17 -4.85 -3.51 2474.19 1.00
σ 136.71 2.74 136.64 132.14 141.13 1906.90 1.00
σ_α 723.12 31.96 722.08 672.81 774.22 3737.96 1.00
σ_β 4.99 0.36 4.96 4.39 5.55 2482.51 1.00
Number of divergences: 0
import arviz as az
data = az.from_numpyro(mcmc)
az.plot_trace(data, compact= True , figsize= (15 , 25 ));
Let’s see how this is done:
The model is described by the following equations:
\[
\begin{align}
\mu_{\alpha} &\sim \text{Normal}(0, 500) \\
\sigma_{\alpha} &\sim \text{Half-Normal}(100) \\
\mu_{\beta} &\sim \text{Normal}(0, 3) \\
\sigma_{\beta} &\sim \text{Half-Normal}(3) \\
\alpha_i &\sim \text{Normal}(\mu_{\alpha}, \sigma_{\alpha}) \\
\beta_i &\sim \text{Normal}(\mu_{\beta}, \sigma_{\beta}) \\
\sigma &\sim \text{Half-Normal}(100) \\
FVC_{ij} &\sim \text{Normal}(\alpha_i + t \beta_i, \sigma)
\end{align}
\] where \(t\) is week variable.
There are 176 patients, each has multiple weekly observations. There are 176 \(\alpha_i\) and 176 \(\beta_i\) (random effects for intercept and for slope). There are two “fixed effects”, \(\mu_{\alpha}\) and \(\mu_{\beta}\) . And there is one standard deviation \(\sigma\) for the residuals, two standard deviations \(\sigma_{\alpha}\) and \(\sigma_{\beta}\) for the random effects.
The idea of Bayesian inference is that we have this model, and the data, and we want to infer the posterior distribution of the parameters given the data. The posterior distribution is proportional to the likelihood times the prior distribution.
We can see in this case, the random effects model has similar results as the Bayesian model.
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
train = pd.read_csv(
"https://gist.githubusercontent.com/ucals/"
"2cf9d101992cb1b78c2cdd6e3bac6a4b/raw/"
"43034c39052dcf97d4b894d2ec1bc3f90f3623d9/"
"osic_pulmonary_fibrosis.csv"
)
from jax import random
def model(patient_code, Weeks, FVC_obs= None ):
μ_α = numpyro.sample("μ_α" , dist.Normal(0.0 , 500.0 ))
σ_α = numpyro.sample("σ_α" , dist.HalfNormal(100.0 ))
μ_β = numpyro.sample("μ_β" , dist.Normal(0.0 , 3.0 ))
σ_β = numpyro.sample("σ_β" , dist.HalfNormal(3.0 ))
n_patients = len (np.unique(patient_code))
with numpyro.plate("plate_i" , n_patients):
α = numpyro.sample("α" , dist.Normal(μ_α, σ_α))
β = numpyro.sample("β" , dist.Normal(μ_β, σ_β))
σ = numpyro.sample("σ" , dist.HalfNormal(100.0 ))
FVC_est = α[patient_code] + β[patient_code] * Weeks
with numpyro.plate("data" , len (patient_code)):
numpyro.sample("obs" , dist.Normal(FVC_est, σ), obs= FVC_obs)
from sklearn.preprocessing import LabelEncoder
patient_encoder = LabelEncoder()
train["patient_code" ] = patient_encoder.fit_transform(train["Patient" ].values)
FVC_obs = train["FVC" ].values
Weeks = train["Weeks" ].values
patient_code = train["patient_code" ].values
numpyro.render_model(
model= model,
model_args= (patient_code, Weeks, FVC_obs),
render_distributions= True ,
render_params= True ,
)
<graphviz.graphs.Digraph object at 0x7cfcfdb6a510>