Jumping Gaussian Estimation Task
In this tutorial, we will fit a Rescorla-Wagner model to data from the Jumping Gaussian Estimation Task (JGET). In the JGET, participants observe continuous outcomes sampled from a Gaussian distribution which at some trials "jumps" to be centered somewhere else. Participants must predict the outcome of the next trial based on the previous outcomes. The data we will use is from a study on schizotypy (Mikus et al., 2025), where participants completed the JGET and also filled out the Peters Delusions Inventory (PDI). The PDI is a self-report questionnaire that measures delusional ideation, and we will use it as a predictor in our model. There are several session per participant in the dataset, which will be modeled as separate experimental sessions. Data is taken from https://github.com/nacemikus/jget-schizotypy, where more information can also be found.
Loading data
First, we load the ActionModels package. We also load CSV and Dataframes for loading the data, and StatsPlots for plotting the results.
using ActionModels
using CSV, DataFrames
using StatsPlots
Then we load the JGET data, which is available in the docs/example_data/JGET folder.
#Trial-level data
JGET_data = CSV.read(
joinpath(docs_path, "example_data", "JGET", "JGET_data_trial_preprocessed.csv"),
DataFrame,
missingstring = ["NaN", ""],
)
JGET_data = select(JGET_data, [:trials, :ID, :session, :outcome, :response, :confidence])
#Subject-level data
subject_data = CSV.read(
joinpath(docs_path, "example_data", "JGET", "JGET_data_sub_preprocessed.csv"),
DataFrame,
missingstring = ["NaN", ""],
)
subject_data = select(subject_data, [:ID, :session, :pdi_total, :Age, :Gender, :Education])
#Join the data
JGET_data = innerjoin(JGET_data, subject_data, on = [:ID, :session])
#Remove ID's with missing actions and missing PDI scores
JGET_data = combine(
groupby(JGET_data, [:ID, :session]),
subdata -> any(ismissing, Matrix(subdata[!, [:response]])) ? DataFrame() : subdata,
)
JGET_data = combine(
groupby(JGET_data, [:ID, :session]),
subdata -> any(ismissing, Matrix(subdata[!, [:pdi_total]])) ? DataFrame() : subdata,
)
disallowmissing!(JGET_data, [:pdi_total, :response])
#Make the outcome into a Float64
JGET_data.outcome = Float64.(JGET_data.outcome)
show(JGET_data)
42480×10 DataFrame
Row │ ID session trials outcome response confidence pdi_total Age ⋯
│ Int64 Int64 Int64 Float64 Float64 Float64? Float64 Flo ⋯
───────┼────────────────────────────────────────────────────────────────────────
1 │ 20 1 1 62.0 15.0 0.0953106 0.0535714 ⋯
2 │ 20 1 2 55.0 57.0 0.0680593 0.0535714
3 │ 20 1 3 52.0 62.0 0.934176 0.0535714
4 │ 20 1 4 50.0 55.0 1.40098 0.0535714
5 │ 20 1 5 65.0 49.0 0.0674173 0.0535714 ⋯
6 │ 20 1 6 55.0 61.0 0.667615 0.0535714
7 │ 20 1 7 43.0 62.0 0.0677037 0.0535714
8 │ 20 1 8 59.0 38.0 0.0341998 0.0535714
⋮ │ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋱
42474 │ 74 4 234 23.0 18.0 0.600938 0.0863095 ⋯
42475 │ 74 4 235 17.0 26.0 0.934273 0.0863095
42476 │ 74 4 236 10.0 25.0 0.78428 0.0863095
42477 │ 74 4 237 16.0 18.0 0.600921 0.0863095
42478 │ 74 4 238 12.0 20.0 0.750922 0.0863095 ⋯
42479 │ 74 4 239 13.0 20.0 0.83429 0.0863095
42480 │ 74 4 240 8.0 23.0 0.734223 0.0863095
3 columns and 42465 rows omitted
For this example, we will subset the data to only include four subjects in total, across the three experimental sessions. This makes the runtime much shorter. Simply skip this step if you want to use the full dataset.
JGET_data = filter(row -> row[:ID] in [20, 40, 60, 70], JGET_data);
Creating the model
Then we construct the model to be fitted to the data. We will use a classic Rescorla-Wagner model with a Gaussian report action model. The Rescorla-Wagner model is a simple reinforcement learning model that updates the expected value of an action based on the observed outcome. The Gaussian report action model assumes that the agent reports a continuous value sampled from a Gaussian distribution, where the mean is the expected value of the action and the standard deviation is a noise parameter. There are two parameters in the action model: the learning rate $\alpha$ and the action noise $\beta$. See the section on the Rescorla-Wagner model for more details.
We create the Rescorla-Wagner action model using the premade model from ActionModels.
action_model = ActionModel(RescorlaWagner())
-- ActionModel --
Action model function: rescorla_wagner_act_after_update
Number of parameters: 1
Number of states: 0
Number of observations: 1
Number of actions: 1
submodel type: ActionModels.ContinuousRescorlaWagner
We then specify which column in the data corresponds to the action (the response) and which columns correspond to the input (the outcome). There are two columns which jointly specify the sessions: the ID of the participant and the experimental session number.
action_cols = :response
observation_cols = :outcome
session_cols = [:ID, :session];
We then create the full model. We use a hierarchical regression model to predict the parameters of the Rescorla-Wagner model based on the PDI score. First we will set appropriate priors for the regression coefficients. For the action noise, the outcome of the regression will be exponentiated before it is used in the model, so pre-transformed outcomes around 3 (exp(3) ≈ 20) are among the most extreme values to be expected. For the learning rate, the outcome of the regression will be passed through a logistic function, so pre-transformed outcomes around around 5 (logistic(5) ≈ 0.993) are among the most extreme values to be expected. This means that we should limit priors to be fairly narrow, so that the linear regression does not go too far into inappropriate parameter space, which will increase the runtime of the fitting process.
regression_prior = RegressionPrior(
β = [Normal(0, 0.3), Normal(0, 0.2), Normal(0, 0.5)],
σ = truncated(Normal(0, 0.3), lower = 0),
)
plot(regression_prior.β[1], label = "Intercept")
plot!(regression_prior.β[2], label = "Effect of session number")
plot!(regression_prior.β[3], label = "Effect of PDI")
plot!(regression_prior.σ, label = "Random intercept std")
title!("Regression priors for the Rescorla-Wagner model")
xlabel!("Regression coefficient")
ylabel!("Density")
We then create the population model, which consists of two regression models: one for the learning rate and one for the action noise. We can then also create the full model using the create_model
function.
population_model = [
Regression(
@formula(learning_rate ~ 1 + pdi_total + session + (1 | ID)),
logistic,
regression_prior,
),
Regression(
@formula(action_noise ~ 1 + pdi_total + session + (1 | ID)),
exp,
regression_prior,
),
]
model = create_model(
action_model,
population_model,
JGET_data,
action_cols = action_cols,
observation_cols = observation_cols,
session_cols = session_cols,
)
-- ModelFit object --
Action model: rescorla_wagner_act_after_update
Linear regression population model
2 estimated action model parameters, 12 sessions
Posterior not sampled
Prior not sampled
Fitting the model
We are now ready to fit the model to the data. For this model, we will use the Enzyme automatic differentiation backend, which is a high-performance automatic differentiation library. Additionally, to keep the runtime of this tutorial short, we will only fit two chains with 500 samples each. We also pass MCMCThreads()
to fit two chains in parallel. This should take up to 5 minutes on a standard laptop.
# Set AD backend ##
using ADTypes: AutoEnzyme
import Enzyme: set_runtime_activity, Reverse
ad_type = AutoEnzyme(; mode = set_runtime_activity(Reverse, true));
# Fit model ##
chns = sample_posterior!(model, MCMCThreads(), n_chains = 2, n_samples = 500, ad_type = ad_type)
Chains MCMC chain (500×26×2 Array{Float64, 3}):
Iterations = 251:1:750
Number of chains = 2
Samples per chain = 500
Wall duration = 250.41 seconds
Compute duration = 378.43 seconds
parameters = learning_rate.β[1], learning_rate.β[2], learning_rate.β[3], learning_rate.ranef_1.σ[1], learning_rate.ranef_1.r[1], learning_rate.ranef_1.r[2], learning_rate.ranef_1.r[3], action_noise.β[1], action_noise.β[2], action_noise.β[3], action_noise.ranef_1.σ[1], action_noise.ranef_1.r[1], action_noise.ranef_1.r[2], action_noise.ranef_1.r[3]
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
learning_rate.β[1] -0.4484 0.1671 0.0103 290.6699 250.3161 1.0063 0.7681
learning_rate.β[2] 0.0257 0.1920 0.0069 773.0777 701.4439 1.0030 2.0428
learning_rate.β[3] 0.0077 0.0365 0.0015 593.3413 436.5822 1.0006 1.5679
learning_rate.ranef_1.σ[1] 0.2180 0.1292 0.0079 242.8907 212.0532 1.0109 0.6418
learning_rate.ranef_1.r[1] -0.2335 0.1645 0.0115 230.7800 244.0296 1.0000 0.6098
learning_rate.ranef_1.r[2] 0.0018 0.1414 0.0097 277.1926 262.3570 1.0055 0.7325
learning_rate.ranef_1.r[3] -0.0269 0.1523 0.0105 266.8553 241.5766 1.0026 0.7052
action_noise.β[1] 0.5774 0.2950 0.0149 389.2905 434.4031 1.0018 1.0287
action_noise.β[2] 0.0672 0.1920 0.0065 883.7596 733.7073 1.0006 2.3353
action_noise.β[3] -0.0893 0.0121 0.0003 1428.9009 707.7241 1.0047 3.7758
action_noise.ranef_1.σ[1] 1.0542 0.1594 0.0052 948.9911 825.3080 1.0003 2.5077
action_noise.ranef_1.r[1] 2.1736 0.2979 0.0151 396.1277 434.5011 1.0004 1.0468
action_noise.ranef_1.r[2] 2.2442 0.2988 0.0153 383.5330 459.7353 1.0014 1.0135
action_noise.ranef_1.r[3] 2.3396 0.3136 0.0158 399.3336 434.0173 1.0007 1.0552
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
learning_rate.β[1] -0.7400 -0.5616 -0.4611 -0.3464 -0.0917
learning_rate.β[2] -0.3707 -0.1063 0.0347 0.1590 0.4028
learning_rate.β[3] -0.0603 -0.0189 0.0086 0.0331 0.0812
learning_rate.ranef_1.σ[1] 0.0333 0.1233 0.1939 0.2862 0.5280
learning_rate.ranef_1.r[1] -0.6103 -0.3277 -0.2153 -0.1132 0.0075
learning_rate.ranef_1.r[2] -0.3287 -0.0578 0.0149 0.0923 0.2282
learning_rate.ranef_1.r[3] -0.4058 -0.0961 -0.0112 0.0641 0.2422
action_noise.β[1] 0.0130 0.3714 0.5805 0.7818 1.1742
action_noise.β[2] -0.2944 -0.0686 0.0651 0.1988 0.4376
action_noise.β[3] -0.1123 -0.0974 -0.0887 -0.0813 -0.0672
action_noise.ranef_1.σ[1] 0.7688 0.9389 1.0454 1.1577 1.3809
action_noise.ranef_1.r[1] 1.5811 1.9583 2.1736 2.3790 2.7450
action_noise.ranef_1.r[2] 1.6506 2.0350 2.2416 2.4450 2.8145
action_noise.ranef_1.r[3] 1.7413 2.1220 2.3348 2.5520 2.9468
We can now inspect the results of the fitting process. We can plot the posterior distributions of the beta parameter for PDI's effect on the action model parameters. Here we can see that there may be a weak indication of an increase in action noise with increasing PDI score, although without the full dataset, the posterior is not very informative.
title = plot(
title = "Posterior over effect of PDI",
grid = false,
showaxis = false,
bottom_margin = -30Plots.px,
)
plot(
title,
density(chns[Symbol("learning_rate.β[2]")], title = "learning rate", label = nothing),
density(chns[Symbol("action_noise.β[2]")], title = "action noise", label = nothing),
layout = @layout([A{0.01h}; [B C]])
)
We can also plot the posterior over the effect of session number on the action model parameters. Here, it looks like there is a negative effect of session number on the action noise, so that participants become more consistent in their responses over the course of the experiment.
title = plot(
title = "Posterior over effect of session",
grid = false,
showaxis = false,
bottom_margin = -30Plots.px,
)
plot(
title,
density(chns[Symbol("learning_rate.β[3]")], title = "learning rate", label = nothing),
density(chns[Symbol("action_noise.β[3]")], title = "action noise", label = nothing),
layout = @layout([A{0.01h}; [B C]])
)
We can also extract the session parameters from the model.
session_parameters = get_session_parameters!(model);
#TODO: plot the session parameters
And we can extract a dataframe with the median of the posterior for each session parameter
parameters_df = summarize(session_parameters, median)
show(parameters_df)
12×4 DataFrame
Row │ ID session action_noise learning_rate
│ String String Float64 Float64
─────┼──────────────────────────────────────────────
1 │ 20 1 14.3628 0.336918
2 │ 20 2 13.128 0.338494
3 │ 20 3 12.0153 0.341259
4 │ 20 4 10.9915 0.343494
5 │ 40 1 15.5426 0.392538
6 │ 40 2 14.2164 0.394518
7 │ 40 3 12.9941 0.397139
8 │ 40 4 11.8951 0.398332
9 │ 60 1 17.5744 0.388351
10 │ 60 2 16.0665 0.390544
11 │ 60 3 14.6931 0.391864
12 │ 60 4 13.4293 0.393892
We can also look at the implied state trajectories, in this case the expected value.
state_trajectories = get_state_trajectories!(model, :expected_value);
#TODO: plot the state trajectories
These can also be summarized in a dataframe, for downstream analysis.
states_df = summarize(state_trajectories, median)
show(states_df)
2892×4 DataFrame
Row │ ID session timestep expected_value
│ String String Int64 Float64
──────┼───────────────────────────────────────────
1 │ 20 1 0 0.0
2 │ 20 1 1 20.8889
3 │ 20 1 2 32.3815
4 │ 20 1 3 38.9913
5 │ 20 1 4 42.7003
6 │ 20 1 5 50.2135
7 │ 20 1 6 51.8262
8 │ 20 1 7 48.8525
⋮ │ ⋮ ⋮ ⋮ ⋮
2886 │ 60 4 234 17.5383
2887 │ 60 4 235 17.3263
2888 │ 60 4 236 14.4405
2889 │ 60 4 237 15.0548
2890 │ 60 4 238 13.8515
2891 │ 60 4 239 13.5161
2892 │ 60 4 240 11.3434
2877 rows omitted
Comparing to a simple random model
We can compare the Rescorla-Wagner model to a simple random model, which samples actions randomly from a Gaussian distribution with a fixed mean $\mu$ and standard deviation $\sigma$.
#First we create the simple model
function gaussian_random(attributes::ModelAttributes, observation::Float64)
parameters = load_parameters(attributes)
σ = parameters.std
μ = parameters.mean
return Normal(μ, σ)
end
action_model = ActionModel(
gaussian_random,
observations = (; observation = Observation()),
actions = (; report = Action(Normal)),
parameters = (std = Parameter(1), mean = Parameter(50)),
)
-- ActionModel --
Action model function: gaussian_random
Number of parameters: 2
Number of states: 0
Number of observations: 1
Number of actions: 1
Fitting the model
We also set priors for this simpler model. Here we set the priors separately for the mean and the noise, since they are on very different scales. We center the priors for the mean at 50, as this is the middle of the range of actions. The priors for the noise are similar to those used with the Rescorla-Wagner model.
mean_regression_prior = RegressionPrior(
β = [Normal(50, 10), Normal(0, 10), Normal(0, 10)],
σ = truncated(Normal(0, 10), lower = 0),
)
noise_regression_prior = RegressionPrior(
β = [Normal(0, 0.3), Normal(0, 0.2), Normal(0, 0.5)],
σ = truncated(Normal(0, 0.3), lower = 0),
)
population_model = [
Regression(@formula(mean ~ 1 + pdi_total + session + (1 | ID)), mean_regression_prior),
Regression(
@formula(std ~ 1 + pdi_total + session + (1 | ID)),
exp,
noise_regression_prior,
),
]
simple_model = create_model(
action_model,
population_model,
JGET_data,
action_cols = action_cols,
observation_cols = observation_cols,
session_cols = session_cols,
)
# Set AD backend ##
using ADTypes: AutoEnzyme
import Enzyme: set_runtime_activity, Reverse
ad_type = AutoEnzyme(; mode = set_runtime_activity(Reverse, true));
# Fit model ##
chns = sample_posterior!(simple_model, MCMCThreads(), n_chains = 2, n_samples = 500, ad_type = ad_type)
#Plot the posteriors
plot(
plot(
plot(
title = "Posterior over effect of PDI",
grid = false,
showaxis = false,
bottom_margin = -30Plots.px,
),
density(
chns[Symbol("mean.β[2]")],
title = "mean",
label = nothing,
),
density(
chns[Symbol("std.β[2]")],
title = "std",
label = nothing,
),
layout = @layout([A{0.01h}; [B C]])
),
plot(
plot(
title = "Posterior over effect of session",
grid = false,
showaxis = false,
bottom_margin = -30Plots.px,
),
density(chns[Symbol("mean.β[3]")], title = "mean", label = nothing),
density(chns[Symbol("std.β[3]")], title = "std", label = nothing),
layout = @layout([A{0.01h}; [B C]])
),
layout = (2,1)
)
And we can use model comparison to compare the two models.
#TODO: model comparison
This page was generated using Literate.jl.