Iowa Gambling Task

In this tutorial, we will fit the PVL-Delta model to data from the Iowa Gambling Task (IGT) using the ActionModels package. In the IGT, participants choose cards from four decks, each with different reward and loss probabilities, and must learn over time which decks are advantageous. We will use data from Ahn et al. (2014), which includes healthy controls and participants with heroin or amphetamine addictions. There are more details about the collected data in the docs/example_data/ahn_et_al_2014/ReadMe.txt

Loading data

First, we load the ActionModels package. We also load CSV and Dataframes for loading the data, and StatsPlots for plotting the results.

using ActionModels
using CSV, DataFrames
using StatsPlots

Then we load the ahn et al. (2014) data, which is available in the docs/example_data/ahn_et_al_2014 folder.

#Import data
data_healthy = CSV.read(
    joinpath(docs_path, "example_data", "ahn_et_al_2014", "IGTdata_healthy_control.txt"),
    DataFrame,
)
data_healthy[!, :clinical_group] .= "1_control"
data_heroin = CSV.read(
    joinpath(docs_path, "example_data", "ahn_et_al_2014", "IGTdata_heroin.txt"),
    DataFrame,
)
data_heroin[!, :clinical_group] .= "2_heroin"
data_amphetamine = CSV.read(
    joinpath(docs_path, "example_data", "ahn_et_al_2014", "IGTdata_amphetamine.txt"),
    DataFrame,
)
data_amphetamine[!, :clinical_group] .= "3_amphetamine"

#Combine into one dataframe
ahn_data = vcat(data_healthy, data_heroin, data_amphetamine)
ahn_data[!, :subjID] = string.(ahn_data[!, :subjID])

#Make column with total reward
ahn_data[!, :reward] = Float64.(ahn_data[!, :gain] + ahn_data[!, :loss]);

show(ahn_data)
12883×7 DataFrame
   Row │ trial  deck   gain   loss   subjID  clinical_group  reward
       │ Int64  Int64  Int64  Int64  String  String          Float64
───────┼─────────────────────────────────────────────────────────────
     1 │     1      3     50      0  103     1_control          50.0
     2 │     2      3     60      0  103     1_control          60.0
     3 │     3      3     40    -50  103     1_control         -10.0
     4 │     4      4     50      0  103     1_control          50.0
     5 │     5      3     55      0  103     1_control          55.0
     6 │     6      1    100      0  103     1_control         100.0
     7 │     7      1    120      0  103     1_control         120.0
     8 │     8      2    100      0  103     1_control         100.0
   ⋮   │   ⋮      ⋮      ⋮      ⋮      ⋮           ⋮            ⋮
 12877 │    94      4     70      0  344     3_amphetamine      70.0
 12878 │    95      2    150      0  344     3_amphetamine     150.0
 12879 │    96      4     50      0  344     3_amphetamine      50.0
 12880 │    97      2    110      0  344     3_amphetamine     110.0
 12881 │    98      4     70   -275  344     3_amphetamine    -205.0
 12882 │    99      2    150      0  344     3_amphetamine     150.0
 12883 │   100      3     70    -25  344     3_amphetamine      45.0
                                                   12868 rows omitted

For this example, we will subset the data to only include two subjects from each clinical group. This makes the runtime much shorter. Simply skip this step if you want to use the full dataset.

ahn_data =
    filter(row -> row[:subjID] in ["103", "117", "105", "136", "130", "149"], ahn_data);

Creating the model

Then we construct the model to be fitted to the data. We use the PVL-Delta action model, which is a classic model for the IGT. The PVL-Delta is a type of reinfrocement learning model that learns the expected value for each of the decks in the IGT. First, the observed reward is transformed with a prospect theory-based utlity curve. This means that the subjective value of a reward increses sub-linearly with reward magnitute, and that losses are weighted more heavily than gains. The expected value of each deck is then updated using a Rescorla-Wagner-like update rule. Finally, the probability of selecting each deck is calculated using a softmax function over the expected values of the decks, scaled by an inverse action noise parameter. In summary, the PVL-Delta has four parameters: the learning rate $\alpha$, the reward sensitivity $A$, the loss aversion $w$, and the action noise $\beta$. See the section on the PVL-Delta premade model in the documentation for more details.

We create the PVL-Delta using the premade model from ActionModels. We specify the number of decks, and also that actions are selected before the expected values are updated. This is because in the IGT, at least as structured in this dataset, participants select a deck before they receive the reward and update expectations.

action_model = ActionModel(PVLDelta(n_options = 4, act_before_update = true))
-- ActionModel --
Action model function: pvl_delta_act_before_update
Number of parameters: 5
Number of states: 1
Number of observations: 2
Number of actions: 1

We then specify which column in the data corresponds to the action (deck choice) and which columns correspond to the observations (deck and reward). We also specify the columns that uniquely identify each session.

action_cols = :deck
observation_cols = (chosen_option = :deck, reward = :reward)
session_cols = :subjID;

Finally, we create the full model. We use a hierarchical regression model to predict the parameters of the PVL-Delta model based on the clinical group (healthy, heroin, or amphetamine). First, we will set appropriate priors for the regression coefficients. For the action noise and the loss aversion, the outcome of the regression will be exponentiated before it is used in the model, so pre-transformed outcomes around 2 (exp(2) ≈ 7) are among the most extreme values to be expected. For the learning rate and reward sensitivity, we will use a logistic transformation, so pre-transformed outcomes around around 5 (logistic(5) ≈ 0.993) are among the most extreme values to be expected.

regression_prior = RegressionPrior(
    β = [Normal(0, 0.2), Normal(0, 0.3), Normal(0,0.3)],
    σ = truncated(Normal(0, 0.3), lower = 0),
)

plot(regression_prior.β[1], label = "Intercept")
plot!(regression_prior.β[2], label = "Effect of clinical group")
plot!(regression_prior.σ, label = "Random intercept std")
title!("Regression priors for the Rescorla-Wagner model")
xlabel!("Regression coefficient")
ylabel!("Density")

We can now create the popluation model, and finally create the full model object.

population_model = [
    Regression(
        @formula(learning_rate ~ clinical_group + (1 | subjID)),
        logistic,
        regression_prior,
    ),
    Regression(
        @formula(reward_sensitivity ~ clinical_group + (1 | subjID)),
        logistic,
        regression_prior,
    ),
    Regression(
        @formula(loss_aversion ~ clinical_group + (1 | subjID)),
        exp,
        regression_prior,
    ),
    Regression(
        @formula(action_noise ~ clinical_group + (1 | subjID)),
        exp,
        regression_prior,
    ),
]

model = create_model(
    action_model,
    population_model,
    ahn_data,
    action_cols = action_cols,
    observation_cols = observation_cols,
    session_cols = session_cols,
)
-- ModelFit object --
Action model: pvl_delta_act_before_update
Linear regression population model
4 estimated action model parameters, 6 sessions
Posterior not sampled
Prior not sampled

Fitting the model

We are now ready to fit the model to the data. For this model, we will use the Enzyme automatic differentiation backend, which is a high-performance automatic differentiation library. Additionally, to keep the runtime of this tutorial short, we will only fit two chains with 500 samples. We will pass MCMCThreas() in order to parallelize the sampling across the two chains. This should take up to 10-15 minutes on a standard laptop. Switch to only a single thread for a learer progress bar.

#Set AD backend
using ADTypes: AutoEnzyme
import Enzyme: set_runtime_activity, Reverse
ad_type = AutoEnzyme(; mode = set_runtime_activity(Reverse, true));

#Fit model
chns = sample_posterior!(model, MCMCThreads(), init_params = :MAP, n_chains = 2, n_samples = 500, ad_type = ad_type)
Chains MCMC chain (500×52×2 Array{Float64, 3}):

Iterations        = 251:1:750
Number of chains  = 2
Samples per chain = 500
Wall duration     = 3183.22 seconds
Compute duration  = 3988.36 seconds
parameters        = learning_rate.β[1], learning_rate.β[2], learning_rate.β[3], learning_rate.ranef_1.σ[1], learning_rate.ranef_1.r[1], learning_rate.ranef_1.r[2], learning_rate.ranef_1.r[3], learning_rate.ranef_1.r[4], learning_rate.ranef_1.r[5], learning_rate.ranef_1.r[6], reward_sensitivity.β[1], reward_sensitivity.β[2], reward_sensitivity.β[3], reward_sensitivity.ranef_1.σ[1], reward_sensitivity.ranef_1.r[1], reward_sensitivity.ranef_1.r[2], reward_sensitivity.ranef_1.r[3], reward_sensitivity.ranef_1.r[4], reward_sensitivity.ranef_1.r[5], reward_sensitivity.ranef_1.r[6], loss_aversion.β[1], loss_aversion.β[2], loss_aversion.β[3], loss_aversion.ranef_1.σ[1], loss_aversion.ranef_1.r[1], loss_aversion.ranef_1.r[2], loss_aversion.ranef_1.r[3], loss_aversion.ranef_1.r[4], loss_aversion.ranef_1.r[5], loss_aversion.ranef_1.r[6], action_noise.β[1], action_noise.β[2], action_noise.β[3], action_noise.ranef_1.σ[1], action_noise.ranef_1.r[1], action_noise.ranef_1.r[2], action_noise.ranef_1.r[3], action_noise.ranef_1.r[4], action_noise.ranef_1.r[5], action_noise.ranef_1.r[6]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
                       parameters      mean       std      mcse    ess_bulk   ess_tail      rhat   ess_per_sec
                           Symbol   Float64   Float64   Float64     Float64    Float64   Float64       Float64

               learning_rate.β[1]   -0.1458    0.2100    0.0075    788.0090   684.2905    1.0026        0.1976
               learning_rate.β[2]    0.0258    0.2978    0.0128    555.0488   473.9848    1.0016        0.1392
               learning_rate.β[3]   -0.0388    0.3037    0.0092   1084.7973   681.9421    1.0017        0.2720
       learning_rate.ranef_1.σ[1]    0.8265    0.3166    0.0446     55.5097   149.7110    1.0257        0.0139
       learning_rate.ranef_1.r[1]   -2.2781    1.3366    0.1821     63.9628   287.6264    1.0194        0.0160
       learning_rate.ranef_1.r[2]    0.6344    0.7523    0.0642    129.4021   298.4270    1.0246        0.0324
       learning_rate.ranef_1.r[3]   -0.2557    0.8156    0.0812     69.7595   110.2648    1.0572        0.0175
       learning_rate.ranef_1.r[4]   -0.3105    0.9908    0.0575    319.8078   264.4876    1.0030        0.0802
       learning_rate.ranef_1.r[5]   -0.3655    1.1238    0.1418     84.8017   100.9919    1.0166        0.0213
       learning_rate.ranef_1.r[6]   -0.0655    0.8830    0.0473    324.6192   337.1385    1.0027        0.0814
          reward_sensitivity.β[1]   -0.3467    0.2017    0.0075    731.9576   750.4011    1.0016        0.1835
          reward_sensitivity.β[2]   -0.2743    0.2754    0.0084   1085.5429   501.6289    0.9991        0.2722
          reward_sensitivity.β[3]   -0.2852    0.2698    0.0085    998.7901   746.8395    0.9991        0.2504
  reward_sensitivity.ranef_1.σ[1]    0.5639    0.2822    0.0216    166.6972   274.8063    1.0244        0.0418
  reward_sensitivity.ranef_1.r[1]   -0.0097    0.7782    0.0542    199.3430   499.4656    1.0139        0.0500
  reward_sensitivity.ranef_1.r[2]   -0.3440    0.4772    0.0242    386.4615   567.3087    1.0097        0.0969
  reward_sensitivity.ranef_1.r[3]   -0.7425    0.6589    0.0391    264.5037   659.1312    1.0169        0.0663
  reward_sensitivity.ranef_1.r[4]   -0.5864    0.6479    0.0341    360.9692   746.4718    1.0096        0.0905
  reward_sensitivity.ranef_1.r[5]   -0.6366    0.6637    0.0330    404.4595   845.9577    1.0081        0.1014
  reward_sensitivity.ranef_1.r[6]   -0.5395    0.5876    0.0315    346.4858   803.3882    1.0121        0.0869
               loss_aversion.β[1]   -0.2400    0.1774    0.0067    681.1574   631.0527    0.9997        0.1708
               loss_aversion.β[2]   -0.3553    0.2788    0.0086   1043.7526   835.4924    1.0068        0.2617
               loss_aversion.β[3]   -0.1182    0.2846    0.0085   1127.7070   677.8353    1.0058        0.2827
       loss_aversion.ranef_1.σ[1]    0.3315    0.2254    0.0210    103.8532   175.4735    1.0099        0.0260
       loss_aversion.ranef_1.r[1]   -0.1244    0.3454    0.0261    276.6112   156.6776    1.0114        0.0694
       loss_aversion.ranef_1.r[2]   -0.3756    0.4912    0.0338    229.5728   463.3372    1.0058        0.0576
       loss_aversion.ranef_1.r[3]    0.0172    0.4176    0.0262    331.8085   200.5155    1.0096        0.0832
       loss_aversion.ranef_1.r[4]   -0.1442    0.3947    0.0206    585.8721   266.3164    1.0066        0.1469
       loss_aversion.ranef_1.r[5]   -0.1656    0.4171    0.0195    595.0356   310.7070    1.0026        0.1492
       loss_aversion.ranef_1.r[6]   -0.0540    0.3672    0.0134    834.8783   493.4550    1.0009        0.2093
                action_noise.β[1]    0.3594    0.1950    0.0075    671.6886   702.4170    1.0031        0.1684
                action_noise.β[2]    0.2839    0.2819    0.0104    752.5935   650.2428    1.0049        0.1887
                action_noise.β[3]    0.3276    0.2775    0.0101    767.9811   747.2489    1.0014        0.1926
        action_noise.ranef_1.σ[1]    0.7755    0.2636    0.0297     85.8663   117.7332    1.0320        0.0215
        action_noise.ranef_1.r[1]    0.4668    0.6362    0.0557    137.7539   422.5892    1.0107        0.0345
        action_noise.ranef_1.r[2]    0.4040    0.4937    0.0206    586.2093   813.7125    1.0067        0.1470
        action_noise.ranef_1.r[3]    0.7979    0.5807    0.0414    188.6492   432.9618    1.0235        0.0473
        action_noise.ranef_1.r[4]    1.2002    0.7684    0.0654    125.4608   428.7655    1.0207        0.0315
        action_noise.ranef_1.r[5]    1.5845    0.8334    0.0901     77.8296   122.1488    1.0309        0.0195
        action_noise.ranef_1.r[6]    0.9108    0.6571    0.0546    127.1097   475.1995    1.0204        0.0319

Quantiles
                       parameters      2.5%     25.0%     50.0%     75.0%     97.5%
                           Symbol   Float64   Float64   Float64   Float64   Float64

               learning_rate.β[1]   -0.5417   -0.2902   -0.1459    0.0056    0.2452
               learning_rate.β[2]   -0.5368   -0.1872    0.0277    0.2393    0.5740
               learning_rate.β[3]   -0.6108   -0.2470   -0.0356    0.1748    0.5579
       learning_rate.ranef_1.σ[1]    0.1831    0.6430    0.8718    1.0481    1.3719
       learning_rate.ranef_1.r[1]   -4.4500   -3.2627   -2.5507   -1.3297    0.1445
       learning_rate.ranef_1.r[2]   -0.7627    0.1135    0.6070    1.1111    2.1959
       learning_rate.ranef_1.r[3]   -2.1956   -0.6280   -0.1933    0.1858    1.3038
       learning_rate.ranef_1.r[4]   -2.9022   -0.7483   -0.2089    0.2736    1.4440
       learning_rate.ranef_1.r[5]   -3.2458   -0.8141   -0.1636    0.3024    1.4904
       learning_rate.ranef_1.r[6]   -1.8845   -0.5593   -0.0443    0.3578    1.8946
          reward_sensitivity.β[1]   -0.7254   -0.4893   -0.3399   -0.2093    0.0400
          reward_sensitivity.β[2]   -0.8266   -0.4668   -0.2792   -0.0873    0.2650
          reward_sensitivity.β[3]   -0.8187   -0.4540   -0.2895   -0.1030    0.2495
  reward_sensitivity.ranef_1.σ[1]    0.0836    0.3542    0.5673    0.7629    1.1182
  reward_sensitivity.ranef_1.r[1]   -1.6819   -0.4480   -0.0311    0.3670    1.6934
  reward_sensitivity.ranef_1.r[2]   -1.4332   -0.6060   -0.2561   -0.0235    0.4284
  reward_sensitivity.ranef_1.r[3]   -2.1907   -1.1728   -0.6198   -0.2149    0.1553
  reward_sensitivity.ranef_1.r[4]   -2.1377   -0.9732   -0.4718   -0.0986    0.3562
  reward_sensitivity.ranef_1.r[5]   -2.2113   -1.0776   -0.4997   -0.1037    0.3214
  reward_sensitivity.ranef_1.r[6]   -1.9096   -0.8795   -0.4236   -0.1083    0.3144
               loss_aversion.β[1]   -0.6091   -0.3505   -0.2379   -0.1268    0.1131
               loss_aversion.β[2]   -0.8753   -0.5557   -0.3521   -0.1715    0.1882
               loss_aversion.β[3]   -0.6753   -0.3068   -0.1143    0.0771    0.4047
       loss_aversion.ranef_1.σ[1]    0.0402    0.1435    0.2940    0.4755    0.8554
       loss_aversion.ranef_1.r[1]   -0.9862   -0.2421   -0.0540    0.0640    0.4161
       loss_aversion.ranef_1.r[2]   -1.5874   -0.6068   -0.2265   -0.0200    0.1943
       loss_aversion.ranef_1.r[3]   -0.9246   -0.1311    0.0057    0.1855    0.9497
       loss_aversion.ranef_1.r[4]   -1.1359   -0.2926   -0.0542    0.0565    0.4857
       loss_aversion.ranef_1.r[5]   -1.2414   -0.3027   -0.0696    0.0484    0.4907
       loss_aversion.ranef_1.r[6]   -0.9308   -0.1926   -0.0231    0.1012    0.6751
                action_noise.β[1]   -0.0249    0.2262    0.3689    0.4897    0.7407
                action_noise.β[2]   -0.2989    0.1041    0.2818    0.4732    0.8300
                action_noise.β[3]   -0.1970    0.1313    0.3349    0.5195    0.8583
        action_noise.ranef_1.σ[1]    0.2123    0.6081    0.7742    0.9622    1.2752
        action_noise.ranef_1.r[1]   -0.7337    0.0262    0.3956    0.8898    1.7589
        action_noise.ranef_1.r[2]   -0.5064    0.0545    0.3613    0.7266    1.4281
        action_noise.ranef_1.r[3]   -0.1885    0.3539    0.7901    1.1799    1.9866
        action_noise.ranef_1.r[4]   -0.0710    0.6157    1.1314    1.7229    2.8707
        action_noise.ranef_1.r[5]    0.0771    1.0067    1.5401    2.1184    3.2687
        action_noise.ranef_1.r[6]   -0.1384    0.4052    0.8403    1.3284    2.2996

We can now inspect the results of the fitting process. We can plot the posterior distribution over the beta parameters of the regression model. We can see indications of lower reward sensitivity and lower loss aversion, as well as higher action noise, in the heroin and amphetamine groups compared to the healthy controls. Note that the posteriors would be more if we ha used the full dataset.

plot(
    plot(
        plot(
            title = "Learning rate",
            grid = false,
            showaxis = false,
            bottom_margin = -30Plots.px,
        ),
        density(
            chns[Symbol("learning_rate.β[2]")],
            title = "Heroin",
            label = nothing,
        ),
        density(
            chns[Symbol("learning_rate.β[3]")],
            title = "Amphetamine",
            label = nothing,
        ),
        layout = @layout([A{0.01h}; [B C]])
    ),
    plot(
        plot(
            title = "Reward sensitivity",
            grid = false,
            showaxis = false,
            bottom_margin = -50Plots.px,
        ),
        density(
            chns[Symbol("reward_sensitivity.β[2]")],
            label = nothing,
        ),
        density(
            chns[Symbol("reward_sensitivity.β[3]")],
            label = nothing,
        ),
        layout = @layout([A{0.01h}; [B C]])
    ),
    plot(
        plot(
            title = "Loss aversion",
            grid = false,
            showaxis = false,
            bottom_margin = -50Plots.px,
        ),
        density(
            chns[Symbol("loss_aversion.β[2]")],
            label = nothing,
        ),
        density(
            chns[Symbol("loss_aversion.β[3]")],
            label = nothing,
        ),
        layout = @layout([A{0.01h}; [B C]])
    ),
    plot(
        plot(
            title = "Action noise",
            grid = false,
            showaxis = false,
            bottom_margin = -50Plots.px,
        ),
        density(
            chns[Symbol("action_noise.β[2]")],
            label = nothing,
        ),
        density(
            chns[Symbol("action_noise.β[3]")],
            label = nothing,
        ),
        layout = @layout([A{0.01h}; [B C]])
    ),
    layout = (4,1), size = (800, 1000),
)

We can also extract the session parameters from the model.

session_parameters = get_session_parameters!(model);

#TODO: plot the session parameters

And we can extract a dataframe with the median of the posterior for each session parameter

parameters_df = summarize(session_parameters, median)
show(parameters_df)
6×5 DataFrame
 Row │ subjID  action_noise  learning_rate  loss_aversion  reward_sensitivity
     │ String  Float64       Float64        Float64        Float64
─────┼────────────────────────────────────────────────────────────────────────
   1 │ 103          2.17007      0.0614588       0.731244            0.395373
   2 │ 105          2.84362      0.620093        0.406255            0.280489
   3 │ 117          3.13918      0.416547        0.77889             0.269566
   4 │ 130          6.23599      0.414338        0.619374            0.240845
   5 │ 136          9.02988      0.427001        0.47848             0.233851
   6 │ 149          4.65828      0.440308        0.678762            0.248081

We can also look at the implied state trajectories, in this case the expected value.

state_trajectories = get_state_trajectories!(model, :expected_value);

#TODO: plot the state trajectories

These can also be summarized in a dataframe, for downstream analysis.

states_df = summarize(state_trajectories, median)
show(states_df)
606×6 DataFrame
 Row │ subjID  timestep  expected_value[1]  expected_value[2]  expected_value[ ⋯
     │ String  Int64     Any                Any                Any             ⋯
─────┼──────────────────────────────────────────────────────────────────────────
   1 │ 103            0  0.0                0.0                0.0             ⋯
   2 │ 103            1  0.0                0.0                0.341685
   3 │ 103            2  0.0                0.0                0.685106
   4 │ 103            3  0.0                0.0                0.450515
   5 │ 103            4  0.0                0.0                0.450515        ⋯
   6 │ 103            5  0.0                0.0                0.806783
   7 │ 103            6  0.457852           0.0                0.806783
   8 │ 103            7  0.927495           0.0                0.806783
  ⋮  │   ⋮        ⋮              ⋮                  ⋮                  ⋮       ⋱
 600 │ 149           94  -1.0957            2.99291            2.17897         ⋯
 601 │ 149           95  -1.0957            2.99291            2.04027
 602 │ 149           96  -1.0957            2.99291            2.04027
 603 │ 149           97  -1.0957            2.99291            2.04027
 604 │ 149           98  -1.0957            2.99291            1.89753         ⋯
 605 │ 149           99  -1.0957            3.05706            1.89753
 606 │ 149          100  0.720223           3.05706            1.89753
                                                  2 columns and 591 rows omitted

Comparing to a simple random model

We can also compare the PVL-Delta to a simple random model, which randomly samples actions from a fixed Categorical distribution.

function categorical_random(
    attributes::ModelAttributes,
    chosen_option::Int64,
    reward::Float64,
)

    action_probabilities = load_parameters(attributes).action_probabilities

    return Categorical(action_probabilities)
end

action_model = ActionModel(
    categorical_random,
    observations = (chosen_option = Observation(Int64), reward = Observation()),
    actions = (; deck = Action(Categorical)),
    parameters = (; action_probabilities = Parameter([0.3, 0.3, 0.3, 0.1])),
)
-- ActionModel --
Action model function: categorical_random
Number of parameters: 1
Number of states: 0
Number of observations: 2
Number of actions: 1

Fitting the model

For this model, we use an independent session population model. We set the prior for the action probabilities to be a Dirichlet distribution, which is a common prior for categorical distributions.

population_model = (; action_probabilities = Dirichlet([1, 1, 1, 1]),)

simple_model = create_model(
    action_model,
    population_model,
    ahn_data,
    action_cols = action_cols,
    observation_cols = observation_cols,
    session_cols = session_cols,
)

#Set AD backend
using ADTypes: AutoEnzyme
import Enzyme: set_runtime_activity, Reverse
ad_type = AutoEnzyme(; mode = set_runtime_activity(Reverse, true));

#Fit model
chns = sample_posterior!(simple_model, n_chains = 1, n_samples = 500, ad_type = ad_type)

#TODO: plot the results
Chains MCMC chain (500×36×1 Array{Float64, 3}):

Iterations        = 251:1:750
Number of chains  = 1
Samples per chain = 500
Wall duration     = 10.18 seconds
Compute duration  = 10.18 seconds
parameters        = action_probabilities.session[1, 1], action_probabilities.session[2, 1], action_probabilities.session[3, 1], action_probabilities.session[4, 1], action_probabilities.session[1, 2], action_probabilities.session[2, 2], action_probabilities.session[3, 2], action_probabilities.session[4, 2], action_probabilities.session[1, 3], action_probabilities.session[2, 3], action_probabilities.session[3, 3], action_probabilities.session[4, 3], action_probabilities.session[1, 4], action_probabilities.session[2, 4], action_probabilities.session[3, 4], action_probabilities.session[4, 4], action_probabilities.session[1, 5], action_probabilities.session[2, 5], action_probabilities.session[3, 5], action_probabilities.session[4, 5], action_probabilities.session[1, 6], action_probabilities.session[2, 6], action_probabilities.session[3, 6], action_probabilities.session[4, 6]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
                          parameters      mean       std      mcse    ess_bulk   ess_tail      rhat   ess_per_sec
                              Symbol   Float64   Float64   Float64     Float64    Float64   Float64       Float64

  action_probabilities.session[1, 1]    0.2105    0.0401    0.0011   1330.2875   349.8661    0.9981      130.6894
  action_probabilities.session[2, 1]    0.2210    0.0401    0.0011   1273.2539   384.7125    1.0001      125.0863
  action_probabilities.session[3, 1]    0.1347    0.0329    0.0011    917.2124   306.8386    0.9987       90.1083
  action_probabilities.session[4, 1]    0.4338    0.0481    0.0013   1257.3936   425.5089    0.9992      123.5282
  action_probabilities.session[1, 2]    0.1825    0.0422    0.0013   1023.3875   311.1505    1.0016      100.5391
  action_probabilities.session[2, 2]    0.3462    0.0495    0.0018    763.2401   243.8135    1.0011       74.9818
  action_probabilities.session[3, 2]    0.1244    0.0309    0.0011    716.8152   353.1196    0.9992       70.4210
  action_probabilities.session[4, 2]    0.3469    0.0439    0.0014   1010.2876   364.0013    1.0002       99.2521
  action_probabilities.session[1, 3]    0.1068    0.0309    0.0008   1174.0060   387.6724    0.9991      115.3361
  action_probabilities.session[2, 3]    0.2904    0.0448    0.0016    906.6387   310.2592    1.0017       89.0695
  action_probabilities.session[3, 3]    0.3553    0.0497    0.0015   1004.3480   320.3494    1.0011       98.6686
  action_probabilities.session[4, 3]    0.2475    0.0415    0.0014    914.4503   199.4877    1.0015       89.8369
  action_probabilities.session[1, 4]    0.2197    0.0400    0.0012   1087.2613   461.4608    0.9984      106.8142
  action_probabilities.session[2, 4]    0.2702    0.0389    0.0011   1190.6484   399.3389    1.0031      116.9711
  action_probabilities.session[3, 4]    0.3560    0.0442    0.0014    954.4264   398.1367    1.0127       93.7643
  action_probabilities.session[4, 4]    0.1540    0.0357    0.0013    765.7528   332.4821    1.0088       75.2287
  action_probabilities.session[1, 5]    0.3367    0.0479    0.0018    754.7404   392.5808    1.0032       74.1468
  action_probabilities.session[2, 5]    0.2119    0.0401    0.0015    732.0202   371.1017    0.9989       71.9147
  action_probabilities.session[3, 5]    0.2018    0.0387    0.0013    854.1399   448.9903    1.0071       83.9120
  action_probabilities.session[4, 5]    0.2496    0.0429    0.0012   1217.3508   491.6032    1.0035      119.5943
  action_probabilities.session[1, 6]    0.1661    0.0389    0.0012   1094.3367   353.1196    0.9981      107.5093
  action_probabilities.session[2, 6]    0.2780    0.0418    0.0015    878.3651   359.4191    1.0017       86.2919
  action_probabilities.session[3, 6]    0.2586    0.0412    0.0014    835.3493   362.4674    0.9995       82.0659
  action_probabilities.session[4, 6]    0.2973    0.0430    0.0015    842.8835   279.0662    1.0086       82.8061

Quantiles
                          parameters      2.5%     25.0%     50.0%     75.0%     97.5%
                              Symbol   Float64   Float64   Float64   Float64   Float64

  action_probabilities.session[1, 1]    0.1438    0.1817    0.2080    0.2357    0.2913
  action_probabilities.session[2, 1]    0.1510    0.1918    0.2181    0.2478    0.3041
  action_probabilities.session[3, 1]    0.0756    0.1120    0.1335    0.1532    0.2039
  action_probabilities.session[4, 1]    0.3435    0.4013    0.4308    0.4645    0.5285
  action_probabilities.session[1, 2]    0.1096    0.1516    0.1788    0.2094    0.2715
  action_probabilities.session[2, 2]    0.2547    0.3119    0.3442    0.3790    0.4447
  action_probabilities.session[3, 2]    0.0711    0.1025    0.1231    0.1435    0.1894
  action_probabilities.session[4, 2]    0.2514    0.3197    0.3484    0.3737    0.4335
  action_probabilities.session[1, 3]    0.0561    0.0836    0.1036    0.1290    0.1697
  action_probabilities.session[2, 3]    0.2077    0.2580    0.2860    0.3201    0.3786
  action_probabilities.session[3, 3]    0.2627    0.3199    0.3528    0.3937    0.4485
  action_probabilities.session[4, 3]    0.1709    0.2177    0.2488    0.2755    0.3284
  action_probabilities.session[1, 4]    0.1518    0.1897    0.2170    0.2491    0.2983
  action_probabilities.session[2, 4]    0.2032    0.2397    0.2704    0.2964    0.3418
  action_probabilities.session[3, 4]    0.2719    0.3261    0.3559    0.3861    0.4372
  action_probabilities.session[4, 4]    0.0927    0.1285    0.1505    0.1772    0.2315
  action_probabilities.session[1, 5]    0.2460    0.3021    0.3348    0.3693    0.4360
  action_probabilities.session[2, 5]    0.1421    0.1819    0.2084    0.2381    0.2925
  action_probabilities.session[3, 5]    0.1338    0.1740    0.1990    0.2280    0.2869
  action_probabilities.session[4, 5]    0.1727    0.2190    0.2479    0.2787    0.3344
  action_probabilities.session[1, 6]    0.0979    0.1363    0.1650    0.1908    0.2490
  action_probabilities.session[2, 6]    0.2017    0.2470    0.2760    0.3056    0.3611
  action_probabilities.session[3, 6]    0.1873    0.2292    0.2560    0.2861    0.3401
  action_probabilities.session[4, 6]    0.2148    0.2725    0.2977    0.3226    0.3893

We can also compare how well the PVL-Delta model fits the data compared to the simple random model.

#TODO: model comparison

This page was generated using Literate.jl.