Population models

Apart from the action model, which describes how actions are chosen and states updated on a timestep-by-timestep basis, it is also necessary to specify a population model when fitting to data. The population model describes how the parameters of the action model are distributed across the sessions in the dataset (i.e. the population). ActionModels provides two types of pre-created population models that the user can choose from. One is the independent sessions population model, in which it is assumed that the parameters for each session are independent from each but other, but has the same prior. The other is the linear regression population model, in which it is assumed that the parameters of each session are related to external variables through a linear regression. The linear regression population model also allows for specifying a hierarchical structure (i.e. a random effect), which is often used to improve parameter estimation in cognitive modelling. Finally, users can also specify their own custom population model as a Turing model. If there is onle a single session in the dataset, only the independent sessions population model is appropriate, since there is no session structure to model.

First, we load the ActionModels package, and StatsPlots for plotting results

using ActionModels
using StatsPlots

We then specify the data that we want to our model to. For this example, we will use a simple manually created dataset, where three participants have completed an experiment where they must predict the next location of a moving target. Each participant has completed the experiment twice, in a control condition and under an experimental treatment.

using DataFrames

data = DataFrame(
    observations = repeat([1.0, 1, 1, 2, 2, 2], 6),
    actions = vcat(
        [0, 0.2, 0.3, 0.4, 0.5, 0.6],
        [0, 0.5, 0.8, 1, 1.5, 1.8],
        [0, 2, 0.5, 4, 5, 3],
        [0, 0.1, 0.15, 0.2, 0.25, 0.3],
        [0, 0.2, 0.4, 0.7, 1.0, 1.1],
        [0, 2, 0.5, 4, 5, 3],
    ),
    id = vcat(
        repeat(["A"], 6),
        repeat(["B"], 6),
        repeat(["C"], 6),
        repeat(["A"], 6),
        repeat(["B"], 6),
        repeat(["C"], 6),
    ),
    treatment = vcat(repeat(["control"], 18), repeat(["treatment"], 18)),
)

show(data)
36×4 DataFrame
 Row │ observations  actions  id      treatment
     │ Float64       Float64  String  String
─────┼──────────────────────────────────────────
   1 │          1.0      0.0  A       control
   2 │          1.0      0.2  A       control
   3 │          1.0      0.3  A       control
   4 │          2.0      0.4  A       control
   5 │          2.0      0.5  A       control
   6 │          2.0      0.6  A       control
   7 │          1.0      0.0  B       control
   8 │          1.0      0.5  B       control
  ⋮  │      ⋮           ⋮       ⋮         ⋮
  30 │          2.0      1.1  B       treatment
  31 │          1.0      0.0  C       treatment
  32 │          1.0      2.0  C       treatment
  33 │          1.0      0.5  C       treatment
  34 │          2.0      4.0  C       treatment
  35 │          2.0      5.0  C       treatment
  36 │          2.0      3.0  C       treatment
                                 21 rows omitted

We specify which columns in the data correspond to the actions, observations, and session identifiers.

action_cols = :actions;
observation_cols = :observations;
session_cols = [:id, :treatment];

Finally, we specify the action model. We here use the premade Rescorla-Wagner action model provided by ActionModels. This is identical to the model described in the defining action models REF section.

action_model = ActionModel(RescorlaWagner())
-- ActionModel --
Action model function: rescorla_wagner_act_after_update
Number of parameters: 1
Number of states: 0
Number of observations: 1
Number of actions: 1
submodel type: ActionModels.ContinuousRescorlaWagner

Independent session population model

To specify the independent session population model, we only need to specify the prior distribution for each parameter to estimate. This is specified as a NamedTuple with the parameter names as keys and the prior distributions as values. We here select a LogitNormal prior for the learning rate, since it is constrained to be between o and 1, and a LogNormal prior for the action noise, since it is constrained to be positive.

population_model = (learning_rate = LogitNormal(), action_noise = LogNormal());

We can then create the full model, and sample from the posterior

model = create_model(
    action_model,
    population_model,
    data;
    action_cols = action_cols,
    observation_cols = observation_cols,
    session_cols = session_cols,
)

chns = sample_posterior!(model)
Chains MCMC chain (1000×24×2 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 2
Samples per chain = 1000
Wall duration     = 7.9 seconds
Compute duration  = 7.07 seconds
parameters        = learning_rate.session[1], learning_rate.session[2], learning_rate.session[3], learning_rate.session[4], learning_rate.session[5], learning_rate.session[6], action_noise.session[1], action_noise.session[2], action_noise.session[3], action_noise.session[4], action_noise.session[5], action_noise.session[6]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
                parameters      mean       std      mcse    ess_bulk    ess_tail      rhat   ess_per_sec
                    Symbol   Float64   Float64   Float64     Float64     Float64   Float64       Float64

  learning_rate.session[1]    0.0865    0.0125    0.0004   1366.4855    668.6563    1.0007      193.1974
  learning_rate.session[2]    0.0412    0.0073    0.0004    588.4418    327.0229    1.0082       83.1955
  learning_rate.session[3]    0.3563    0.0760    0.0030   1066.7402    607.5336    1.0039      150.8186
  learning_rate.session[4]    0.1748    0.0245    0.0007   1462.0254    959.0508    1.0002      206.7051
  learning_rate.session[5]    0.6182    0.1760    0.0037   2262.2098   1396.9692    0.9996      319.8374
  learning_rate.session[6]    0.6143    0.1787    0.0037   2278.9793   1114.1031    1.0005      322.2083
   action_noise.session[1]    0.1035    0.0568    0.0021   1124.2329    791.6993    1.0008      158.9471
   action_noise.session[2]    0.0635    0.0412    0.0020    815.2990    496.8743    1.0035      115.2692
   action_noise.session[3]    0.2799    0.1192    0.0036   1514.1985    988.0564    1.0031      214.0815
   action_noise.session[4]    0.1585    0.0758    0.0022   1724.4417   1302.5529    0.9996      243.8063
   action_noise.session[5]    1.9314    0.6278    0.0163   1897.4371   1282.2082    1.0009      268.2648
   action_noise.session[6]    1.9462    0.6155    0.0139   2521.5779   1578.5808    1.0014      356.5075

Quantiles
                parameters      2.5%     25.0%     50.0%     75.0%     97.5%
                    Symbol   Float64   Float64   Float64   Float64   Float64

  learning_rate.session[1]    0.0666    0.0789    0.0852    0.0918    0.1169
  learning_rate.session[2]    0.0307    0.0372    0.0401    0.0435    0.0590
  learning_rate.session[3]    0.2401    0.3090    0.3443    0.3882    0.5415
  learning_rate.session[4]    0.1343    0.1604    0.1719    0.1855    0.2304
  learning_rate.session[5]    0.2459    0.5009    0.6336    0.7498    0.9147
  learning_rate.session[6]    0.2500    0.4881    0.6260    0.7536    0.9088
   action_noise.session[1]    0.0449    0.0670    0.0880    0.1223    0.2583
   action_noise.session[2]    0.0248    0.0393    0.0524    0.0733    0.1840
   action_noise.session[3]    0.1344    0.1957    0.2509    0.3333    0.6034
   action_noise.session[4]    0.0729    0.1072    0.1395    0.1850    0.3549
   action_noise.session[5]    1.0990    1.5096    1.8189    2.1916    3.5158
   action_noise.session[6]    1.1299    1.5109    1.8168    2.2500    3.5459

We can see in the Chains object that each parameter is estimated separately for each session. This means that the population model parameters in this case are identical to the session parameters. Therefore, the standard plotting functions will plot the session parameters directly.

#TODO: plot(model)

We see that the learning rate (by construction) is lower in the treatment condition, and that sessions A, B C have the lowest, middle and highest learning rates, respectively.

We can also extract the session parameters as a DataFrame for further analysis.

parameter_df = summarize(get_session_parameters!(model))

show(parameter_df)
6×4 DataFrame
 Row │ id      treatment  action_noise  learning_rate
     │ String  String     Float64       Float64
─────┼────────────────────────────────────────────────
   1 │ A       control       0.0880035      0.0851549
   2 │ A       treatment     0.0523786      0.0401189
   3 │ B       control       0.250855       0.344341
   4 │ B       treatment     0.139543       0.171876
   5 │ C       control       1.8189         0.633597
   6 │ C       treatment     1.8168         0.625952

Linear regression population models

Often, the goal of cognitive modelling is to relate differences in parameter values to some external variables, such as treatment conditions or participant characteristics. This is often done with a linear regression, where point estimates of the parameters are predicted from the external variables. In this case, it is informationally advantageous to fit the linear model as part of the population model, rather than fitting the linear model separately and then using the point estimates as the population model. ActionModels provides a pre-made linear regression population model that can be used for this purpose.

To specify a linear regression popluation model, we create a tuple of Regression objects, where each Regression object specifies a regression model for one of the parameters to estimate. The regression is specified with standard LMER syntax, where the formula is specified as a @formula object. Here, we predict each parameter from the treatment condition, and add a random intercept for each session ID (making this a classic hierarchical model). For each regression, we can also specify an inverse link function. This function transforms the output of the regression, and is commonly used to ensure that the resulting parameter values are in the correct range. Here, we use the logistic function for the learning rate (to ensure it is between 0 and 1) and the exp function for the action noise (to ensure it is positive). If no inverse link function is specified, the identity function is used by default.

population_model = [
    Regression(@formula(learning_rate ~ 1 + treatment + (1 | id)), logistic),
    Regression(@formula(action_noise ~ 1 + treatment + (1 | id)), exp),
];

It is possible to specify the priors for the regression population model with the RegressionPrior constructor. It is possible to specify priors for the regression coefficients (β) and the standard deviation of the random effects (σ). If a single deistribution is specified for a given type, it is used for all parameters of that type. To set separate priors for each predictor or each random effect, a vector of distributions can be passed instead. Here, we specify a Student's t-distribution with 3 degrees of freedom for the regression coefficients (β) and an Exponential distribution with rate 1 for the standard deviation of the random effects (σ). These values are for use in the regression before the inverse link function is applied. This means that values above 5 are somewhat extreme, so these priors are probably too broad for most usecases.

prior = RegressionPrior(β = [TDist(3), Normal()], σ = Exponential(1))

population_model = [
    Regression(@formula(learning_rate ~ 1 + treatment + (1 | id)), logistic, prior),
    Regression(@formula(action_noise ~ 1 + treatment + (1 | id)), exp, prior),
]


plot(prior.β[1], label = "Intercept")
plot!(prior.β[2], label = "Effect of treatment")
plot!(prior.σ, label = "Random intercept std")
title!("Regression priors for the Rescorla-Wagner model")
xlabel!("Regression coefficient")
ylabel!("Density")

We can then create the full model, and sample from the posterior

model = create_model(
    action_model,
    population_model,
    data;
    action_cols = action_cols,
    observation_cols = observation_cols,
    session_cols = session_cols,
)

chns = sample_posterior!(model)
Chains MCMC chain (1000×24×2 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 2
Samples per chain = 1000
Wall duration     = 51.16 seconds
Compute duration  = 50.88 seconds
parameters        = learning_rate.β[1], learning_rate.β[2], learning_rate.ranef_1.σ[1], learning_rate.ranef_1.r[1], learning_rate.ranef_1.r[2], learning_rate.ranef_1.r[3], action_noise.β[1], action_noise.β[2], action_noise.ranef_1.σ[1], action_noise.ranef_1.r[1], action_noise.ranef_1.r[2], action_noise.ranef_1.r[3]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
                  parameters      mean       std      mcse    ess_bulk    ess_tail      rhat   ess_per_sec
                      Symbol   Float64   Float64   Float64     Float64     Float64   Float64       Float64

          learning_rate.β[1]   -0.3006    0.8432    0.0324    695.3430    460.7344    0.9997       13.6671
          learning_rate.β[2]   -0.8272    0.1016    0.0023   1893.4854   1402.7090    0.9996       37.2169
  learning_rate.ranef_1.σ[1]    1.9811    0.8986    0.0267   1180.7318   1424.4527    1.0006       23.2076
  learning_rate.ranef_1.r[1]   -2.1025    0.8467    0.0325    693.6386    450.0751    0.9997       13.6336
  learning_rate.ranef_1.r[2]   -0.4400    0.8497    0.0325    698.8130    468.4971    0.9996       13.7353
  learning_rate.ranef_1.r[3]    2.2159    1.6280    0.0517   1280.5595   1104.1716    0.9999       25.1697
           action_noise.β[1]   -0.6999    0.7865    0.0324    597.5387    662.5751    1.0039       11.7448
           action_noise.β[2]   -0.3713    0.2525    0.0061   1731.1649   1356.4465    1.0011       34.0265
   action_noise.ranef_1.σ[1]    1.8612    0.7058    0.0217   1130.7256   1297.5853    1.0010       22.2247
   action_noise.ranef_1.r[1]   -2.2087    0.8036    0.0329    605.0450    649.9050    1.0040       11.8923
   action_noise.ranef_1.r[2]   -1.0330    0.8071    0.0333    591.1629    742.0487    1.0041       11.6195
   action_noise.ranef_1.r[3]    1.5202    0.8025    0.0318    650.8611    681.5342    1.0031       12.7928

Quantiles
                  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
                      Symbol   Float64   Float64   Float64   Float64   Float64

          learning_rate.β[1]   -2.0397   -0.8233   -0.3228    0.1926    1.5154
          learning_rate.β[2]   -1.0363   -0.8946   -0.8261   -0.7589   -0.6255
  learning_rate.ranef_1.σ[1]    0.7980    1.3531    1.7950    2.4433    4.2117
  learning_rate.ranef_1.r[1]   -3.9109   -2.6149   -2.0878   -1.5699   -0.3291
  learning_rate.ranef_1.r[2]   -2.2391   -0.9412   -0.4200    0.0790    1.2997
  learning_rate.ranef_1.r[3]   -0.2247    1.1695    1.9590    3.0523    6.3246
           action_noise.β[1]   -2.3032   -1.1910   -0.7058   -0.2124    0.8310
           action_noise.β[2]   -0.8665   -0.5402   -0.3642   -0.1986    0.1154
   action_noise.ranef_1.σ[1]    0.9157    1.3703    1.7311    2.1761    3.6765
   action_noise.ranef_1.r[1]   -3.8289   -2.7313   -2.2051   -1.6915   -0.6299
   action_noise.ranef_1.r[2]   -2.6619   -1.5472   -1.0398   -0.5153    0.6172
   action_noise.ranef_1.r[3]   -0.1090    1.0379    1.5007    2.0333    3.1320

We can see that there are, for each parameter, two $\beta$ estimates, where the first is the intercept and the second is the treatment effect. We can also see that there is, for each random effect, a $\sigma$ estimate, which is the standard deviation of the random intercepts, as well as each of the sampled random intercepts for each session. We can also see that the effect of the treatment is estimated to be negative for the learning rate, which is by construction.

title = plot(
    title = "Posterior over treatment effect",
    grid = false,
    showaxis = false,
    bottom_margin = -30Plots.px,
)
plot(
    title,
    density(chns[Symbol("learning_rate.β[2]")], title = "learning rate", label = nothing),
    density(chns[Symbol("action_noise.β[2]")], title = "action noise", label = nothing),
    layout = @layout([A{0.01h}; [B C]])
)

When using the linear regression population model, the session parameters are not directly estimated, but rather the regression coefficients and random effects. We can still extract the session parameters as before, however.

parameter_df = summarize(get_session_parameters!(model))

show(parameter_df)
6×4 DataFrame
 Row │ id      treatment  action_noise  learning_rate
     │ String  String     Float64       Float64
─────┼────────────────────────────────────────────────
   1 │ A       control       0.0540524      0.0829063
   2 │ A       treatment     0.0371126      0.038081
   3 │ B       control       0.174928       0.323045
   4 │ B       treatment     0.121          0.172382
   5 │ C       control       2.24936        0.843154
   6 │ C       treatment     1.55008        0.703965

Custom population models

Finally, it is also possible to specify a custom population model for use with ActionModels. This is done by creating a conditioned Turing model that describes the population model, and which returns the sampled parameters for each session. The output of the model, which is used as a Turing submodel, must be iterable, and must for each session return a tuple with the parameter names as keys and the sampled values as values. The order of the sessions will be the same as the order of the sessions in the data, so it is important to ensure that the model returns the parameters in the correct order. Additionally, the names of the parameters is also passed as a Tuple to the model, so that the parameters can be correctly matched to the data. The order of the vector and the sampled parameters must match. Here, we create a custom population model where the learning rate and action noise for each session are sampled from multivariate normal distributions and then transformed.

#Load Turing
using ActionModels: Turing

#Get the number of sessions in the data
n_sessions = nrow(unique(data, session_cols))

#Create the Turing model for the custom population model
@model function custom_population_model(n_sessions::Int64)

    #Sample parameters for each session
    learning_rates ~ MvNormal(zeros(n_sessions), I)
    action_noises ~ MvNormal(zeros(n_sessions), I)

    #Transform the parameters to the correct range
    learning_rates = logistic.(learning_rates)
    action_noises = exp.(action_noises)

    #Return the parameters as an iterable of tuples
    return zip(learning_rates, action_noises)
end

#Condition the turing model
population_model = custom_population_model(n_sessions)

#Specify which parameters are estimated
parameters_to_estimate = (:learning_rate, :action_noise);

We can then create the full model, and sample from the posterior

model = create_model(
    action_model,
    population_model,
    data;
    action_cols = action_cols,
    observation_cols = observation_cols,
    session_cols = session_cols,
    parameters_to_estimate = parameters_to_estimate,
)

chns = sample_posterior!(model)
Chains MCMC chain (1000×24×2 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 2
Samples per chain = 1000
Wall duration     = 7.27 seconds
Compute duration  = 6.47 seconds
parameters        = learning_rates[1], learning_rates[2], learning_rates[3], learning_rates[4], learning_rates[5], learning_rates[6], action_noises[1], action_noises[2], action_noises[3], action_noises[4], action_noises[5], action_noises[6]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
         parameters      mean       std      mcse    ess_bulk    ess_tail      rhat   ess_per_sec
             Symbol   Float64   Float64   Float64     Float64     Float64   Float64       Float64

  learning_rates[1]   -2.3631    0.1665    0.0070   1190.2449    641.5536    1.0023      184.0490
  learning_rates[2]   -3.1511    0.1796    0.0078    820.0419    491.5452    1.0014      126.8041
  learning_rates[3]   -0.5988    0.3260    0.0105   1344.8261    796.4016    0.9999      207.9521
  learning_rates[4]   -1.5591    0.1783    0.0075    928.5766    494.7849    1.0021      143.5869
  learning_rates[5]    0.5467    0.8773    0.0167   2761.6116   1345.4202    1.0012      427.0313
  learning_rates[6]    0.5703    0.9102    0.0203   2017.4966   1383.6599    1.0003      311.9679
   action_noises[1]   -2.3653    0.4704    0.0140   1383.0767    774.0021    1.0005      213.8668
   action_noises[2]   -2.8549    0.5075    0.0182    924.9947    766.0870    1.0013      143.0330
   action_noises[3]   -1.3557    0.3989    0.0111   1468.5047    759.7253    1.0000      227.0767
   action_noises[4]   -1.9022    0.4234    0.0144   1018.8588    686.9066    1.0004      157.5474
   action_noises[5]    0.6136    0.2904    0.0069   2136.1178   1095.7191    1.0002      330.3105
   action_noises[6]    0.6073    0.2917    0.0068   2059.3249   1384.2353    1.0000      318.4359

Quantiles
         parameters      2.5%     25.0%     50.0%     75.0%     97.5%
             Symbol   Float64   Float64   Float64   Float64   Float64

  learning_rates[1]   -2.6624   -2.4520   -2.3716   -2.2895   -2.0178
  learning_rates[2]   -3.4262   -3.2553   -3.1736   -3.0801   -2.6914
  learning_rates[3]   -1.1424   -0.7879   -0.6292   -0.4484    0.1568
  learning_rates[4]   -1.8740   -1.6635   -1.5731   -1.4734   -1.1630
  learning_rates[5]   -1.1348   -0.0297    0.5020    1.1434    2.2824
  learning_rates[6]   -1.1673   -0.0185    0.5452    1.1621    2.4473
   action_noises[1]   -3.1154   -2.7044   -2.4232   -2.0863   -1.2958
   action_noises[2]   -3.6674   -3.2175   -2.9199   -2.5713   -1.7242
   action_noises[3]   -2.0348   -1.6417   -1.4048   -1.1052   -0.4297
   action_noises[4]   -2.6011   -2.1947   -1.9471   -1.6578   -0.9132
   action_noises[5]    0.1338    0.4041    0.5926    0.7891    1.2523
   action_noises[6]    0.0904    0.4059    0.5903    0.7867    1.2349

We can see that the parameter are estimated for each session, in non-transformed space. We can still extract the session parameters as before. This extracts the parameters in the version that they are passed to the action model - i.e., transformed.

parameter_df = summarize(get_session_parameters!(model))

show(parameter_df)
6×4 DataFrame
 Row │ id      treatment  action_noise  learning_rate
     │ String  String     Float64       Float64
─────┼────────────────────────────────────────────────
   1 │ A       control       0.0886413      0.0853674
   2 │ A       treatment     0.0539416      0.0401718
   3 │ B       control       0.245404       0.347691
   4 │ B       treatment     0.142689       0.171773
   5 │ C       control       1.80875        0.622939
   6 │ C       treatment     1.80449        0.633018

Single session population model

Notably, it is also possible to use ActionModels with only a single session. In this case, the actions and observations can be passed as two vectors, instead of passing a DataFrame. For single session models, only the independent session population model is appropriate, since there is no session structure to model.

#Create a single session dataset
observations = [1.0, 1, 1, 2, 2, 2]
actions = [0, 0.2, 0.3, 0.4, 0.5, 0.6]

#Create a standard independent sessions population model
population_model = (learning_rate = LogitNormal(), action_noise = LogNormal())

#Create full model
model = create_model(action_model, population_model, observations, actions)

#Sample from the posterior
chns = sample_posterior!(model)
Chains MCMC chain (1000×14×2 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 2
Samples per chain = 1000
Wall duration     = 2.46 seconds
Compute duration  = 2.3 seconds
parameters        = learning_rate.session[1], action_noise.session[1]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
                parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec
                    Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64

  learning_rate.session[1]    0.0865    0.0114    0.0004   818.9929   712.8329    1.0029      355.9291
   action_noise.session[1]    0.1018    0.0548    0.0022   529.5521   705.4878    1.0010      230.1400

Quantiles
                parameters      2.5%     25.0%     50.0%     75.0%     97.5%
                    Symbol   Float64   Float64   Float64   Float64   Float64

  learning_rate.session[1]    0.0671    0.0796    0.0852    0.0914    0.1152
   action_noise.session[1]    0.0431    0.0654    0.0867    0.1204    0.2561

We can still plot the estimated parameters for the session

#TODO: plot(model)

This page was generated using Literate.jl.