Using numpyro

R
python
code
analysis
Author

Xiang Ao

Published

October 17, 2025

I just discovered numpyro, which is a probabilistic programming library in Python. It can be coupled with Jax and is very powerful.

Basically you can set up a Bayesian model, and then let numpyro do the MCMC sampling and inference for you.

Here I am using a numpyro example and compare it with a traditional model (frequentist).

Example 1

I am using the numpyro example here: https://num.pyro.ai/en/stable/tutorials/bayesian_hierarchical_linear_regression.html

The data set is from https://www.kaggle.com/c/osic-pulmonary-fibrosis-progression

“Pulmonary fibrosis is a disorder with no known cause and no known cure, created by scarring of the lungs. In this competition, we were asked to predict a patient’s severity of decline in lung function. Lung function is assessed based on output from a spirometer, which measures the forced vital capacity (FVC), i.e. the volume of air exhaled.

In medical applications, it is useful to evaluate a model’s confidence in its decisions. Accordingly, the metric used to rank the teams was designed to reflect both the accuracy and certainty of each prediction.”

I read it in R first.

library(tidyverse)
# read in csv file
library(readr)

data <- read_csv("./osic_pulmonary_fibrosis.csv") |> 
  arrange(Patient, Weeks)

data
# A tibble: 1,549 × 7
   Patient                   Weeks   FVC Percent   Age Sex   SmokingStatus
   <chr>                     <dbl> <dbl>   <dbl> <dbl> <chr> <chr>        
 1 ID00007637202177411956430    -4  2315    58.3    79 Male  Ex-smoker    
 2 ID00007637202177411956430     5  2214    55.7    79 Male  Ex-smoker    
 3 ID00007637202177411956430     7  2061    51.9    79 Male  Ex-smoker    
 4 ID00007637202177411956430     9  2144    54.0    79 Male  Ex-smoker    
 5 ID00007637202177411956430    11  2069    52.1    79 Male  Ex-smoker    
 6 ID00007637202177411956430    17  2101    52.9    79 Male  Ex-smoker    
 7 ID00007637202177411956430    29  2000    50.3    79 Male  Ex-smoker    
 8 ID00007637202177411956430    41  2064    51.9    79 Male  Ex-smoker    
 9 ID00007637202177411956430    57  2057    51.8    79 Male  Ex-smoker    
10 ID00009637202177434476278     8  3660    85.3    69 Male  Ex-smoker    
# ℹ 1,539 more rows

a random effect model

We’d do a random effect model on interectp and slope of weeks, with a linear trend in weeks.

# a random effect model with FVC as DV, and a linear time trend, with random effect on Patient
library(lme4)
reg1 <- lmer(FVC ~ Weeks + (1 + Weeks | Patient), data = data)
summary(reg1)
Linear mixed model fit by REML ['lmerMod']
Formula: FVC ~ Weeks + (1 + Weeks | Patient)
   Data: data

REML criterion at convergence: 20891.4

Scaled residuals: 
    Min      1Q  Median      3Q     Max 
-9.3209 -0.4257  0.0063  0.4458  5.6990 

Random effects:
 Groups   Name        Variance  Std.Dev. Corr 
 Patient  (Intercept) 686691.11 828.668       
          Weeks           25.88   5.087  -0.14
 Residual              18592.11 136.353       
Number of obs: 1549, groups:  Patient, 176

Fixed effects:
             Estimate Std. Error t value
(Intercept) 2810.2843    62.9497  44.643
Weeks         -4.2629     0.4377  -9.739

Correlation of Fixed Effects:
      (Intr)
Weeks -0.174

Then we do the same with a Bayesian hierarchical model. Should have similar results.

hierarchical model in numpyro

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

train = pd.read_csv(
    "https://gist.githubusercontent.com/ucals/"
    "2cf9d101992cb1b78c2cdd6e3bac6a4b/raw/"
    "43034c39052dcf97d4b894d2ec1bc3f90f3623d9/"
    "osic_pulmonary_fibrosis.csv"
)


from jax import random

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive

assert numpyro.__version__.startswith("0.19.0")

def model(patient_code, Weeks, FVC_obs=None):
    μ_α = numpyro.sample("μ_α", dist.Normal(0.0, 500.0))
    σ_α = numpyro.sample("σ_α", dist.HalfNormal(100.0))
    μ_β = numpyro.sample("μ_β", dist.Normal(0.0, 3.0))
    σ_β = numpyro.sample("σ_β", dist.HalfNormal(3.0))

    n_patients = len(np.unique(patient_code))

    with numpyro.plate("plate_i", n_patients):
        α = numpyro.sample("α", dist.Normal(μ_α, σ_α))
        β = numpyro.sample("β", dist.Normal(μ_β, σ_β))

    σ = numpyro.sample("σ", dist.HalfNormal(100.0))
    FVC_est = α[patient_code] + β[patient_code] * Weeks

    with numpyro.plate("data", len(patient_code)):
        numpyro.sample("obs", dist.Normal(FVC_est, σ), obs=FVC_obs)
        
from sklearn.preprocessing import LabelEncoder

patient_encoder = LabelEncoder()
train["patient_code"] = patient_encoder.fit_transform(train["Patient"].values)

FVC_obs = train["FVC"].values
Weeks = train["Weeks"].values
patient_code = train["patient_code"].values


nuts_kernel = NUTS(model)

mcmc = MCMC(nuts_kernel, num_samples=2000, num_warmup=2000, progress_bar=False)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, patient_code, Weeks, FVC_obs=FVC_obs)

posterior_samples = mcmc.get_samples()
mcmc.print_summary() 

                mean       std    median      5.0%     95.0%     n_eff     r_hat
      α[0]   2181.96     60.74   2182.26   2080.02   2276.12   3478.15      1.00
      α[1]   3783.49     73.96   3784.85   3667.70   3913.58   4928.66      1.00
      α[2]   3265.39     58.09   3264.75   3175.63   3366.59   4088.11      1.00
      α[3]   3478.90     66.67   3477.34   3371.61   3589.76   3993.87      1.00
      α[4]   3700.88    118.15   3699.46   3507.03   3892.74   4173.49      1.00
      α[5]   3724.77     59.33   3725.87   3623.35   3817.91   5760.40      1.00
      α[6]   3050.36     87.13   3052.33   2901.09   3187.95   4301.39      1.00
      α[7]   2151.55     78.65   2149.32   2031.82   2287.03   4117.25      1.00
      α[8]   2344.04     89.13   2344.16   2197.93   2493.25   4350.38      1.00
      α[9]   1496.10     59.90   1495.72   1400.21   1595.97   3821.66      1.00
     α[10]   3147.59     59.34   3146.73   3054.35   3247.87   3823.34      1.00
     α[11]   3061.16    134.88   3059.77   2859.88   3302.93   4108.39      1.00
     α[12]   2776.40     81.49   2777.66   2638.22   2904.36   5336.77      1.00
     α[13]   2681.60    125.34   2682.50   2470.12   2876.24   5991.18      1.00
     α[14]   5195.74    116.71   5196.11   5008.86   5385.35   4844.13      1.00
     α[15]   2780.42     75.71   2780.33   2663.12   2909.89   4404.98      1.00
     α[16]   3924.87     78.16   3923.05   3795.71   4053.21   4959.80      1.00
     α[17]   3613.71     61.09   3613.50   3513.57   3715.78   3498.32      1.00
     α[18]   3276.12     66.24   3276.51   3162.19   3382.03   2883.28      1.00
     α[19]   1477.57     69.84   1478.19   1352.44   1579.33   3962.45      1.00
     α[20]   1703.49     56.27   1704.27   1612.68   1796.76   2952.87      1.00
     α[21]   2984.97    218.20   2988.79   2625.00   3337.65   5150.51      1.00
     α[22]   3392.43     93.03   3391.44   3252.03   3554.48   3935.83      1.00
     α[23]   3637.04    103.11   3639.29   3468.56   3799.16   4211.06      1.00
     α[24]   2483.06    144.52   2483.37   2247.02   2721.52   5069.97      1.00
     α[25]   3453.93     66.54   3455.23   3345.72   3565.21   4795.24      1.00
     α[26]   2834.53     75.77   2834.91   2716.25   2961.84   4413.27      1.00
     α[27]   2419.90     83.80   2420.70   2289.84   2568.34   4094.23      1.00
     α[28]   2495.48     68.58   2494.98   2379.58   2601.87   5322.98      1.00
     α[29]   1828.49    109.39   1829.96   1650.90   2013.15   3854.51      1.00
     α[30]   2387.46     59.38   2387.19   2291.45   2484.52   4065.98      1.00
     α[31]   3258.15     64.09   3259.99   3156.46   3366.40   4339.93      1.00
     α[32]   1885.49     63.12   1886.79   1779.17   1986.83   3903.91      1.00
     α[33]   2946.67    109.14   2945.00   2761.04   3113.03   4230.53      1.00
     α[34]   3373.80     62.44   3373.85   3277.11   3478.82   4635.49      1.00
     α[35]   2537.18     72.11   2536.82   2427.51   2662.56   4037.66      1.00
     α[36]   3219.40     92.97   3219.59   3067.16   3361.41   3682.34      1.00
     α[37]   3578.86     64.44   3579.33   3473.77   3685.27   4499.79      1.00
     α[38]   4806.80     67.82   4806.16   4704.59   4926.04   4021.18      1.00
     α[39]   3142.70    128.24   3140.92   2944.80   3357.44   4483.11      1.00
     α[40]   2303.84     64.02   2304.47   2191.04   2403.73   4225.49      1.00
     α[41]   2567.85    104.45   2568.37   2408.97   2752.63   4547.57      1.00
     α[42]   2746.64    132.81   2749.36   2525.71   2960.18   4787.09      1.00
     α[43]   2856.82     79.62   2858.73   2734.38   2990.30   5177.49      1.00
     α[44]   2553.31    161.52   2556.90   2279.42   2807.01   4763.51      1.00
     α[45]   2028.61     99.29   2029.85   1870.11   2198.04   4007.81      1.00
     α[46]   2415.84     85.37   2412.85   2274.98   2551.99   4389.82      1.00
     α[47]   2559.61     91.70   2560.16   2415.14   2713.06   4703.63      1.00
     α[48]   2903.48     67.59   2903.82   2783.52   3003.76   3074.21      1.00
     α[49]   2916.36     61.66   2917.28   2815.96   3018.28   3231.64      1.00
     α[50]   2433.25     62.36   2433.10   2333.02   2535.01   4112.43      1.00
     α[51]   1874.57    149.88   1874.53   1636.37   2125.48   4466.82      1.00
     α[52]   3885.27    152.80   3884.64   3634.18   4129.37   4764.01      1.00
     α[53]   2453.30     74.16   2454.34   2332.25   2572.95   4032.64      1.00
     α[54]   2289.13     98.67   2291.88   2113.48   2432.82   4228.26      1.00
     α[55]   1968.45    148.75   1967.96   1727.68   2223.98   4738.61      1.00
     α[56]   2448.41    168.45   2449.80   2177.20   2731.69   4271.57      1.00
     α[57]   2284.98     61.21   2285.08   2188.96   2386.97   5302.78      1.00
     α[58]   1690.54     75.88   1692.69   1572.17   1820.44   4384.45      1.00
     α[59]   3652.09    175.83   3651.85   3371.15   3952.49   5137.58      1.00
     α[60]   2659.02     72.32   2658.87   2540.28   2776.74   4343.19      1.00
     α[61]   3162.24     59.04   3164.49   3066.44   3260.57   3628.39      1.00
     α[62]   2697.13     68.00   2697.47   2589.83   2809.31   3821.55      1.00
     α[63]   3943.73    119.42   3946.85   3729.60   4121.03   4423.43      1.00
     α[64]   3339.19     64.10   3340.30   3229.28   3435.80   3789.94      1.00
     α[65]   4422.48     85.06   4424.31   4280.59   4564.35   6713.01      1.00
     α[66]   4017.49     92.23   4016.28   3868.26   4169.53   4683.90      1.00
     α[67]   4089.30     82.06   4091.63   3959.81   4223.71   4983.39      1.00
     α[68]   3549.15    138.06   3549.95   3327.27   3776.16   3890.86      1.00
     α[69]   2901.59     73.02   2900.42   2790.57   3030.32   4683.16      1.00
     α[70]   2181.65    219.45   2175.94   1841.71   2541.47   4550.12      1.00
     α[71]   4770.96     74.12   4771.11   4654.07   4898.68   4209.58      1.00
     α[72]   3148.67     68.49   3149.51   3038.27   3261.98   4719.43      1.00
     α[73]   2568.25     84.10   2568.06   2422.14   2691.78   4620.57      1.00
     α[74]   2071.70    135.79   2069.03   1851.62   2289.57   4258.39      1.00
     α[75]   2805.05    126.29   2808.18   2587.92   3004.50   4962.74      1.00
     α[76]   2851.84     84.64   2854.08   2708.30   2982.13   4178.09      1.00
     α[77]   3069.15     75.47   3067.51   2947.36   3191.41   4765.39      1.00
     α[78]   2473.79    134.71   2472.45   2246.52   2689.86   4316.09      1.00
     α[79]   3176.65    112.70   3181.20   2992.57   3366.85   4780.32      1.00
     α[80]   2918.38     75.56   2919.69   2795.13   3045.32   4288.53      1.00
     α[81]   4380.73     73.78   4380.60   4256.85   4500.40   4449.85      1.00
     α[82]   1901.83     73.17   1900.47   1776.18   2016.42   3629.52      1.00
     α[83]   3939.29     74.29   3939.85   3822.65   4061.64   4840.37      1.00
     α[84]   2302.36     64.27   2301.19   2197.80   2406.84   3878.34      1.00
     α[85]   2803.07     66.43   2803.56   2706.09   2922.83   4427.60      1.00
     α[86]   4430.48     94.38   4430.86   4275.12   4581.68   4876.03      1.00
     α[87]   2862.12    122.95   2862.95   2663.85   3066.27   5159.40      1.00
     α[88]   2831.12     59.21   2830.93   2738.42   2929.85   3505.62      1.00
     α[89]   3174.82     69.33   3175.22   3065.44   3291.17   2897.17      1.00
     α[90]   2907.47     63.66   2908.82   2807.63   3013.57   3894.72      1.00
     α[91]   2321.38     73.67   2321.39   2197.67   2435.79   4540.13      1.00
     α[92]   2534.82     92.63   2533.36   2380.92   2685.90   4466.08      1.00
     α[93]   5897.89     61.90   5898.57   5798.89   6001.60   3792.77      1.00
     α[94]   1493.36     77.54   1494.47   1357.99   1613.48   4334.43      1.00
     α[95]   2553.11     64.25   2552.00   2446.32   2654.63   5198.04      1.00
     α[96]   2703.26     61.69   2704.33   2597.79   2800.33   3588.30      1.00
     α[97]   1254.25     84.35   1254.03   1111.30   1386.08   4096.89      1.00
     α[98]   1977.65     81.96   1975.93   1850.62   2117.53   4584.07      1.00
     α[99]   3626.00     94.04   3624.16   3453.35   3765.02   6152.51      1.00
    α[100]   1960.98     74.47   1960.78   1839.09   2081.26   5508.20      1.00
    α[101]   3618.52     56.55   3618.62   3520.58   3705.30   3322.28      1.00
    α[102]   1974.77     65.12   1974.23   1860.57   2073.28   3503.83      1.00
    α[103]   2105.64    128.71   2105.25   1886.33   2306.98   6616.46      1.00
    α[104]   2910.32     68.85   2910.23   2789.15   3014.76   4584.25      1.00
    α[105]   1401.37     63.49   1401.44   1303.52   1512.56   4920.60      1.00
    α[106]   1840.26     65.45   1840.61   1734.82   1946.79   4199.05      1.00
    α[107]   1381.50     80.50   1379.99   1243.57   1515.55   5144.95      1.00
    α[108]   3918.88    134.89   3916.91   3674.47   4118.55   4558.49      1.00
    α[109]   2231.49    103.29   2231.06   2077.49   2414.70   4677.99      1.00
    α[110]   1565.15     66.49   1567.76   1456.32   1674.67   3762.40      1.00
    α[111]   2075.41     87.26   2076.29   1941.11   2223.85   4193.44      1.00
    α[112]   1065.46     73.31   1067.54    938.10   1176.44   4130.45      1.00
    α[113]   2863.31     68.76   2865.82   2749.60   2974.66   4549.16      1.00
    α[114]   2517.66    146.55   2517.04   2289.99   2768.97   4708.89      1.00
    α[115]   2973.36     73.24   2973.90   2847.22   3086.66   4239.57      1.00
    α[116]   4160.73     70.88   4160.89   4043.80   4277.77   4756.33      1.00
    α[117]   1981.23     59.08   1981.86   1887.80   2081.02   3847.90      1.00
    α[118]   1402.20     63.46   1401.69   1303.28   1512.53   3938.85      1.00
    α[119]   2096.95    108.78   2098.23   1910.34   2266.26   4640.21      1.00
    α[120]   2363.06     80.42   2361.92   2237.11   2505.73   3919.36      1.00
    α[121]   4055.78     59.60   4055.97   3955.94   4151.83   3641.06      1.00
    α[122]   2289.47     66.27   2286.09   2188.57   2405.43   4131.95      1.00
    α[123]   2191.83     74.29   2192.15   2071.21   2315.32   3979.08      1.00
    α[124]   1923.23     66.87   1922.47   1815.24   2032.27   3807.17      1.00
    α[125]   2250.75    100.50   2249.12   2086.16   2409.58   5053.45      1.00
    α[126]   2915.61     75.51   2915.58   2794.77   3037.53   5279.26      1.00
    α[127]   3239.20     64.65   3237.52   3144.70   3354.14   4017.07      1.00
    α[128]   3222.20     66.91   3222.45   3109.25   3328.30   3365.91      1.00
    α[129]   1859.73     61.92   1858.81   1758.87   1964.10   4568.50      1.00
    α[130]   2614.03    115.54   2615.68   2420.21   2804.06   4947.98      1.00
    α[131]   3910.90     81.49   3911.47   3779.04   4048.18   4029.07      1.00
    α[132]   2618.84     63.93   2618.52   2509.26   2720.43   4529.40      1.00
    α[133]   1617.22    126.03   1614.94   1406.29   1819.48   4226.61      1.00
    α[134]   2891.23    130.63   2892.29   2684.80   3108.29   4700.87      1.00
    α[135]   3080.86    116.70   3078.24   2867.59   3246.76   5356.82      1.00
    α[136]   2736.10     65.37   2737.26   2632.36   2845.87   4798.27      1.00
    α[137]   2616.46     70.43   2617.43   2499.96   2731.58   4693.34      1.00
    α[138]   3931.41    113.71   3931.37   3744.14   4120.55   6092.02      1.00
    α[139]   3540.65     66.25   3541.71   3435.03   3652.82   4975.09      1.00
    α[140]   2319.48     64.50   2319.73   2214.62   2429.24   3795.68      1.00
    α[141]   2933.86    114.23   2931.55   2746.68   3115.25   4674.73      1.00
    α[142]   2405.33     99.42   2404.73   2248.27   2568.26   4378.19      1.00
    α[143]   1608.70     94.94   1607.20   1452.57   1764.13   4365.56      1.00
    α[144]   2065.78     66.48   2066.25   1962.33   2177.87   4630.61      1.00
    α[145]   2608.69     80.14   2607.56   2476.55   2740.32   5272.79      1.00
    α[146]   4202.96    105.58   4201.76   4034.35   4378.06   5225.38      1.00
    α[147]   2052.41     75.69   2052.85   1930.84   2175.75   4737.47      1.00
    α[148]   2888.84     66.31   2888.21   2785.91   3000.87   4069.44      1.00
    α[149]   3538.52    130.67   3538.85   3311.72   3734.16   5014.65      1.00
    α[150]   2671.02    152.86   2672.30   2421.11   2920.64   5058.42      1.00
    α[151]   1734.95    130.10   1735.16   1529.26   1966.11   4134.36      1.00
    α[152]   3036.62     77.43   3037.41   2913.63   3163.50   5443.08      1.00
    α[153]   2280.62     80.84   2280.98   2157.90   2419.70   6519.05      1.00
    α[154]   2535.44     67.69   2535.98   2424.18   2645.61   4251.41      1.00
    α[155]   4401.42    142.25   4403.37   4172.91   4646.13   4468.32      1.00
    α[156]   3468.31    103.94   3469.87   3318.57   3657.42   4614.79      1.00
    α[157]   2300.55     74.07   2301.65   2186.87   2430.85   4202.74      1.00
    α[158]   3354.23    157.23   3350.28   3104.44   3608.42   4578.01      1.00
    α[159]   3142.29     62.26   3142.13   3032.66   3240.83   4841.13      1.00
    α[160]   2204.44     70.52   2202.80   2089.87   2320.99   4692.09      1.00
    α[161]   3737.23     61.33   3736.79   3634.25   3833.27   3300.82      1.00
    α[162]   3000.90     86.93   3000.32   2858.64   3146.88   4496.44      1.00
    α[163]   3965.14    100.36   3964.14   3793.29   4120.08   3782.56      1.00
    α[164]   2020.35    142.58   2021.99   1791.25   2264.88   5086.22      1.00
    α[165]   1580.13     64.85   1579.95   1486.46   1700.85   2607.56      1.00
    α[166]   3250.36     94.84   3251.18   3103.08   3407.85   4652.29      1.00
    α[167]   2455.41     89.25   2455.88   2307.87   2606.59   4002.42      1.00
    α[168]   1770.46     70.57   1770.27   1659.72   1888.74   4205.60      1.00
    α[169]   2892.62     76.82   2890.41   2774.02   3024.54   3921.34      1.00
    α[170]   3357.47     69.56   3356.52   3252.30   3478.19   4614.54      1.00
    α[171]   2836.92     67.01   2839.11   2729.38   2944.73   4768.27      1.00
    α[172]   2828.49     81.06   2828.37   2687.76   2949.13   4523.36      1.00
    α[173]   1989.75     77.96   1989.84   1854.52   2106.57   3484.11      1.00
    α[174]   3007.70     85.53   3010.05   2869.75   3152.74   4488.10      1.00
    α[175]   2931.63     67.48   2931.62   2813.62   3032.52   4880.61      1.00
      β[0]     -3.49      2.13     -3.52     -6.86      0.11   3777.86      1.00
      β[1]     -8.02      2.36     -8.00    -11.68     -4.02   4356.88      1.00
      β[2]    -14.27      2.28    -14.26    -17.84    -10.21   3486.20      1.00
      β[3]     -4.27      2.31     -4.27     -7.77     -0.22   4751.57      1.00
      β[4]     -7.13      2.19     -7.13    -10.57     -3.43   4758.71      1.00
      β[5]    -11.69      2.32    -11.69    -15.47     -7.84   4411.74      1.00
      β[6]     -6.71      2.27     -6.70    -10.40     -3.02   4800.55      1.00
      β[7]     -5.62      2.20     -5.63     -9.54     -2.30   4401.94      1.00
      β[8]     -6.81      2.27     -6.88    -10.33     -2.83   4538.73      1.00
      β[9]     -3.06      2.28     -3.04     -6.67      0.83   5195.12      1.00
     β[10]     -9.32      2.33     -9.33    -13.17     -5.60   3699.40      1.00
     β[11]     -3.07      2.15     -3.07     -6.56      0.49   4320.79      1.00
     β[12]     -5.17      2.26     -5.14     -8.90     -1.54   4824.94      1.00
     β[13]     -4.64      2.92     -4.56     -9.56     -0.14   4782.43      1.00
     β[14]    -11.04      2.19    -11.04    -14.47     -7.42   4283.43      1.00
     β[15]     -8.45      2.30     -8.48    -12.12     -4.57   4260.51      1.00
     β[16]     -1.18      2.36     -1.12     -5.03      2.76   4967.28      1.00
     β[17]     -0.16      2.14     -0.19     -3.62      3.29   3726.29      1.00
     β[18]     -6.65      3.82     -6.71    -13.21     -0.86   3985.31      1.00
     β[19]     -2.16      2.20     -2.18     -5.76      1.43   3160.09      1.00
     β[20]     -3.71      2.20     -3.68     -7.30     -0.02   3877.82      1.00
     β[21]    -11.26      3.68    -11.34    -17.13     -5.13   4327.97      1.00
     β[22]     -5.95      2.40     -5.94     -9.97     -2.07   4703.96      1.00
     β[23]      5.92      2.37      5.92      2.08      9.67   4153.01      1.00
     β[24]     -4.07      2.26     -4.07     -7.76     -0.37   5563.09      1.00
     β[25]     -5.09      3.02     -5.11    -10.29     -0.54   3885.94      1.00
     β[26]     -6.07      2.34     -6.07    -10.20     -2.40   5647.53      1.00
     β[27]     -2.60      2.13     -2.55     -6.46      0.69   4187.02      1.00
     β[28]    -13.06      3.09    -13.13    -18.11     -7.93   4795.54      1.00
     β[29]     -7.67      2.38     -7.71    -11.57     -3.82   4117.66      1.00
     β[30]      8.04      2.25      8.04      4.13     11.55   4600.00      1.00
     β[31]      2.45      2.26      2.47     -1.03      6.30   4007.59      1.00
     β[32]     -6.58      2.23     -6.59    -10.21     -2.82   3832.25      1.00
     β[33]     -0.80      2.37     -0.81     -4.99      2.80   4361.64      1.00
     β[34]     -0.83      2.16     -0.83     -4.16      2.91   4025.73      1.00
     β[35]    -10.61      2.40    -10.61    -14.60     -6.78   3563.09      1.00
     β[36]     -5.19      2.17     -5.18     -8.89     -1.93   3299.45      1.00
     β[37]     -0.25      2.42     -0.19     -4.01      3.83   3704.36      1.00
     β[38]     -2.73      2.35     -2.70     -6.32      1.24   4342.30      1.00
     β[39]     -3.52      1.84     -3.51     -6.77     -0.84   3976.06      1.00
     β[40]     -1.34      2.30     -1.37     -5.12      2.42   3742.77      1.00
     β[41]    -11.93      2.96    -11.83    -16.93     -7.23   4682.23      1.00
     β[42]     -5.60      1.96     -5.62     -8.69     -2.24   4732.31      1.00
     β[43]     -4.12      2.35     -4.08     -7.66     -0.08   4616.45      1.00
     β[44]     -5.79      2.20     -5.80     -9.27     -2.09   4924.40      1.00
     β[45]    -19.96      3.06    -19.91    -24.63    -14.61   3679.62      1.00
     β[46]     -7.46      3.85     -7.50    -13.75     -1.59   4155.61      1.00
     β[47]     -5.62      2.99     -5.66    -10.20     -0.32   4985.98      1.00
     β[48]      0.74      2.23      0.77     -2.89      4.42   3883.18      1.00
     β[49]     -0.05      1.96     -0.03     -3.14      3.21   5014.63      1.00
     β[50]     -6.66      2.21     -6.70    -10.32     -3.04   3677.17      1.00
     β[51]     -4.38      2.23     -4.35     -8.12     -0.87   4395.64      1.00
     β[52]     -6.26      2.33     -6.21    -10.11     -2.64   4683.23      1.00
     β[53]    -11.29      3.00    -11.30    -16.50     -6.57   4172.36      1.00
     β[54]      0.53      2.20      0.49     -2.74      4.33   3909.65      1.00
     β[55]     -4.76      2.25     -4.77     -8.53     -1.01   4825.57      1.00
     β[56]     -6.44      3.85     -6.41    -12.56      0.06   4113.27      1.00
     β[57]     -2.80      2.27     -2.74     -6.67      0.70   3687.04      1.00
     β[58]     -7.49      2.26     -7.51    -10.84     -3.50   4276.53      1.00
     β[59]     -4.83      2.15     -4.82     -8.31     -1.30   5575.94      1.00
     β[60]     -1.40      2.24     -1.41     -4.98      2.35   4515.19      1.00
     β[61]     -3.11      2.28     -3.11     -6.95      0.55   4081.08      1.00
     β[62]     -8.59      2.31     -8.61    -12.35     -4.93   3377.48      1.00
     β[63]     -0.72      2.21     -0.73     -4.17      3.04   4547.00      1.00
     β[64]     -1.68      2.26     -1.63     -5.77      1.66   3385.73      1.00
     β[65]     -3.64      2.22     -3.64     -7.07      0.12   5571.42      1.00
     β[66]     -3.56      2.21     -3.63     -7.35     -0.16   4991.63      1.00
     β[67]     -1.55      1.90     -1.54     -4.73      1.41   4627.93      1.00
     β[68]     -7.49      2.09     -7.48    -11.06     -4.34   4079.94      1.00
     β[69]     -1.36      2.33     -1.39     -5.11      2.52   4569.82      1.00
     β[70]     -0.29      2.23     -0.25     -4.25      2.95   4242.55      1.00
     β[71]    -10.03      1.98    -10.06    -13.32     -6.94   4442.15      1.00
     β[72]     -7.53      2.24     -7.46    -11.36     -3.98   5482.66      1.00
     β[73]     -3.37      2.34     -3.39     -7.11      0.40   5019.89      1.00
     β[74]      2.41      2.23      2.45     -1.18      6.12   4508.04      1.00
     β[75]     -1.25      2.26     -1.33     -4.92      2.55   5646.07      1.00
     β[76]     -4.17      2.23     -4.17     -7.97     -0.66   5675.45      1.00
     β[77]     -4.32      2.26     -4.33     -8.17     -0.89   4923.77      1.00
     β[78]     -0.77      2.26     -0.75     -4.48      2.97   4750.75      1.00
     β[79]     -3.90      2.27     -3.89     -7.59     -0.15   4487.57      1.00
     β[80]    -12.61      2.64    -12.60    -17.07     -8.36   4246.54      1.00
     β[81]     -1.52      2.29     -1.45     -5.58      2.05   4651.25      1.00
     β[82]     -4.46      2.18     -4.50     -7.95     -0.89   4468.52      1.00
     β[83]     -1.38      2.24     -1.39     -4.98      2.40   5150.16      1.00
     β[84]     10.63      2.32     10.60      6.78     14.33   4505.47      1.00
     β[85]      0.58      2.27      0.59     -3.09      4.36   3539.17      1.00
     β[86]    -10.39      2.28    -10.40    -14.11     -6.87   4488.76      1.00
     β[87]     -3.61      2.29     -3.60     -7.36      0.10   5316.67      1.00
     β[88]     -0.30      2.22     -0.32     -3.80      3.35   3819.55      1.00
     β[89]     -1.93      2.32     -1.92     -5.96      1.63   4179.68      1.00
     β[90]     -2.15      2.37     -2.15     -5.91      1.70   3341.19      1.00
     β[91]     -3.70      2.26     -3.69     -7.61     -0.12   3904.55      1.00
     β[92]      2.18      2.20      2.17     -1.22      6.00   4908.78      1.00
     β[93]     -3.33      2.28     -3.41     -7.01      0.50   3571.76      1.00
     β[94]     -2.74      2.32     -2.75     -6.52      0.98   4606.86      1.00
     β[95]     -4.31      2.28     -4.40     -7.87     -0.61   4132.73      1.00
     β[96]     -0.62      2.29     -0.61     -4.24      3.25   3616.64      1.00
     β[97]     -7.27      2.28     -7.18    -11.10     -3.62   4140.66      1.00
     β[98]     -5.21      2.29     -5.23     -8.89     -1.43   5202.64      1.00
     β[99]     -3.25      1.91     -3.24     -6.55     -0.23   4539.26      1.00
    β[100]     -2.23      2.32     -2.18     -6.27      1.43   5201.07      1.00
    β[101]     -4.05      2.14     -4.01     -7.60     -0.70   3869.35      1.00
    β[102]     -3.24      2.97     -3.28     -7.89      1.76   3127.43      1.00
    β[103]      0.79      2.43      0.81     -2.78      5.17   6492.31      1.00
    β[104]     -4.63      2.39     -4.58     -8.28     -0.38   4463.70      1.00
    β[105]     -3.55      2.32     -3.48     -7.44      0.17   3883.25      1.00
    β[106]     -3.12      2.38     -3.04     -7.04      0.68   4317.92      1.00
    β[107]     -0.32      2.26     -0.27     -3.93      3.38   4878.10      1.00
    β[108]    -19.46      2.24    -19.44    -23.17    -15.80   3707.86      1.00
    β[109]     -5.55      2.08     -5.50     -8.91     -2.23   5596.44      1.00
    β[110]     -3.00      2.37     -2.99     -6.79      0.91   4362.65      1.00
    β[111]     -1.77      2.96     -1.76     -6.89      2.89   4287.54      1.00
    β[112]     -5.76      4.03     -5.67    -12.41      0.60   4568.87      1.00
    β[113]      0.30      2.26      0.29     -3.39      4.05   4201.45      1.00
    β[114]     -8.96      2.15     -8.93    -12.54     -5.55   4902.72      1.00
    β[115]     -6.70      3.75     -6.70    -12.80     -0.58   3675.94      1.00
    β[116]     -4.58      2.27     -4.57     -7.94     -0.57   4612.40      1.00
    β[117]     -0.47      2.21     -0.45     -4.30      2.92   4604.65      1.00
    β[118]     -0.88      2.17     -0.89     -4.45      2.67   3857.26      1.00
    β[119]      3.37      3.06      3.34     -1.63      8.40   4641.89      1.00
    β[120]     -7.61      2.76     -7.61    -12.08     -3.07   4584.45      1.00
    β[121]     -7.80      2.18     -7.84    -11.53     -4.31   4143.18      1.00
    β[122]     -4.47      2.28     -4.42     -8.62     -1.11   5003.13      1.00
    β[123]     -1.87      3.11     -1.85     -6.48      3.77   3837.84      1.00
    β[124]     -7.47      2.30     -7.42    -11.15     -3.56   3488.98      1.00
    β[125]     -2.96      2.30     -2.95     -6.51      0.86   4045.44      1.00
    β[126]     -4.62      2.32     -4.67     -8.37     -0.90   4501.75      1.00
    β[127]     -7.88      2.22     -7.87    -11.55     -4.17   4004.76      1.00
    β[128]     -4.28      2.31     -4.32     -7.78     -0.44   3548.34      1.00
    β[129]     -5.02      2.21     -5.08     -8.72     -1.32   3864.48      1.00
    β[130]     -9.61      2.21     -9.60    -13.21     -6.13   4645.20      1.00
    β[131]      5.26      2.23      5.33      1.71      9.07   3949.80      1.00
    β[132]     -5.85      2.23     -5.85     -9.48     -2.19   4969.66      1.00
    β[133]      0.48      3.14      0.54     -4.68      5.72   4701.96      1.00
    β[134]     -1.64      2.16     -1.65     -5.28      1.70   4533.16      1.00
    β[135]      1.34      2.32      1.36     -2.49      5.12   4581.77      1.00
    β[136]     -5.27      2.41     -5.27     -9.53     -1.62   4665.97      1.00
    β[137]     -5.41      2.89     -5.32    -10.61     -1.06   4848.17      1.00
    β[138]     -0.34      2.25     -0.35     -3.84      3.53   6135.90      1.00
    β[139]     -8.16      2.34     -8.12    -11.84     -4.21   4310.14      1.00
    β[140]     -3.98      2.25     -3.96     -7.82     -0.19   5176.38      1.00
    β[141]     -2.42      2.22     -2.46     -5.94      1.29   4938.04      1.00
    β[142]     -7.50      2.23     -7.48    -10.93     -3.83   5237.66      1.00
    β[143]     -1.22      1.95     -1.24     -4.48      1.89   4102.71      1.00
    β[144]     -6.33      3.78     -6.32    -12.02      0.48   3855.15      1.00
    β[145]      0.36      2.30      0.31     -3.04      4.50   5106.15      1.00
    β[146]      5.67      2.24      5.69      2.32      9.50   4880.28      1.00
    β[147]     -1.28      2.33     -1.32     -4.71      2.91   4092.35      1.00
    β[148]     -2.58      2.34     -2.60     -6.47      1.23   3386.65      1.00
    β[149]    -12.39      2.23    -12.40    -16.17     -8.91   5298.42      1.00
    β[150]     -6.77      3.07     -6.72    -11.85     -1.95   4858.98      1.00
    β[151]     -7.95      2.38     -7.92    -11.65     -3.74   4461.45      1.00
    β[152]     -2.49      2.33     -2.50     -6.22      1.21   3764.36      1.00
    β[153]     -1.30      2.37     -1.29     -5.04      2.67   6336.56      1.00
    β[154]     -3.86      2.20     -3.87     -7.43     -0.39   3279.35      1.00
    β[155]     -2.65      2.41     -2.66     -7.01      0.97   4424.41      1.00
    β[156]     -3.71      2.22     -3.71     -7.33     -0.15   4825.16      1.00
    β[157]     -8.63      3.94     -8.60    -14.57     -1.90   3967.71      1.00
    β[158]     -3.99      2.31     -3.98     -7.49     -0.07   4349.28      1.00
    β[159]     -4.94      1.98     -4.90     -8.41     -1.91   4700.91      1.00
    β[160]     -6.93      2.43     -6.99    -10.98     -2.99   4785.30      1.00
    β[161]     -3.47      2.13     -3.47     -7.00     -0.16   3972.92      1.00
    β[162]    -15.86      3.11    -15.80    -21.10    -10.84   4183.32      1.00
    β[163]     -0.94      1.87     -0.95     -4.08      2.08   4370.92      1.00
    β[164]     -2.01      2.22     -2.00     -5.44      1.98   5639.13      1.00
    β[165]     -2.00      2.22     -1.98     -5.55      1.68   3479.20      1.00
    β[166]     -4.23      2.27     -4.25     -7.71     -0.32   4338.94      1.00
    β[167]     -1.70      3.12     -1.79     -6.56      3.57   4373.63      1.00
    β[168]     -1.91      2.28     -1.90     -5.93      1.68   4487.30      1.00
    β[169]    -17.05      2.31    -17.07    -20.95    -13.41   4857.17      1.00
    β[170]     -1.69      2.27     -1.68     -5.50      1.95   4521.26      1.00
    β[171]     -2.54      2.31     -2.57     -6.12      1.41   4106.66      1.00
    β[172]     -1.94      1.91     -1.91     -4.93      1.32   3996.98      1.00
    β[173]     -4.88      2.92     -4.92     -9.76     -0.26   4223.72      1.00
    β[174]     -7.95      2.24     -7.97    -11.47     -4.07   4681.34      1.00
    β[175]     -1.66      2.25     -1.69     -4.95      2.32   3655.75      1.00
       μ_α   2774.79     54.96   2775.49   2686.53   2863.70   3474.50      1.00
       μ_β     -4.17      0.41     -4.17     -4.85     -3.51   2474.19      1.00
         σ    136.71      2.74    136.64    132.14    141.13   1906.90      1.00
       σ_α    723.12     31.96    722.08    672.81    774.22   3737.96      1.00
       σ_β      4.99      0.36      4.96      4.39      5.55   2482.51      1.00

Number of divergences: 0

import arviz as az

data = az.from_numpyro(mcmc)
az.plot_trace(data, compact=True, figsize=(15, 25));

Let’s see how this is done:

The model is described by the following equations:

\[ \begin{align} \mu_{\alpha} &\sim \text{Normal}(0, 500) \\ \sigma_{\alpha} &\sim \text{Half-Normal}(100) \\ \mu_{\beta} &\sim \text{Normal}(0, 3) \\ \sigma_{\beta} &\sim \text{Half-Normal}(3) \\ \alpha_i &\sim \text{Normal}(\mu_{\alpha}, \sigma_{\alpha}) \\ \beta_i &\sim \text{Normal}(\mu_{\beta}, \sigma_{\beta}) \\ \sigma &\sim \text{Half-Normal}(100) \\ FVC_{ij} &\sim \text{Normal}(\alpha_i + t \beta_i, \sigma) \end{align} \] where \(t\) is week variable.

There are 176 patients, each has multiple weekly observations. There are 176 \(\alpha_i\) and 176 \(\beta_i\) (random effects for intercept and for slope). There are two “fixed effects”, \(\mu_{\alpha}\) and \(\mu_{\beta}\). And there is one standard deviation \(\sigma\) for the residuals, two standard deviations \(\sigma_{\alpha}\) and \(\sigma_{\beta}\) for the random effects.

The idea of Bayesian inference is that we have this model, and the data, and we want to infer the posterior distribution of the parameters given the data. The posterior distribution is proportional to the likelihood times the prior distribution.

We can see in this case, the random effects model has similar results as the Bayesian model.

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive

train = pd.read_csv(
    "https://gist.githubusercontent.com/ucals/"
    "2cf9d101992cb1b78c2cdd6e3bac6a4b/raw/"
    "43034c39052dcf97d4b894d2ec1bc3f90f3623d9/"
    "osic_pulmonary_fibrosis.csv"
)


from jax import random

def model(patient_code, Weeks, FVC_obs=None):
    μ_α = numpyro.sample("μ_α", dist.Normal(0.0, 500.0))
    σ_α = numpyro.sample("σ_α", dist.HalfNormal(100.0))
    μ_β = numpyro.sample("μ_β", dist.Normal(0.0, 3.0))
    σ_β = numpyro.sample("σ_β", dist.HalfNormal(3.0))

    n_patients = len(np.unique(patient_code))

    with numpyro.plate("plate_i", n_patients):
        α = numpyro.sample("α", dist.Normal(μ_α, σ_α))
        β = numpyro.sample("β", dist.Normal(μ_β, σ_β))

    σ = numpyro.sample("σ", dist.HalfNormal(100.0))
    FVC_est = α[patient_code] + β[patient_code] * Weeks

    with numpyro.plate("data", len(patient_code)):
        numpyro.sample("obs", dist.Normal(FVC_est, σ), obs=FVC_obs)
       
from sklearn.preprocessing import LabelEncoder

patient_encoder = LabelEncoder()
train["patient_code"] = patient_encoder.fit_transform(train["Patient"].values)

FVC_obs = train["FVC"].values
Weeks = train["Weeks"].values
patient_code = train["patient_code"].values
        
        
numpyro.render_model(
    model=model,
    model_args=(patient_code, Weeks, FVC_obs),
    render_distributions=True,
    render_params=True,
)
<graphviz.graphs.Digraph object at 0x7cfcfdb6a510>