Jumping Gaussian Estimation Task

In this tutorial, we will fit a Rescorla-Wagner model to data from the Jumping Gaussian Estimation Task (JGET). In the JGET, participants observe continuous outcomes sampled from a Gaussian distribution which at some trials "jumps" to be centered somewhere else. Participants must predict the outcome of the next trial based on the previous outcomes. The data we will use is from a study on schizotypy (Mikus et al., 2025), where participants completed the JGET and also filled out the Peters Delusions Inventory (PDI). The PDI is a self-report questionnaire that measures delusional ideation, and we will use it as a predictor in our model. There are several session per participant in the dataset, which will be modeled as separate experimental sessions. Data is taken from https://github.com/nacemikus/jget-schizotypy, where more information can also be found.

Loading data

First, we load the ActionModels package. We also load CSV and Dataframes for loading the data, and StatsPlots for plotting the results.

using ActionModels
using CSV, DataFrames
using StatsPlots

Then we load the JGET data, which is available in the docs/example_data/JGET folder.

#Trial-level data
JGET_data = CSV.read(
    joinpath(docs_path, "example_data", "JGET", "JGET_data_trial_preprocessed.csv"),
    DataFrame,
    missingstring = ["NaN", ""],
)
JGET_data = select(JGET_data, [:trials, :ID, :session, :outcome, :response, :confidence])

#Subject-level data
subject_data = CSV.read(
    joinpath(docs_path, "example_data", "JGET", "JGET_data_sub_preprocessed.csv"),
    DataFrame,
    missingstring = ["NaN", ""],
)
subject_data = select(subject_data, [:ID, :session, :pdi_total, :Age, :Gender, :Education])

#Join the data
JGET_data = innerjoin(JGET_data, subject_data, on = [:ID, :session])

#Remove ID's with missing actions and missing PDI scores
JGET_data = combine(
    groupby(JGET_data, [:ID, :session]),
    subdata -> any(ismissing, Matrix(subdata[!, [:response]])) ? DataFrame() : subdata,
)
JGET_data = combine(
    groupby(JGET_data, [:ID, :session]),
    subdata -> any(ismissing, Matrix(subdata[!, [:pdi_total]])) ? DataFrame() : subdata,
)
disallowmissing!(JGET_data, [:pdi_total, :response])

#Make the outcome into a Float64
JGET_data.outcome = Float64.(JGET_data.outcome)

show(JGET_data)
42480×10 DataFrame
   Row │ ID     session  trials  outcome  response  confidence  pdi_total  Age ⋯
       │ Int64  Int64    Int64   Float64  Float64   Float64?    Float64    Flo ⋯
───────┼────────────────────────────────────────────────────────────────────────
     1 │    20        1       1     62.0      15.0   0.0953106  0.0535714      ⋯
     2 │    20        1       2     55.0      57.0   0.0680593  0.0535714
     3 │    20        1       3     52.0      62.0   0.934176   0.0535714
     4 │    20        1       4     50.0      55.0   1.40098    0.0535714
     5 │    20        1       5     65.0      49.0   0.0674173  0.0535714      ⋯
     6 │    20        1       6     55.0      61.0   0.667615   0.0535714
     7 │    20        1       7     43.0      62.0   0.0677037  0.0535714
     8 │    20        1       8     59.0      38.0   0.0341998  0.0535714
   ⋮   │   ⋮       ⋮       ⋮        ⋮        ⋮          ⋮           ⋮          ⋱
 42474 │    74        4     234     23.0      18.0   0.600938   0.0863095      ⋯
 42475 │    74        4     235     17.0      26.0   0.934273   0.0863095
 42476 │    74        4     236     10.0      25.0   0.78428    0.0863095
 42477 │    74        4     237     16.0      18.0   0.600921   0.0863095
 42478 │    74        4     238     12.0      20.0   0.750922   0.0863095      ⋯
 42479 │    74        4     239     13.0      20.0   0.83429    0.0863095
 42480 │    74        4     240      8.0      23.0   0.734223   0.0863095
                                                3 columns and 42465 rows omitted

For this example, we will subset the data to only include four subjects in total, across the three experimental sessions. This makes the runtime much shorter. Simply skip this step if you want to use the full dataset.

JGET_data = filter(row -> row[:ID] in [20, 40, 60, 70], JGET_data);

Creating the model

Then we construct the model to be fitted to the data. We will use a classic Rescorla-Wagner model with a Gaussian report action model. The Rescorla-Wagner model is a simple reinforcement learning model that updates the expected value of an action based on the observed outcome. The Gaussian report action model assumes that the agent reports a continuous value sampled from a Gaussian distribution, where the mean is the expected value of the action and the standard deviation is a noise parameter. There are two parameters in the action model: the learning rate $\alpha$ and the action noise $\beta$. See the section on the Rescorla-Wagner model for more details.

We create the Rescorla-Wagner action model using the premade model from ActionModels.

action_model = ActionModel(RescorlaWagner())
-- ActionModel --
Action model function: rescorla_wagner_act_after_update
Number of parameters: 1
Number of states: 0
Number of observations: 1
Number of actions: 1
submodel type: ActionModels.ContinuousRescorlaWagner

We then specify which column in the data corresponds to the action (the response) and which columns correspond to the input (the outcome). There are two columns which jointly specify the sessions: the ID of the participant and the experimental session number.

action_cols = :response
observation_cols = :outcome
session_cols = [:ID, :session];

We then create the full model. We use a hierarchical regression model to predict the parameters of the Rescorla-Wagner model based on the PDI score. First we will set appropriate priors for the regression coefficients. For the action noise, the outcome of the regression will be exponentiated before it is used in the model, so pre-transformed outcomes around 3 (exp(3) ≈ 20) are among the most extreme values to be expected. For the learning rate, the outcome of the regression will be passed through a logistic function, so pre-transformed outcomes around around 5 (logistic(5) ≈ 0.993) are among the most extreme values to be expected. This means that we should limit priors to be fairly narrow, so that the linear regression does not go too far into inappropriate parameter space, which will increase the runtime of the fitting process.

regression_prior = RegressionPrior(
    β = [Normal(0, 0.3), Normal(0, 0.2), Normal(0, 0.5)],
    σ = truncated(Normal(0, 0.3), lower = 0),
)

plot(regression_prior.β[1], label = "Intercept")
plot!(regression_prior.β[2], label = "Effect of session number")
plot!(regression_prior.β[3], label = "Effect of PDI")
plot!(regression_prior.σ, label = "Random intercept std")
title!("Regression priors for the Rescorla-Wagner model")
xlabel!("Regression coefficient")
ylabel!("Density")

We then create the population model, which consists of two regression models: one for the learning rate and one for the action noise. We can then also create the full model using the create_model function.

population_model = [
    Regression(
        @formula(learning_rate ~ 1 + pdi_total + session + (1 | ID)),
        logistic,
        regression_prior,
    ),
    Regression(
        @formula(action_noise ~ 1 + pdi_total + session + (1 | ID)),
        exp,
        regression_prior,
    ),
]

model = create_model(
    action_model,
    population_model,
    JGET_data,
    action_cols = action_cols,
    observation_cols = observation_cols,
    session_cols = session_cols,
)
-- ModelFit object --
Action model: rescorla_wagner_act_after_update
Linear regression population model
2 estimated action model parameters, 12 sessions
Posterior not sampled
Prior not sampled

Fitting the model

We are now ready to fit the model to the data. For this model, we will use the Enzyme automatic differentiation backend, which is a high-performance automatic differentiation library. Additionally, to keep the runtime of this tutorial short, we will only fit two chains with 500 samples each. We also pass MCMCThreads() to fit two chains in parallel. This should take up to 5 minutes on a standard laptop.

# Set AD backend ##
using ADTypes: AutoEnzyme
import Enzyme: set_runtime_activity, Reverse
ad_type = AutoEnzyme(; mode = set_runtime_activity(Reverse, true));

# Fit model ##
chns = sample_posterior!(model, MCMCThreads(), n_chains = 2, n_samples = 500, ad_type = ad_type)
Chains MCMC chain (500×26×2 Array{Float64, 3}):

Iterations        = 251:1:750
Number of chains  = 2
Samples per chain = 500
Wall duration     = 250.41 seconds
Compute duration  = 378.43 seconds
parameters        = learning_rate.β[1], learning_rate.β[2], learning_rate.β[3], learning_rate.ranef_1.σ[1], learning_rate.ranef_1.r[1], learning_rate.ranef_1.r[2], learning_rate.ranef_1.r[3], action_noise.β[1], action_noise.β[2], action_noise.β[3], action_noise.ranef_1.σ[1], action_noise.ranef_1.r[1], action_noise.ranef_1.r[2], action_noise.ranef_1.r[3]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
                  parameters      mean       std      mcse    ess_bulk   ess_tail      rhat   ess_per_sec
                      Symbol   Float64   Float64   Float64     Float64    Float64   Float64       Float64

          learning_rate.β[1]   -0.4484    0.1671    0.0103    290.6699   250.3161    1.0063        0.7681
          learning_rate.β[2]    0.0257    0.1920    0.0069    773.0777   701.4439    1.0030        2.0428
          learning_rate.β[3]    0.0077    0.0365    0.0015    593.3413   436.5822    1.0006        1.5679
  learning_rate.ranef_1.σ[1]    0.2180    0.1292    0.0079    242.8907   212.0532    1.0109        0.6418
  learning_rate.ranef_1.r[1]   -0.2335    0.1645    0.0115    230.7800   244.0296    1.0000        0.6098
  learning_rate.ranef_1.r[2]    0.0018    0.1414    0.0097    277.1926   262.3570    1.0055        0.7325
  learning_rate.ranef_1.r[3]   -0.0269    0.1523    0.0105    266.8553   241.5766    1.0026        0.7052
           action_noise.β[1]    0.5774    0.2950    0.0149    389.2905   434.4031    1.0018        1.0287
           action_noise.β[2]    0.0672    0.1920    0.0065    883.7596   733.7073    1.0006        2.3353
           action_noise.β[3]   -0.0893    0.0121    0.0003   1428.9009   707.7241    1.0047        3.7758
   action_noise.ranef_1.σ[1]    1.0542    0.1594    0.0052    948.9911   825.3080    1.0003        2.5077
   action_noise.ranef_1.r[1]    2.1736    0.2979    0.0151    396.1277   434.5011    1.0004        1.0468
   action_noise.ranef_1.r[2]    2.2442    0.2988    0.0153    383.5330   459.7353    1.0014        1.0135
   action_noise.ranef_1.r[3]    2.3396    0.3136    0.0158    399.3336   434.0173    1.0007        1.0552

Quantiles
                  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
                      Symbol   Float64   Float64   Float64   Float64   Float64

          learning_rate.β[1]   -0.7400   -0.5616   -0.4611   -0.3464   -0.0917
          learning_rate.β[2]   -0.3707   -0.1063    0.0347    0.1590    0.4028
          learning_rate.β[3]   -0.0603   -0.0189    0.0086    0.0331    0.0812
  learning_rate.ranef_1.σ[1]    0.0333    0.1233    0.1939    0.2862    0.5280
  learning_rate.ranef_1.r[1]   -0.6103   -0.3277   -0.2153   -0.1132    0.0075
  learning_rate.ranef_1.r[2]   -0.3287   -0.0578    0.0149    0.0923    0.2282
  learning_rate.ranef_1.r[3]   -0.4058   -0.0961   -0.0112    0.0641    0.2422
           action_noise.β[1]    0.0130    0.3714    0.5805    0.7818    1.1742
           action_noise.β[2]   -0.2944   -0.0686    0.0651    0.1988    0.4376
           action_noise.β[3]   -0.1123   -0.0974   -0.0887   -0.0813   -0.0672
   action_noise.ranef_1.σ[1]    0.7688    0.9389    1.0454    1.1577    1.3809
   action_noise.ranef_1.r[1]    1.5811    1.9583    2.1736    2.3790    2.7450
   action_noise.ranef_1.r[2]    1.6506    2.0350    2.2416    2.4450    2.8145
   action_noise.ranef_1.r[3]    1.7413    2.1220    2.3348    2.5520    2.9468

We can now inspect the results of the fitting process. We can plot the posterior distributions of the beta parameter for PDI's effect on the action model parameters. Here we can see that there may be a weak indication of an increase in action noise with increasing PDI score, although without the full dataset, the posterior is not very informative.

title = plot(
    title = "Posterior over effect of PDI",
    grid = false,
    showaxis = false,
    bottom_margin = -30Plots.px,
)
plot(
    title,
    density(chns[Symbol("learning_rate.β[2]")], title = "learning rate", label = nothing),
    density(chns[Symbol("action_noise.β[2]")], title = "action noise", label = nothing),
    layout = @layout([A{0.01h}; [B C]])
)

We can also plot the posterior over the effect of session number on the action model parameters. Here, it looks like there is a negative effect of session number on the action noise, so that participants become more consistent in their responses over the course of the experiment.

title = plot(
    title = "Posterior over effect of session",
    grid = false,
    showaxis = false,
    bottom_margin = -30Plots.px,
)
plot(
    title,
    density(chns[Symbol("learning_rate.β[3]")], title = "learning rate", label = nothing),
    density(chns[Symbol("action_noise.β[3]")], title = "action noise", label = nothing),
    layout = @layout([A{0.01h}; [B C]])
)

We can also extract the session parameters from the model.

session_parameters = get_session_parameters!(model);

#TODO: plot the session parameters

And we can extract a dataframe with the median of the posterior for each session parameter

parameters_df = summarize(session_parameters, median)
show(parameters_df)
12×4 DataFrame
 Row │ ID      session  action_noise  learning_rate
     │ String  String   Float64       Float64
─────┼──────────────────────────────────────────────
   1 │ 20      1             14.3628       0.336918
   2 │ 20      2             13.128        0.338494
   3 │ 20      3             12.0153       0.341259
   4 │ 20      4             10.9915       0.343494
   5 │ 40      1             15.5426       0.392538
   6 │ 40      2             14.2164       0.394518
   7 │ 40      3             12.9941       0.397139
   8 │ 40      4             11.8951       0.398332
   9 │ 60      1             17.5744       0.388351
  10 │ 60      2             16.0665       0.390544
  11 │ 60      3             14.6931       0.391864
  12 │ 60      4             13.4293       0.393892

We can also look at the implied state trajectories, in this case the expected value.

state_trajectories = get_state_trajectories!(model, :expected_value);

#TODO: plot the state trajectories

These can also be summarized in a dataframe, for downstream analysis.

states_df = summarize(state_trajectories, median)
show(states_df)
2892×4 DataFrame
  Row │ ID      session  timestep  expected_value
      │ String  String   Int64     Float64
──────┼───────────────────────────────────────────
    1 │ 20      1               0          0.0
    2 │ 20      1               1         20.8889
    3 │ 20      1               2         32.3815
    4 │ 20      1               3         38.9913
    5 │ 20      1               4         42.7003
    6 │ 20      1               5         50.2135
    7 │ 20      1               6         51.8262
    8 │ 20      1               7         48.8525
  ⋮   │   ⋮        ⋮        ⋮            ⋮
 2886 │ 60      4             234         17.5383
 2887 │ 60      4             235         17.3263
 2888 │ 60      4             236         14.4405
 2889 │ 60      4             237         15.0548
 2890 │ 60      4             238         13.8515
 2891 │ 60      4             239         13.5161
 2892 │ 60      4             240         11.3434
                                 2877 rows omitted

Comparing to a simple random model

We can compare the Rescorla-Wagner model to a simple random model, which samples actions randomly from a Gaussian distribution with a fixed mean $\mu$ and standard deviation $\sigma$.

#First we create the simple model
function gaussian_random(attributes::ModelAttributes, observation::Float64)

    parameters = load_parameters(attributes)

    σ = parameters.std
    μ = parameters.mean

    return Normal(μ, σ)
end

action_model = ActionModel(
    gaussian_random,
    observations = (; observation = Observation()),
    actions = (; report = Action(Normal)),
    parameters = (std = Parameter(1), mean = Parameter(50)),
)
-- ActionModel --
Action model function: gaussian_random
Number of parameters: 2
Number of states: 0
Number of observations: 1
Number of actions: 1

Fitting the model

We also set priors for this simpler model. Here we set the priors separately for the mean and the noise, since they are on very different scales. We center the priors for the mean at 50, as this is the middle of the range of actions. The priors for the noise are similar to those used with the Rescorla-Wagner model.

mean_regression_prior = RegressionPrior(
    β = [Normal(50, 10), Normal(0, 10), Normal(0, 10)],
    σ = truncated(Normal(0, 10), lower = 0),
)
noise_regression_prior = RegressionPrior(
    β = [Normal(0, 0.3), Normal(0, 0.2), Normal(0, 0.5)],
    σ = truncated(Normal(0, 0.3), lower = 0),
)

population_model = [
    Regression(@formula(mean ~ 1 + pdi_total + session + (1 | ID)), mean_regression_prior),
    Regression(
        @formula(std ~ 1 + pdi_total + session + (1 | ID)),
        exp,
        noise_regression_prior,
    ),
]

simple_model = create_model(
    action_model,
    population_model,
    JGET_data,
    action_cols = action_cols,
    observation_cols = observation_cols,
    session_cols = session_cols,
)

# Set AD backend ##
using ADTypes: AutoEnzyme
import Enzyme: set_runtime_activity, Reverse
ad_type = AutoEnzyme(; mode = set_runtime_activity(Reverse, true));

# Fit model ##
chns = sample_posterior!(simple_model, MCMCThreads(), n_chains = 2, n_samples = 500, ad_type = ad_type)

#Plot the posteriors
plot(
    plot(
        plot(
            title = "Posterior over effect of PDI",
            grid = false,
            showaxis = false,
            bottom_margin = -30Plots.px,
        ),
        density(
            chns[Symbol("mean.β[2]")],
            title = "mean",
            label = nothing,
        ),
        density(
            chns[Symbol("std.β[2]")],
            title = "std",
            label = nothing,
        ),
        layout = @layout([A{0.01h}; [B C]])
    ),
    plot(
        plot(
            title = "Posterior over effect of session",
            grid = false,
            showaxis = false,
            bottom_margin = -30Plots.px,
        ),
        density(chns[Symbol("mean.β[3]")], title = "mean", label = nothing),
        density(chns[Symbol("std.β[3]")], title = "std", label = nothing),
        layout = @layout([A{0.01h}; [B C]])
    ),
    layout = (2,1)
)

And we can use model comparison to compare the two models.

#TODO: model comparison

This page was generated using Literate.jl.