Jumping Gaussian Estimation Task

In this tutorial, we will fit a Rescorla-Wagner model to data from the Jumping Gaussian Estimation Task (JGET). In the JGET, participants observe continuous outcomes sampled from a Gaussian distribution which at some trials "jumps" to be centered somewhere else. Participants must predict the outcome of the next trial based on the previous outcomes. The data we will use is from a study on schizotypy (Mikus et al., 2025), where participants completed the JGET and also filled out the Peters Delusions Inventory (PDI). The PDI is a self-report questionnaire that measures delusional ideation, and we will use it as a predictor in our model. There are several session per participant in the dataset, which will be modeled as separate experimental sessions. Data is taken from https://github.com/nacemikus/jget-schizotypy, where more information can also be found.

Loading data

First, we load the ActionModels package. We also load CSV and Dataframes for loading the data, and StatsPlots for plotting the results.

using ActionModels
using CSV, DataFrames
using StatsPlots

Then we load the JGET data, which is available in the docs/example_data/JGET folder.

#Trial-level data
JGET_data = CSV.read(
    joinpath(docs_path, "example_data", "JGET", "JGET_data_trial_preprocessed.csv"),
    DataFrame,
    missingstring = ["NaN", ""],
)
JGET_data = select(JGET_data, [:trials, :ID, :session, :outcome, :response, :confidence])

#Subject-level data
subject_data = CSV.read(
    joinpath(docs_path, "example_data", "JGET", "JGET_data_sub_preprocessed.csv"),
    DataFrame,
    missingstring = ["NaN", ""],
)
subject_data = select(subject_data, [:ID, :session, :pdi_total, :Age, :Gender, :Education])

#Join the data
JGET_data = innerjoin(JGET_data, subject_data, on = [:ID, :session])

#Remove ID's with missing actions and missing PDI scores
JGET_data = combine(
    groupby(JGET_data, [:ID, :session]),
    subdata -> any(ismissing, Matrix(subdata[!, [:response]])) ? DataFrame() : subdata,
)
JGET_data = combine(
    groupby(JGET_data, [:ID, :session]),
    subdata -> any(ismissing, Matrix(subdata[!, [:pdi_total]])) ? DataFrame() : subdata,
)
disallowmissing!(JGET_data, [:pdi_total, :response])

#Make the outcome into a Float64
JGET_data.outcome = Float64.(JGET_data.outcome)

show(JGET_data)
42480×10 DataFrame
   Row │ ID     session  trials  outcome  response  confidence  pdi_total  Age ⋯
       │ Int64  Int64    Int64   Float64  Float64   Float64?    Float64    Flo ⋯
───────┼────────────────────────────────────────────────────────────────────────
     1 │    20        1       1     62.0      15.0   0.0953106  0.0535714      ⋯
     2 │    20        1       2     55.0      57.0   0.0680593  0.0535714
     3 │    20        1       3     52.0      62.0   0.934176   0.0535714
     4 │    20        1       4     50.0      55.0   1.40098    0.0535714
     5 │    20        1       5     65.0      49.0   0.0674173  0.0535714      ⋯
     6 │    20        1       6     55.0      61.0   0.667615   0.0535714
     7 │    20        1       7     43.0      62.0   0.0677037  0.0535714
     8 │    20        1       8     59.0      38.0   0.0341998  0.0535714
   ⋮   │   ⋮       ⋮       ⋮        ⋮        ⋮          ⋮           ⋮          ⋱
 42474 │    74        4     234     23.0      18.0   0.600938   0.0863095      ⋯
 42475 │    74        4     235     17.0      26.0   0.934273   0.0863095
 42476 │    74        4     236     10.0      25.0   0.78428    0.0863095
 42477 │    74        4     237     16.0      18.0   0.600921   0.0863095
 42478 │    74        4     238     12.0      20.0   0.750922   0.0863095      ⋯
 42479 │    74        4     239     13.0      20.0   0.83429    0.0863095
 42480 │    74        4     240      8.0      23.0   0.734223   0.0863095
                                                3 columns and 42465 rows omitted

For this example, we will subset the data to only include four subjects in total, across the three experimental sessions. This makes the runtime much shorter. Simply skip this step if you want to use the full dataset.

JGET_data = filter(row -> row[:ID] in [20, 40, 60, 70], JGET_data);

Creating the model

Then we construct the model to be fitted to the data. We will use a classic Rescorla-Wagner model with a Gaussian report action model. The Rescorla-Wagner model is a simple reinforcement learning model that updates the expected value of an action based on the observed outcome. The Gaussian report action model assumes that the agent reports a continuous value sampled from a Gaussian distribution, where the mean is the expected value of the action and the standard deviation is a noise parameter. There are two parameters in the action model: the learning rate $\alpha$ and the action noise $\beta$. See the section on the Rescorla-Wagner model for more details.

We create the Rescorla-Wagner action model using the premade model from ActionModels.

action_model = ActionModel(RescorlaWagner())
-- ActionModel --
Action model function: rescorla_wagner_act_after_update
Number of parameters: 1
Number of states: 0
Number of observations: 1
Number of actions: 1
submodel type: ActionModels.ContinuousRescorlaWagner

We then specify which column in the data corresponds to the action (the response) and which columns correspond to the input (the outcome). There are two columns which jointly specify the sessions: the ID of the participant and the experimental session number.

action_cols = :response
observation_cols = :outcome
session_cols = [:ID, :session];

We then create the full model. We use a hierarchical regression model to predict the parameters of the Rescorla-Wagner model based on the PDI score. First we will set appropriate priors for the regression coefficients. For the action noise, the outcome of the regression will be exponentiated before it is used in the model, so pre-transformed outcomes around 3 (exp(3) ≈ 20) are among the most extreme values to be expected. For the learning rate, the outcome of the regression will be passed through a logistic function, so pre-transformed outcomes around around 5 (logistic(5) ≈ 0.993) are among the most extreme values to be expected. This means that we should limit priors to be fairly narrow, so that the linear regression does not go too far into inappropriate parameter space, which will increase the runtime of the fitting process.

regression_prior = RegressionPrior(
    β = [Normal(0, 0.3), Normal(0, 0.2), Normal(0, 0.5)],
    σ = truncated(Normal(0, 0.3), lower = 0),
)

plot(regression_prior.β[1], label = "Intercept")
plot!(regression_prior.β[2], label = "Effect of session number")
plot!(regression_prior.β[3], label = "Effect of PDI")
plot!(regression_prior.σ, label = "Random intercept std")
title!("Regression priors for the Rescorla-Wagner model")
xlabel!("Regression coefficient")
ylabel!("Density")

We then create the population model, which consists of two regression models: one for the learning rate and one for the action noise. We can then also create the full model using the create_model function.

population_model = [
    Regression(
        @formula(learning_rate ~ 1 + pdi_total + session + (1 | ID)),
        logistic,
        regression_prior,
    ),
    Regression(
        @formula(action_noise ~ 1 + pdi_total + session + (1 | ID)),
        exp,
        regression_prior,
    ),
]

model = create_model(
    action_model,
    population_model,
    JGET_data,
    action_cols = action_cols,
    observation_cols = observation_cols,
    session_cols = session_cols,
)
-- ModelFit object --
Action model: rescorla_wagner_act_after_update
Linear regression population model
2 estimated action model parameters, 12 sessions
Posterior not sampled
Prior not sampled

Fitting the model

We are now ready to fit the model to the data. For this model, we will use the Enzyme automatic differentiation backend, which is a high-performance automatic differentiation library. Additionally, to keep the runtime of this tutorial short, we will only fit two chains with 500 samples each. We also pass MCMCThreads() to fit two chains in parallel. This should take up to 5 minutes on a standard laptop.

# Set AD backend ##
using ADTypes: AutoEnzyme
import Enzyme: set_runtime_activity, Reverse
ad_type = AutoEnzyme(; mode = set_runtime_activity(Reverse, true));

# Fit model ##
chns = sample_posterior!(model, MCMCThreads(), n_chains = 2, n_samples = 500, ad_type = ad_type)
Chains MCMC chain (500×26×2 Array{Float64, 3}):

Iterations        = 251:1:750
Number of chains  = 2
Samples per chain = 500
Wall duration     = 251.35 seconds
Compute duration  = 382.03 seconds
parameters        = learning_rate.β[1], learning_rate.β[2], learning_rate.β[3], learning_rate.ranef_1.σ[1], learning_rate.ranef_1.r[1], learning_rate.ranef_1.r[2], learning_rate.ranef_1.r[3], action_noise.β[1], action_noise.β[2], action_noise.β[3], action_noise.ranef_1.σ[1], action_noise.ranef_1.r[1], action_noise.ranef_1.r[2], action_noise.ranef_1.r[3]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
                  parameters      mean       std      mcse    ess_bulk   ess_tail      rhat   ess_per_sec
                      Symbol   Float64   Float64   Float64     Float64    Float64   Float64       Float64

          learning_rate.β[1]   -0.4333    0.1758    0.0091    377.5332   449.4098    1.0027        0.9882
          learning_rate.β[2]    0.0051    0.1847    0.0076    596.7839   535.7030    1.0031        1.5621
          learning_rate.β[3]    0.0053    0.0376    0.0014    736.7688   563.1750    1.0015        1.9285
  learning_rate.ranef_1.σ[1]    0.2339    0.1323    0.0065    358.7000   618.4885    1.0113        0.9389
  learning_rate.ranef_1.r[1]   -0.2462    0.1732    0.0101    314.5063   358.1045    1.0097        0.8232
  learning_rate.ranef_1.r[2]   -0.0005    0.1599    0.0094    301.7176   356.0504    1.0109        0.7898
  learning_rate.ranef_1.r[3]   -0.0207    0.1650    0.0088    370.6401   411.7057    1.0037        0.9702
           action_noise.β[1]    0.5958    0.3146    0.0174    332.5194   298.1507    1.0028        0.8704
           action_noise.β[2]    0.0731    0.1962    0.0068    845.5846   679.9730    1.0013        2.2134
           action_noise.β[3]   -0.0887    0.0123    0.0004   1132.4104   632.2110    1.0059        2.9642
   action_noise.ranef_1.σ[1]    1.0351    0.1622    0.0062    672.0246   662.7026    1.0022        1.7591
   action_noise.ranef_1.r[1]    2.1531    0.3178    0.0174    339.9609   321.7117    1.0027        0.8899
   action_noise.ranef_1.r[2]    2.2231    0.3219    0.0178    333.0812   313.1892    1.0026        0.8719
   action_noise.ranef_1.r[3]    2.3175    0.3406    0.0185    340.9825   357.8952    1.0028        0.8925

Quantiles
                  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
                      Symbol   Float64   Float64   Float64   Float64   Float64

          learning_rate.β[1]   -0.7635   -0.5583   -0.4439   -0.3180   -0.0758
          learning_rate.β[2]   -0.3629   -0.1218    0.0109    0.1367    0.3399
          learning_rate.β[3]   -0.0699   -0.0191    0.0057    0.0301    0.0773
  learning_rate.ranef_1.σ[1]    0.0629    0.1403    0.2102    0.2911    0.5654
  learning_rate.ranef_1.r[1]   -0.6798   -0.3372   -0.2234   -0.1223    0.0140
  learning_rate.ranef_1.r[2]   -0.3924   -0.0705    0.0174    0.0997    0.2630
  learning_rate.ranef_1.r[3]   -0.4201   -0.0948   -0.0062    0.0790    0.2546
           action_noise.β[1]   -0.0103    0.3794    0.5930    0.8125    1.2564
           action_noise.β[2]   -0.2886   -0.0563    0.0703    0.1973    0.4542
           action_noise.β[3]   -0.1136   -0.0964   -0.0887   -0.0811   -0.0645
   action_noise.ranef_1.σ[1]    0.7408    0.9171    1.0288    1.1441    1.3669
   action_noise.ranef_1.r[1]    1.4943    1.9344    2.1524    2.3710    2.7706
   action_noise.ranef_1.r[2]    1.5421    1.9985    2.2233    2.4505    2.8474
   action_noise.ranef_1.r[3]    1.5834    2.0942    2.3247    2.5506    2.9578

We can now inspect the results of the fitting process. We can plot the posterior distributions of the beta parameter for PDI's effect on the action model parameters. Here we can see that there may be a weak indication of an increase in action noise with increasing PDI score, although without the full dataset, the posterior is not very informative.

title = plot(
    title = "Posterior over effect of PDI",
    grid = false,
    showaxis = false,
    bottom_margin = -30Plots.px,
)
plot(
    title,
    density(chns[Symbol("learning_rate.β[2]")], title = "learning rate", label = nothing),
    density(chns[Symbol("action_noise.β[2]")], title = "action noise", label = nothing),
    layout = @layout([A{0.01h}; [B C]])
)

We can also plot the posterior over the effect of session number on the action model parameters. Here, it looks like there is a negative effect of session number on the action noise, so that participants become more consistent in their responses over the course of the experiment.

title = plot(
    title = "Posterior over effect of session",
    grid = false,
    showaxis = false,
    bottom_margin = -30Plots.px,
)
plot(
    title,
    density(chns[Symbol("learning_rate.β[3]")], title = "learning rate", label = nothing),
    density(chns[Symbol("action_noise.β[3]")], title = "action noise", label = nothing),
    layout = @layout([A{0.01h}; [B C]])
)

We can also extract the session parameters from the model.

session_parameters = get_session_parameters!(model);

#TODO: plot the session parameters

And we can extract a dataframe with the median of the posterior for each session parameter

parameters_df = summarize(session_parameters, median)
show(parameters_df)
12×4 DataFrame
 Row │ ID      session  action_noise  learning_rate
     │ String  String   Float64       Float64
─────┼──────────────────────────────────────────────
   1 │ 20      1             14.3482       0.337215
   2 │ 20      2             13.1284       0.338779
   3 │ 20      3             12.0228       0.339856
   4 │ 20      4             10.9936       0.340807
   5 │ 40      1             15.5294       0.395413
   6 │ 40      2             14.206        0.395648
   7 │ 40      3             12.999        0.396763
   8 │ 40      4             11.8839       0.397486
   9 │ 60      1             17.5411       0.390897
  10 │ 60      2             16.0591       0.391724
  11 │ 60      3             14.6971       0.393122
  12 │ 60      4             13.4685       0.393396

We can also look at the implied state trajectories, in this case the expected value.

state_trajectories = get_state_trajectories!(model, :expected_value);

#TODO: plot the state trajectories

These can also be summarized in a dataframe, for downstream analysis.

states_df = summarize(state_trajectories, median)
show(states_df)
2892×4 DataFrame
  Row │ ID      session  timestep  expected_value
      │ String  String   Int64     Float64
──────┼───────────────────────────────────────────
    1 │ 20      1               0          0.0
    2 │ 20      1               1         20.9073
    3 │ 20      1               2         32.4039
    4 │ 20      1               3         39.012
    5 │ 20      1               4         42.7173
    6 │ 20      1               5         50.2314
    7 │ 20      1               6         51.8394
    8 │ 20      1               7         48.8586
  ⋮   │   ⋮        ⋮        ⋮            ⋮
 2886 │ 60      4             234         17.5325
 2887 │ 60      4             235         17.323
 2888 │ 60      4             236         14.4422
 2889 │ 60      4             237         15.055
 2890 │ 60      4             238         13.8532
 2891 │ 60      4             239         13.5175
 2892 │ 60      4             240         11.347
                                 2877 rows omitted

Comparing to a simple random model

We can compare the Rescorla-Wagner model to a simple random model, which samples actions randomly from a Gaussian distribution with a fixed mean $\mu$ and standard deviation $\sigma$.

#First we create the simple model
function gaussian_random(attributes::ModelAttributes, observation::Float64)

    parameters = load_parameters(attributes)

    σ = parameters.std
    μ = parameters.mean

    return Normal(μ, σ)
end

action_model = ActionModel(
    gaussian_random,
    observations = (; observation = Observation()),
    actions = (; report = Action(Normal)),
    parameters = (std = Parameter(1), mean = Parameter(50)),
)
-- ActionModel --
Action model function: gaussian_random
Number of parameters: 2
Number of states: 0
Number of observations: 1
Number of actions: 1

Fitting the model

We also set priors for this simpler model. Here we set the priors separately for the mean and the noise, since they are on very different scales. We center the priors for the mean at 50, as this is the middle of the range of actions. The priors for the noise are similar to those used with the Rescorla-Wagner model.

mean_regression_prior = RegressionPrior(
    β = [Normal(50, 10), Normal(0, 10), Normal(0, 10)],
    σ = truncated(Normal(0, 10), lower = 0),
)
noise_regression_prior = RegressionPrior(
    β = [Normal(0, 0.3), Normal(0, 0.2), Normal(0, 0.5)],
    σ = truncated(Normal(0, 0.3), lower = 0),
)

population_model = [
    Regression(@formula(mean ~ 1 + pdi_total + session + (1 | ID)), mean_regression_prior),
    Regression(
        @formula(std ~ 1 + pdi_total + session + (1 | ID)),
        exp,
        noise_regression_prior,
    ),
]

simple_model = create_model(
    action_model,
    population_model,
    JGET_data,
    action_cols = action_cols,
    observation_cols = observation_cols,
    session_cols = session_cols,
)

# Set AD backend ##
using ADTypes: AutoEnzyme
import Enzyme: set_runtime_activity, Reverse
ad_type = AutoEnzyme(; mode = set_runtime_activity(Reverse, true));

# Fit model ##
chns = sample_posterior!(simple_model, MCMCThreads(), n_chains = 2, n_samples = 500, ad_type = ad_type)

#Plot the posteriors
plot(
    plot(
        plot(
            title = "Posterior over effect of PDI",
            grid = false,
            showaxis = false,
            bottom_margin = -30Plots.px,
        ),
        density(
            chns[Symbol("mean.β[2]")],
            title = "mean",
            label = nothing,
        ),
        density(
            chns[Symbol("std.β[2]")],
            title = "std",
            label = nothing,
        ),
        layout = @layout([A{0.01h}; [B C]])
    ),
    plot(
        plot(
            title = "Posterior over effect of session",
            grid = false,
            showaxis = false,
            bottom_margin = -30Plots.px,
        ),
        density(chns[Symbol("mean.β[3]")], title = "mean", label = nothing),
        density(chns[Symbol("std.β[3]")], title = "std", label = nothing),
        layout = @layout([A{0.01h}; [B C]])
    ),
    layout = (2,1)
)

And we can use model comparison to compare the two models.

#TODO: model comparison

This page was generated using Literate.jl.