Iowa Gambling Task

In this tutorial, we will fit the PVL-Delta model to data from the Iowa Gambling Task (IGT) using the ActionModels package. In the IGT, participants choose cards from four decks, each with different reward and loss probabilities, and must learn over time which decks are advantageous. We will use data from Ahn et al. (2014), which includes healthy controls and participants with heroin or amphetamine addictions. There are more details about the collected data in the docs/example_data/ahn_et_al_2014/ReadMe.txt

Loading data

First, we load the ActionModels package. We also load CSV and Dataframes for loading the data, and StatsPlots for plotting the results.

using ActionModels
using CSV, DataFrames
using StatsPlots

Then we load the ahn et al. (2014) data, which is available in the docs/example_data/ahn_et_al_2014 folder.

#Import data
data_healthy = CSV.read(
    joinpath(docs_path, "example_data", "ahn_et_al_2014", "IGTdata_healthy_control.txt"),
    DataFrame,
)
data_healthy[!, :clinical_group] .= "1_control"
data_heroin = CSV.read(
    joinpath(docs_path, "example_data", "ahn_et_al_2014", "IGTdata_heroin.txt"),
    DataFrame,
)
data_heroin[!, :clinical_group] .= "2_heroin"
data_amphetamine = CSV.read(
    joinpath(docs_path, "example_data", "ahn_et_al_2014", "IGTdata_amphetamine.txt"),
    DataFrame,
)
data_amphetamine[!, :clinical_group] .= "3_amphetamine"

#Combine into one dataframe
ahn_data = vcat(data_healthy, data_heroin, data_amphetamine)
ahn_data[!, :subjID] = string.(ahn_data[!, :subjID])

#Make column with total reward
ahn_data[!, :reward] = Float64.(ahn_data[!, :gain] + ahn_data[!, :loss]);

show(ahn_data)
12883×7 DataFrame
   Row │ trial  deck   gain   loss   subjID  clinical_group  reward
       │ Int64  Int64  Int64  Int64  String  String          Float64
───────┼─────────────────────────────────────────────────────────────
     1 │     1      3     50      0  103     1_control          50.0
     2 │     2      3     60      0  103     1_control          60.0
     3 │     3      3     40    -50  103     1_control         -10.0
     4 │     4      4     50      0  103     1_control          50.0
     5 │     5      3     55      0  103     1_control          55.0
     6 │     6      1    100      0  103     1_control         100.0
     7 │     7      1    120      0  103     1_control         120.0
     8 │     8      2    100      0  103     1_control         100.0
   ⋮   │   ⋮      ⋮      ⋮      ⋮      ⋮           ⋮            ⋮
 12877 │    94      4     70      0  344     3_amphetamine      70.0
 12878 │    95      2    150      0  344     3_amphetamine     150.0
 12879 │    96      4     50      0  344     3_amphetamine      50.0
 12880 │    97      2    110      0  344     3_amphetamine     110.0
 12881 │    98      4     70   -275  344     3_amphetamine    -205.0
 12882 │    99      2    150      0  344     3_amphetamine     150.0
 12883 │   100      3     70    -25  344     3_amphetamine      45.0
                                                   12868 rows omitted

For this example, we will subset the data to only include two subjects from each clinical group. This makes the runtime much shorter. Simply skip this step if you want to use the full dataset.

ahn_data =
    filter(row -> row[:subjID] in ["103", "117", "105", "136", "130", "149"], ahn_data);

Creating the model

Then we construct the model to be fitted to the data. We use the PVL-Delta action model, which is a classic model for the IGT. The PVL-Delta is a type of reinfrocement learning model that learns the expected value for each of the decks in the IGT. First, the observed reward is transformed with a prospect theory-based utlity curve. This means that the subjective value of a reward increses sub-linearly with reward magnitute, and that losses are weighted more heavily than gains. The expected value of each deck is then updated using a Rescorla-Wagner-like update rule. Finally, the probability of selecting each deck is calculated using a softmax function over the expected values of the decks, scaled by an inverse action noise parameter. In summary, the PVL-Delta has four parameters: the learning rate $\alpha$, the reward sensitivity $A$, the loss aversion $w$, and the action noise $\beta$. See the section on the PVL-Delta premade model in the documentation for more details.

We create the PVL-Delta using the premade model from ActionModels. We specify the number of decks, and also that actions are selected before the expected values are updated. This is because in the IGT, at least as structured in this dataset, participants select a deck before they receive the reward and update expectations.

action_model = ActionModel(PVLDelta(n_options = 4, act_before_update = true))
-- ActionModel --
Action model function: pvl_delta_act_before_update
Number of parameters: 5
Number of states: 1
Number of observations: 2
Number of actions: 1

We then specify which column in the data corresponds to the action (deck choice) and which columns correspond to the observations (deck and reward). We also specify the columns that uniquely identify each session.

action_cols = :deck
observation_cols = (chosen_option = :deck, reward = :reward)
session_cols = :subjID;

Finally, we create the full model. We use a hierarchical regression model to predict the parameters of the PVL-Delta model based on the clinical group (healthy, heroin, or amphetamine). First, we will set appropriate priors for the regression coefficients. For the action noise and the loss aversion, the outcome of the regression will be exponentiated before it is used in the model, so pre-transformed outcomes around 2 (exp(2) ≈ 7) are among the most extreme values to be expected. For the learning rate and reward sensitivity, we will use a logistic transformation, so pre-transformed outcomes around around 5 (logistic(5) ≈ 0.993) are among the most extreme values to be expected.

regression_prior = RegressionPrior(
    β = [Normal(0, 0.2), Normal(0, 0.3), Normal(0,0.3)],
    σ = truncated(Normal(0, 0.3), lower = 0),
)

plot(regression_prior.β[1], label = "Intercept")
plot!(regression_prior.β[2], label = "Effect of clinical group")
plot!(regression_prior.σ, label = "Random intercept std")
title!("Regression priors for the Rescorla-Wagner model")
xlabel!("Regression coefficient")
ylabel!("Density")

We can now create the popluation model, and finally create the full model object.

population_model = [
    Regression(
        @formula(learning_rate ~ clinical_group + (1 | subjID)),
        logistic,
        regression_prior,
    ),
    Regression(
        @formula(reward_sensitivity ~ clinical_group + (1 | subjID)),
        logistic,
        regression_prior,
    ),
    Regression(
        @formula(loss_aversion ~ clinical_group + (1 | subjID)),
        exp,
        regression_prior,
    ),
    Regression(
        @formula(action_noise ~ clinical_group + (1 | subjID)),
        exp,
        regression_prior,
    ),
]

model = create_model(
    action_model,
    population_model,
    ahn_data,
    action_cols = action_cols,
    observation_cols = observation_cols,
    session_cols = session_cols,
)
-- ModelFit object --
Action model: pvl_delta_act_before_update
Linear regression population model
4 estimated action model parameters, 6 sessions
Posterior not sampled
Prior not sampled

Fitting the model

We are now ready to fit the model to the data. For this model, we will use the Enzyme automatic differentiation backend, which is a high-performance automatic differentiation library. Additionally, to keep the runtime of this tutorial short, we will only fit two chains with 500 samples. We will pass MCMCThreas() in order to parallelize the sampling across the two chains. This should take up to 10-15 minutes on a standard laptop. Switch to only a single thread for a learer progress bar.

#Set AD backend
using ADTypes: AutoEnzyme
import Enzyme: set_runtime_activity, Reverse
ad_type = AutoEnzyme(; mode = set_runtime_activity(Reverse, true));

#Fit model
chns = sample_posterior!(model, MCMCThreads(), init_params = :MAP, n_chains = 2, n_samples = 500, ad_type = ad_type)
Chains MCMC chain (500×52×2 Array{Float64, 3}):

Iterations        = 251:1:750
Number of chains  = 2
Samples per chain = 500
Wall duration     = 1087.99 seconds
Compute duration  = 1607.0 seconds
parameters        = learning_rate.β[1], learning_rate.β[2], learning_rate.β[3], learning_rate.ranef_1.σ[1], learning_rate.ranef_1.r[1], learning_rate.ranef_1.r[2], learning_rate.ranef_1.r[3], learning_rate.ranef_1.r[4], learning_rate.ranef_1.r[5], learning_rate.ranef_1.r[6], reward_sensitivity.β[1], reward_sensitivity.β[2], reward_sensitivity.β[3], reward_sensitivity.ranef_1.σ[1], reward_sensitivity.ranef_1.r[1], reward_sensitivity.ranef_1.r[2], reward_sensitivity.ranef_1.r[3], reward_sensitivity.ranef_1.r[4], reward_sensitivity.ranef_1.r[5], reward_sensitivity.ranef_1.r[6], loss_aversion.β[1], loss_aversion.β[2], loss_aversion.β[3], loss_aversion.ranef_1.σ[1], loss_aversion.ranef_1.r[1], loss_aversion.ranef_1.r[2], loss_aversion.ranef_1.r[3], loss_aversion.ranef_1.r[4], loss_aversion.ranef_1.r[5], loss_aversion.ranef_1.r[6], action_noise.β[1], action_noise.β[2], action_noise.β[3], action_noise.ranef_1.σ[1], action_noise.ranef_1.r[1], action_noise.ranef_1.r[2], action_noise.ranef_1.r[3], action_noise.ranef_1.r[4], action_noise.ranef_1.r[5], action_noise.ranef_1.r[6]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
                       parameters      mean       std      mcse    ess_bulk   ess_tail      rhat   ess_per_sec
                           Symbol   Float64   Float64   Float64     Float64    Float64   Float64       Float64

               learning_rate.β[1]   -0.0772    0.1360    0.0179    487.3965    49.2938    2.2348        0.3033
               learning_rate.β[2]    0.1384    0.2277    0.0677     27.1742    88.4432    2.2461        0.0169
               learning_rate.β[3]   -0.2957    0.3313    0.1921      3.2763     2.1370    2.2399        0.0020
       learning_rate.ranef_1.σ[1]    0.7915    0.2666    0.0303     89.3771    92.1977    2.2334        0.0556
       learning_rate.ranef_1.r[1]   -2.5884    1.2142    0.3600     89.8795   129.6627    2.2395        0.0559
       learning_rate.ranef_1.r[2]    0.3158    0.5824    0.0952    253.2225    63.9834    2.2538        0.1576
       learning_rate.ranef_1.r[3]   -0.4122    0.5462    0.1593      9.6980    46.1141    2.2327        0.0060
       learning_rate.ranef_1.r[4]    0.0266    0.8071    0.3003      8.4308   112.9638    2.2336        0.0052
       learning_rate.ranef_1.r[5]   -0.5968    0.8017    0.1527     14.9271   113.9311    2.2606        0.0093
       learning_rate.ranef_1.r[6]    0.5159    0.9456    0.4666      4.0995   147.9296    2.2560        0.0026
          reward_sensitivity.β[1]   -0.3721    0.1417    0.0062    454.0315   171.0156    2.2324        0.2825
          reward_sensitivity.β[2]   -0.2280    0.1980    0.0084    784.1333   186.5376    2.2410        0.4879
          reward_sensitivity.β[3]   -0.1269    0.2567    0.1102      7.1871    73.9149    2.2373        0.0045
  reward_sensitivity.ranef_1.σ[1]    0.3080    0.3458    0.2031      2.5866     2.0325    2.2411        0.0016
  reward_sensitivity.ranef_1.r[1]   -0.0281    0.5865    0.0434    133.8821    64.2951    2.2494        0.0833
  reward_sensitivity.ranef_1.r[2]   -0.1471    0.4197    0.1410     12.6425   102.8753    2.2323        0.0079
  reward_sensitivity.ranef_1.r[3]   -0.3805    0.6003    0.2973      4.3431   190.7738    2.2324        0.0027
  reward_sensitivity.ranef_1.r[4]   -0.2724    0.5417    0.1801     15.5057    95.5693    2.2349        0.0096
  reward_sensitivity.ranef_1.r[5]   -0.3043    0.5930    0.2429      7.0046   169.8609    2.2336        0.0044
  reward_sensitivity.ranef_1.r[6]   -0.3014    0.5066    0.1829     13.1553    68.7551    2.2457        0.0082
               loss_aversion.β[1]   -0.2295    0.1325    0.0240     93.2163    62.4714    2.2355        0.0580
               loss_aversion.β[2]   -0.5151    0.2477    0.1015      8.8090   128.3562    2.2343        0.0055
               loss_aversion.β[3]   -0.1075    0.2213    0.0068   1021.8084   190.9917    2.2451        0.6358
       loss_aversion.ranef_1.σ[1]    0.3433    0.1558    0.0509     13.5332    48.3537    2.2489        0.0084
       loss_aversion.ranef_1.r[1]    0.0845    0.2773    0.1465      3.4900    55.3386    2.2434        0.0022
       loss_aversion.ranef_1.r[2]   -0.4048    0.3150    0.0720     13.5790   241.7611    2.2362        0.0084
       loss_aversion.ranef_1.r[3]   -0.0612    0.2270    0.0280     21.3886    43.6145    2.2566        0.0133
       loss_aversion.ranef_1.r[4]    0.1964    0.3587    0.2081      3.0690   213.5619    2.2427        0.0019
       loss_aversion.ranef_1.r[5]   -0.3142    0.3231    0.1390      4.8800     2.3495    2.2407        0.0030
       loss_aversion.ranef_1.r[6]   -0.0454    0.2132    0.0086    647.8554    40.0776    2.2729        0.4031
                action_noise.β[1]    0.4864    0.2020    0.0983      5.1146   182.6105    2.2332        0.0032
                action_noise.β[2]    0.3061    0.2145    0.0086    833.6584    58.4463    2.2369        0.5188
                action_noise.β[3]    0.0251    0.3679    0.2245      2.8257     2.0637    2.2348        0.0018
        action_noise.ranef_1.σ[1]    0.8973    0.2060    0.0764     12.6386    65.1240    2.2329        0.0079
        action_noise.ranef_1.r[1]    0.0396    0.6841    0.3466      4.3725     2.2788    2.2326        0.0027
        action_noise.ranef_1.r[2]    0.3654    0.3704    0.0147    680.1321    75.5211    2.2518        0.4232
        action_noise.ranef_1.r[3]    1.0046    0.4627    0.1700     12.5407    70.6206    2.2328        0.0078
        action_noise.ranef_1.r[4]    2.0815    0.9658    0.6018      2.7899   165.1563    2.2351        0.0017
        action_noise.ranef_1.r[5]    1.3146    0.6855    0.2200     14.1099   127.2067    2.2324        0.0088
        action_noise.ranef_1.r[6]    1.7732    0.9773    0.6145      3.0343   329.0925    2.2337        0.0019

Quantiles
                       parameters      2.5%     25.0%     50.0%     75.0%     97.5%
                           Symbol   Float64   Float64   Float64   Float64   Float64

               learning_rate.β[1]   -0.4074   -0.1173   -0.0434   -0.0434    0.1852
               learning_rate.β[2]   -0.4053    0.0263    0.2354    0.2354    0.4982
               learning_rate.β[3]   -0.5593   -0.5593   -0.5593   -0.0211    0.4548
       learning_rate.ranef_1.σ[1]    0.1538    0.8148    0.8305    0.8305    1.3092
       learning_rate.ranef_1.r[1]   -4.2693   -3.1422   -3.1422   -2.2011    0.0920
       learning_rate.ranef_1.r[2]   -0.6299    0.1503    0.1503    0.3207    1.8329
       learning_rate.ranef_1.r[3]   -1.3694   -0.6314   -0.6314   -0.1268    0.8657
       learning_rate.ranef_1.r[4]   -2.3442   -0.1559    0.4395    0.4395    0.9060
       learning_rate.ranef_1.r[5]   -2.4989   -0.8487   -0.8487   -0.1064    1.1653
       learning_rate.ranef_1.r[6]   -1.7711   -0.0519    1.1499    1.1499    1.2904
          reward_sensitivity.β[1]   -0.6891   -0.3840   -0.3840   -0.3590   -0.0410
          reward_sensitivity.β[2]   -0.7182   -0.2274   -0.2036   -0.2036    0.1915
          reward_sensitivity.β[3]   -0.7447   -0.2807    0.0229    0.0229    0.2032
  reward_sensitivity.ranef_1.σ[1]    0.0270    0.0270    0.0452    0.5964    1.0571
  reward_sensitivity.ranef_1.r[1]   -1.2953   -0.1284    0.0141    0.0141    1.4797
  reward_sensitivity.ranef_1.r[2]   -1.3006   -0.2604    0.0455    0.0455    0.3986
  reward_sensitivity.ranef_1.r[3]   -1.8838   -0.7646    0.0282    0.0282    0.0696
  reward_sensitivity.ranef_1.r[4]   -1.7846   -0.3824   -0.0145   -0.0145    0.3254
  reward_sensitivity.ranef_1.r[5]   -1.9911   -0.4873    0.0264    0.0264    0.2105
  reward_sensitivity.ranef_1.r[6]   -1.7051   -0.4443   -0.0452   -0.0452    0.1989
               loss_aversion.β[1]   -0.5610   -0.2702   -0.1925   -0.1925    0.0337
               loss_aversion.β[2]   -0.8350   -0.6539   -0.6539   -0.3910    0.1362
               loss_aversion.β[3]   -0.5636   -0.1210   -0.1210   -0.1056    0.4021
       loss_aversion.ranef_1.σ[1]    0.0343    0.2382    0.4046    0.4046    0.7093
       loss_aversion.ranef_1.r[1]   -0.6530   -0.0540    0.2837    0.2837    0.2837
       loss_aversion.ranef_1.r[2]   -1.2822   -0.5041   -0.5041   -0.1663    0.1338
       loss_aversion.ranef_1.r[3]   -0.5646   -0.1184   -0.1184   -0.0061    0.5145
       loss_aversion.ranef_1.r[4]   -0.6620   -0.0378    0.4822    0.4822    0.4822
       loss_aversion.ranef_1.r[5]   -0.7576   -0.5087   -0.5087   -0.0540    0.3040
       loss_aversion.ranef_1.r[6]   -0.6355   -0.0375   -0.0375   -0.0145    0.3657
                action_noise.β[1]    0.0141    0.3540    0.6209    0.6209    0.7048
                action_noise.β[2]   -0.2105    0.3098    0.3134    0.3134    0.7585
                action_noise.β[3]   -0.2856   -0.2856   -0.2856    0.3352    0.8054
        action_noise.ranef_1.σ[1]    0.3496    0.7904    1.0078    1.0078    1.1864
        action_noise.ranef_1.r[1]   -0.5347   -0.4402   -0.4402    0.5198    1.5679
        action_noise.ranef_1.r[2]   -0.3802    0.3142    0.3142    0.3987    1.3214
        action_noise.ranef_1.r[3]   -0.0569    0.7098    1.2434    1.2434    1.6849
        action_noise.ranef_1.r[4]    0.1318    1.2417    2.9216    2.9216    2.9216
        action_noise.ranef_1.r[5]    0.2197    0.9941    0.9941    1.6238    2.9385
        action_noise.ranef_1.r[6]   -0.0495    0.8779    2.6332    2.6332    2.6332

We can now inspect the results of the fitting process. We can plot the posterior distribution over the beta parameters of the regression model. We can see indications of lower reward sensitivity and lower loss aversion, as well as higher action noise, in the heroin and amphetamine groups compared to the healthy controls. Note that the posteriors would be more if we ha used the full dataset.

plot(
    plot(
        plot(
            title = "Learning rate",
            grid = false,
            showaxis = false,
            bottom_margin = -30Plots.px,
        ),
        density(
            chns[Symbol("learning_rate.β[2]")],
            title = "Heroin",
            label = nothing,
        ),
        density(
            chns[Symbol("learning_rate.β[3]")],
            title = "Amphetamine",
            label = nothing,
        ),
        layout = @layout([A{0.01h}; [B C]])
    ),
    plot(
        plot(
            title = "Reward sensitivity",
            grid = false,
            showaxis = false,
            bottom_margin = -50Plots.px,
        ),
        density(
            chns[Symbol("reward_sensitivity.β[2]")],
            label = nothing,
        ),
        density(
            chns[Symbol("reward_sensitivity.β[3]")],
            label = nothing,
        ),
        layout = @layout([A{0.01h}; [B C]])
    ),
    plot(
        plot(
            title = "Loss aversion",
            grid = false,
            showaxis = false,
            bottom_margin = -50Plots.px,
        ),
        density(
            chns[Symbol("loss_aversion.β[2]")],
            label = nothing,
        ),
        density(
            chns[Symbol("loss_aversion.β[3]")],
            label = nothing,
        ),
        layout = @layout([A{0.01h}; [B C]])
    ),
    plot(
        plot(
            title = "Action noise",
            grid = false,
            showaxis = false,
            bottom_margin = -50Plots.px,
        ),
        density(
            chns[Symbol("action_noise.β[2]")],
            label = nothing,
        ),
        density(
            chns[Symbol("action_noise.β[3]")],
            label = nothing,
        ),
        layout = @layout([A{0.01h}; [B C]])
    ),
    layout = (4,1), size = (800, 1000),
)

We can also extract the session parameters from the model.

session_parameters = get_session_parameters!(model);

#TODO: plot the session parameters

And we can extract a dataframe with the median of the posterior for each session parameter

parameters_df = summarize(session_parameters, median)
show(parameters_df)
6×5 DataFrame
 Row │ subjID  action_noise  learning_rate  loss_aversion  reward_sensitivity
     │ String  Float64       Float64        Float64        Float64
─────┼────────────────────────────────────────────────────────────────────────
   1 │ 103          1.19806      0.0397095       1.09555             0.408584
   2 │ 117          6.45123      0.337413        0.732784            0.411977
   3 │ 105          3.48517      0.584723        0.259128            0.367704
   4 │ 136          6.87877      0.341462        0.257926            0.363276
   5 │ 130         25.9703       0.459288        1.18377             0.407202
   6 │ 149         19.4623       0.633474        0.703993            0.399811

We can also look at the implied state trajectories, in this case the expected value.

state_trajectories = get_state_trajectories!(model, :expected_value);

#TODO: plot the state trajectories

These can also be summarized in a dataframe, for downstream analysis.

states_df = summarize(state_trajectories, median)
show(states_df)
606×6 DataFrame
 Row │ subjID  timestep  expected_value[1]  expected_value[2]  expected_value[ ⋯
     │ String  Int64     Any                Any                Any             ⋯
─────┼──────────────────────────────────────────────────────────────────────────
   1 │ 103            0  0.0                0.0                0.0             ⋯
   2 │ 103            1  0.0                0.0                0.196366
   3 │ 103            2  0.0                0.0                0.40012
   4 │ 103            3  0.0                0.0                0.272774
   5 │ 103            4  0.0                0.0                0.272774        ⋯
   6 │ 103            5  0.0                0.0                0.466106
   7 │ 103            6  0.260652           0.0                0.466106
   8 │ 103            7  0.531113           0.0                0.466106
  ⋮  │   ⋮        ⋮              ⋮                  ⋮                  ⋮       ⋱
 600 │ 149           94  -3.41555           6.66582            4.15543         ⋯
 601 │ 149           95  -3.41555           6.66582            3.39351
 602 │ 149           96  -3.41555           6.66582            3.39351
 603 │ 149           97  -3.41555           6.66582            3.39351
 604 │ 149           98  -3.41555           6.66582            2.83433         ⋯
 605 │ 149           99  -3.41555           6.73865            2.83433
 606 │ 149          100  2.8967             6.73865            2.83433
                                                  2 columns and 591 rows omitted

Comparing to a simple random model

We can also compare the PVL-Delta to a simple random model, which randomly samples actions from a fixed Categorical distribution.

function categorical_random(
    attributes::ModelAttributes,
    chosen_option::Int64,
    reward::Float64,
)

    action_probabilities = load_parameters(attributes).action_probabilities

    return Categorical(action_probabilities)
end

action_model = ActionModel(
    categorical_random,
    observations = (chosen_option = Observation(Int64), reward = Observation()),
    actions = (; deck = Action(Categorical)),
    parameters = (; action_probabilities = Parameter([0.3, 0.3, 0.3, 0.1])),
)
-- ActionModel --
Action model function: categorical_random
Number of parameters: 1
Number of states: 0
Number of observations: 2
Number of actions: 1

Fitting the model

For this model, we use an independent session population model. We set the prior for the action probabilities to be a Dirichlet distribution, which is a common prior for categorical distributions.

population_model = (; action_probabilities = Dirichlet([1, 1, 1, 1]),)

simple_model = create_model(
    action_model,
    population_model,
    ahn_data,
    action_cols = action_cols,
    observation_cols = observation_cols,
    session_cols = session_cols,
)

#Set AD backend
using ADTypes: AutoEnzyme
import Enzyme: set_runtime_activity, Reverse
ad_type = AutoEnzyme(; mode = set_runtime_activity(Reverse, true));

#Fit model
chns = sample_posterior!(simple_model, n_chains = 1, n_samples = 500, ad_type = ad_type)

#TODO: plot the results
Chains MCMC chain (500×36×1 Array{Float64, 3}):

Iterations        = 251:1:750
Number of chains  = 1
Samples per chain = 500
Wall duration     = 15.38 seconds
Compute duration  = 15.38 seconds
parameters        = action_probabilities.session[1, 1], action_probabilities.session[2, 1], action_probabilities.session[3, 1], action_probabilities.session[4, 1], action_probabilities.session[1, 2], action_probabilities.session[2, 2], action_probabilities.session[3, 2], action_probabilities.session[4, 2], action_probabilities.session[1, 3], action_probabilities.session[2, 3], action_probabilities.session[3, 3], action_probabilities.session[4, 3], action_probabilities.session[1, 4], action_probabilities.session[2, 4], action_probabilities.session[3, 4], action_probabilities.session[4, 4], action_probabilities.session[1, 5], action_probabilities.session[2, 5], action_probabilities.session[3, 5], action_probabilities.session[4, 5], action_probabilities.session[1, 6], action_probabilities.session[2, 6], action_probabilities.session[3, 6], action_probabilities.session[4, 6]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
                          parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec
                              Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64

  action_probabilities.session[1, 1]    0.2127    0.0437    0.0020   499.7635   264.2679    1.0202       32.4838
  action_probabilities.session[2, 1]    0.2225    0.0420    0.0018   566.9277   258.8146    1.0044       36.8494
  action_probabilities.session[3, 1]    0.1339    0.0345    0.0015   530.6587   323.6614    0.9988       34.4920
  action_probabilities.session[4, 1]    0.4309    0.0498    0.0022   509.9878   362.1328    1.0011       33.1484
  action_probabilities.session[1, 2]    0.1076    0.0303    0.0011   836.8070   318.1140    1.0023       54.3911
  action_probabilities.session[2, 2]    0.2874    0.0445    0.0017   713.0993   441.4632    1.0126       46.3503
  action_probabilities.session[3, 2]    0.3563    0.0456    0.0018   612.0289   408.3394    0.9993       39.7809
  action_probabilities.session[4, 2]    0.2487    0.0410    0.0018   500.2803   400.5173    1.0012       32.5174
  action_probabilities.session[1, 3]    0.1825    0.0356    0.0015   534.1960   408.3394    1.0001       34.7219
  action_probabilities.session[2, 3]    0.3469    0.0445    0.0017   646.2085   366.2060    0.9995       42.0025
  action_probabilities.session[3, 3]    0.1239    0.0335    0.0013   610.8795   283.5594    0.9991       39.7062
  action_probabilities.session[4, 3]    0.3466    0.0477    0.0021   549.8255   274.6839    0.9998       35.7378
  action_probabilities.session[1, 4]    0.3339    0.0496    0.0023   447.1155   285.2636    0.9992       29.0618
  action_probabilities.session[2, 4]    0.2102    0.0395    0.0017   535.5974   297.4440    1.0018       34.8130
  action_probabilities.session[3, 4]    0.2034    0.0425    0.0018   607.4729   377.4493    1.0036       39.4848
  action_probabilities.session[4, 4]    0.2526    0.0410    0.0017   597.4313   229.7026    0.9995       38.8321
  action_probabilities.session[1, 5]    0.2181    0.0428    0.0018   525.8366   298.1951    1.0073       34.1785
  action_probabilities.session[2, 5]    0.2728    0.0449    0.0016   766.2143   400.7875    0.9996       49.8027
  action_probabilities.session[3, 5]    0.3544    0.0502    0.0023   474.3856   470.5002    1.0031       30.8343
  action_probabilities.session[4, 5]    0.1547    0.0371    0.0016   543.4503   308.5591    1.0010       35.3234
  action_probabilities.session[1, 6]    0.1613    0.0367    0.0015   615.0265   367.4394    1.0003       39.9757
  action_probabilities.session[2, 6]    0.2803    0.0420    0.0019   510.9267   425.4347    0.9988       33.2094
  action_probabilities.session[3, 6]    0.2588    0.0431    0.0018   549.6302   414.4278    1.0068       35.7251
  action_probabilities.session[4, 6]    0.2997    0.0445    0.0019   521.1288   314.9595    0.9987       33.8725

Quantiles
                          parameters      2.5%     25.0%     50.0%     75.0%     97.5%
                              Symbol   Float64   Float64   Float64   Float64   Float64

  action_probabilities.session[1, 1]    0.1334    0.1822    0.2070    0.2417    0.3036
  action_probabilities.session[2, 1]    0.1477    0.1956    0.2200    0.2502    0.3105
  action_probabilities.session[3, 1]    0.0732    0.1100    0.1313    0.1524    0.2107
  action_probabilities.session[4, 1]    0.3273    0.3993    0.4334    0.4654    0.5225
  action_probabilities.session[1, 2]    0.0558    0.0863    0.1053    0.1270    0.1702
  action_probabilities.session[2, 2]    0.2003    0.2574    0.2842    0.3162    0.3784
  action_probabilities.session[3, 2]    0.2728    0.3246    0.3531    0.3854    0.4552
  action_probabilities.session[4, 2]    0.1739    0.2183    0.2459    0.2778    0.3276
  action_probabilities.session[1, 3]    0.1192    0.1583    0.1796    0.2055    0.2582
  action_probabilities.session[2, 3]    0.2645    0.3153    0.3456    0.3796    0.4290
  action_probabilities.session[3, 3]    0.0640    0.0996    0.1214    0.1425    0.1949
  action_probabilities.session[4, 3]    0.2640    0.3130    0.3428    0.3787    0.4435
  action_probabilities.session[1, 4]    0.2437    0.2983    0.3327    0.3659    0.4304
  action_probabilities.session[2, 4]    0.1370    0.1841    0.2074    0.2353    0.2901
  action_probabilities.session[3, 4]    0.1253    0.1757    0.1997    0.2296    0.2974
  action_probabilities.session[4, 4]    0.1801    0.2235    0.2523    0.2781    0.3388
  action_probabilities.session[1, 5]    0.1414    0.1899    0.2159    0.2452    0.3125
  action_probabilities.session[2, 5]    0.1881    0.2389    0.2732    0.3035    0.3591
  action_probabilities.session[3, 5]    0.2518    0.3214    0.3543    0.3849    0.4532
  action_probabilities.session[4, 5]    0.0890    0.1303    0.1534    0.1766    0.2413
  action_probabilities.session[1, 6]    0.0995    0.1350    0.1584    0.1858    0.2361
  action_probabilities.session[2, 6]    0.2021    0.2494    0.2794    0.3071    0.3661
  action_probabilities.session[3, 6]    0.1818    0.2295    0.2581    0.2866    0.3446
  action_probabilities.session[4, 6]    0.2124    0.2693    0.2996    0.3295    0.3903

We can also compare how well the PVL-Delta model fits the data compared to the simple random model.

#TODO: model comparison

This page was generated using Literate.jl.