Population models

Apart from the action model, which describes how actions are chosen and states updated on a timestep-by-timestep basis, it is also necessary to specify a population model when fitting to data. The population model describes how the parameters of the action model are distributed across the sessions in the dataset (i.e. the population). ActionModels provides two types of pre-created population models that the user can choose from. One is the independent sessions population model, in which it is assumed that the parameters for each session are independent from each but other, but has the same prior. The other is the linear regression population model, in which it is assumed that the parameters of each session are related to external variables through a linear regression. The linear regression population model also allows for specifying a hierarchical structure (i.e. a random effect), which is often used to improve parameter estimation in cognitive modelling. Finally, users can also specify their own custom population model as a Turing model. If there is onle a single session in the dataset, only the independent sessions population model is appropriate, since there is no session structure to model.

First, we load the ActionModels package, and StatsPlots for plotting results

using ActionModels
using StatsPlots

We then specify the data that we want to our model to. For this example, we will use a simple manually created dataset, where three participants have completed an experiment where they must predict the next location of a moving target. Each participant has completed the experiment twice, in a control condition and under an experimental treatment.

using DataFrames

data = DataFrame(
    observations = repeat([1.0, 1, 1, 2, 2, 2], 6),
    actions = vcat(
        [0, 0.2, 0.3, 0.4, 0.5, 0.6],
        [0, 0.5, 0.8, 1, 1.5, 1.8],
        [0, 2, 0.5, 4, 5, 3],
        [0, 0.1, 0.15, 0.2, 0.25, 0.3],
        [0, 0.2, 0.4, 0.7, 1.0, 1.1],
        [0, 2, 0.5, 4, 5, 3],
    ),
    id = vcat(
        repeat(["A"], 6),
        repeat(["B"], 6),
        repeat(["C"], 6),
        repeat(["A"], 6),
        repeat(["B"], 6),
        repeat(["C"], 6),
    ),
    treatment = vcat(repeat(["control"], 18), repeat(["treatment"], 18)),
)

show(data)
36×4 DataFrame
 Row │ observations  actions  id      treatment
     │ Float64       Float64  String  String
─────┼──────────────────────────────────────────
   1 │          1.0      0.0  A       control
   2 │          1.0      0.2  A       control
   3 │          1.0      0.3  A       control
   4 │          2.0      0.4  A       control
   5 │          2.0      0.5  A       control
   6 │          2.0      0.6  A       control
   7 │          1.0      0.0  B       control
   8 │          1.0      0.5  B       control
  ⋮  │      ⋮           ⋮       ⋮         ⋮
  30 │          2.0      1.1  B       treatment
  31 │          1.0      0.0  C       treatment
  32 │          1.0      2.0  C       treatment
  33 │          1.0      0.5  C       treatment
  34 │          2.0      4.0  C       treatment
  35 │          2.0      5.0  C       treatment
  36 │          2.0      3.0  C       treatment
                                 21 rows omitted

We specify which columns in the data correspond to the actions, observations, and session identifiers.

action_cols = :actions;
observation_cols = :observations;
session_cols = [:id, :treatment];

Finally, we specify the action model. We here use the premade Rescorla-Wagner action model provided by ActionModels. This is identical to the model described in the defining action models REF section.

action_model = ActionModel(RescorlaWagner())
-- ActionModel --
Action model function: rescorla_wagner_act_after_update
Number of parameters: 1
Number of states: 0
Number of observations: 1
Number of actions: 1
submodel type: ActionModels.ContinuousRescorlaWagner

Independent session population model

To specify the independent session population model, we only need to specify the prior distribution for each parameter to estimate. This is specified as a NamedTuple with the parameter names as keys and the prior distributions as values. We here select a LogitNormal prior for the learning rate, since it is constrained to be between o and 1, and a LogNormal prior for the action noise, since it is constrained to be positive.

population_model = (learning_rate = LogitNormal(), action_noise = LogNormal());

We can then create the full model, and sample from the posterior

model = create_model(
    action_model,
    population_model,
    data;
    action_cols = action_cols,
    observation_cols = observation_cols,
    session_cols = session_cols,
)

chns = sample_posterior!(model)
Chains MCMC chain (1000×24×2 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 2
Samples per chain = 1000
Wall duration     = 8.38 seconds
Compute duration  = 7.5 seconds
parameters        = learning_rate.session[1], learning_rate.session[2], learning_rate.session[3], learning_rate.session[4], learning_rate.session[5], learning_rate.session[6], action_noise.session[1], action_noise.session[2], action_noise.session[3], action_noise.session[4], action_noise.session[5], action_noise.session[6]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
                parameters      mean       std      mcse    ess_bulk    ess_tail      rhat   ess_per_sec
                    Symbol   Float64   Float64   Float64     Float64     Float64   Float64       Float64

  learning_rate.session[1]    0.0865    0.0133    0.0006   1100.2745    548.2045    1.0020      146.7815
  learning_rate.session[2]    0.3596    0.0814    0.0040    910.2911    396.7386    1.0038      121.4369
  learning_rate.session[3]    0.6179    0.1752    0.0040   1920.7375   1453.7957    1.0008      256.2350
  learning_rate.session[4]    0.0427    0.0149    0.0008   1012.8125    483.9725    1.0054      135.1137
  learning_rate.session[5]    0.1737    0.0223    0.0006   1498.2800   1094.3198    0.9999      199.8773
  learning_rate.session[6]    0.6092    0.1751    0.0045   1560.7516   1392.8291    1.0020      208.2113
   action_noise.session[1]    0.1030    0.0613    0.0025    975.0942    772.9122    1.0010      130.0819
   action_noise.session[2]    0.2734    0.1171    0.0041   1186.3347    886.4857    1.0029      158.2624
   action_noise.session[3]    1.9514    0.6357    0.0170   1663.6013   1160.1601    1.0016      221.9319
   action_noise.session[4]    0.0722    0.0740    0.0038    951.0474    518.4142    1.0063      126.8740
   action_noise.session[5]    0.1586    0.0780    0.0023   1710.1374   1384.9489    1.0017      228.1400
   action_noise.session[6]    1.9343    0.5885    0.0142   2119.5175   1626.8434    1.0024      282.7531

Quantiles
                parameters      2.5%     25.0%     50.0%     75.0%     97.5%
                    Symbol   Float64   Float64   Float64   Float64   Float64

  learning_rate.session[1]    0.0655    0.0795    0.0848    0.0916    0.1156
  learning_rate.session[2]    0.2468    0.3115    0.3450    0.3885    0.5667
  learning_rate.session[3]    0.2441    0.5004    0.6325    0.7516    0.9039
  learning_rate.session[4]    0.0308    0.0374    0.0402    0.0441    0.0684
  learning_rate.session[5]    0.1347    0.1601    0.1719    0.1847    0.2231
  learning_rate.session[6]    0.2442    0.4877    0.6297    0.7445    0.8993
   action_noise.session[1]    0.0433    0.0663    0.0869    0.1194    0.2735
   action_noise.session[2]    0.1359    0.1923    0.2449    0.3177    0.5894
   action_noise.session[3]    1.1383    1.5179    1.8224    2.2330    3.5232
   action_noise.session[4]    0.0257    0.0401    0.0536    0.0788    0.2253
   action_noise.session[5]    0.0716    0.1079    0.1401    0.1864    0.3477
   action_noise.session[6]    1.1345    1.5145    1.8265    2.2193    3.3286

We can see in the Chains object that each parameter is estimated separately for each session. This means that the population model parameters in this case are identical to the session parameters. Therefore, the standard plotting functions will plot the session parameters directly.

#TODO: plot(model)

We see that the learning rate (by construction) is lower in the treatment condition, and that sessions A, B C have the lowest, middle and highest learning rates, respectively.

We can also extract the session parameters as a DataFrame for further analysis.

parameter_df = summarize(get_session_parameters!(model))

show(parameter_df)
6×4 DataFrame
 Row │ id      treatment  action_noise  learning_rate
     │ String  String     Float64       Float64
─────┼────────────────────────────────────────────────
   1 │ A       control       0.0869008      0.0847564
   2 │ B       control       0.244907       0.345021
   3 │ C       control       1.8224         0.632463
   4 │ A       treatment     0.0536363      0.0401632
   5 │ B       treatment     0.14009        0.171898
   6 │ C       treatment     1.82646        0.629654

Linear regression population models

Often, the goal of cognitive modelling is to relate differences in parameter values to some external variables, such as treatment conditions or participant characteristics. This is often done with a linear regression, where point estimates of the parameters are predicted from the external variables. In this case, it is informationally advantageous to fit the linear model as part of the population model, rather than fitting the linear model separately and then using the point estimates as the population model. ActionModels provides a pre-made linear regression population model that can be used for this purpose.

To specify a linear regression popluation model, we create a tuple of Regression objects, where each Regression object specifies a regression model for one of the parameters to estimate. The regression is specified with standard LMER syntax, where the formula is specified as a @formula object. Here, we predict each parameter from the treatment condition, and add a random intercept for each session ID (making this a classic hierarchical model). For each regression, we can also specify an inverse link function. This function transforms the output of the regression, and is commonly used to ensure that the resulting parameter values are in the correct range. Here, we use the logistic function for the learning rate (to ensure it is between 0 and 1) and the exp function for the action noise (to ensure it is positive). If no inverse link function is specified, the identity function is used by default.

population_model = [
    Regression(@formula(learning_rate ~ 1 + treatment + (1 | id)), logistic),
    Regression(@formula(action_noise ~ 1 + treatment + (1 | id)), exp),
];

It is possible to specify the priors for the regression population model with the RegressionPrior constructor. It is possible to specify priors for the regression coefficients (β) and the standard deviation of the random effects (σ). If a single deistribution is specified for a given type, it is used for all parameters of that type. To set separate priors for each predictor or each random effect, a vector of distributions can be passed instead. Here, we specify a Student's t-distribution with 3 degrees of freedom for the regression coefficients (β) and an Exponential distribution with rate 1 for the standard deviation of the random effects (σ). These values are for use in the regression before the inverse link function is applied. This means that values above 5 are somewhat extreme, so these priors are probably too broad for most usecases.

prior = RegressionPrior(β = [TDist(3), Normal()], σ = Exponential(1))

population_model = [
    Regression(@formula(learning_rate ~ 1 + treatment + (1 | id)), logistic, prior),
    Regression(@formula(action_noise ~ 1 + treatment + (1 | id)), exp, prior),
]


plot(prior.β[1], label = "Intercept")
plot!(prior.β[2], label = "Effect of treatment")
plot!(prior.σ, label = "Random intercept std")
title!("Regression priors for the Rescorla-Wagner model")
xlabel!("Regression coefficient")
ylabel!("Density")

We can then create the full model, and sample from the posterior

model = create_model(
    action_model,
    population_model,
    data;
    action_cols = action_cols,
    observation_cols = observation_cols,
    session_cols = session_cols,
)

chns = sample_posterior!(model)
Chains MCMC chain (1000×24×2 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 2
Samples per chain = 1000
Wall duration     = 67.02 seconds
Compute duration  = 66.72 seconds
parameters        = learning_rate.β[1], learning_rate.β[2], learning_rate.ranef_1.σ[1], learning_rate.ranef_1.r[1], learning_rate.ranef_1.r[2], learning_rate.ranef_1.r[3], action_noise.β[1], action_noise.β[2], action_noise.ranef_1.σ[1], action_noise.ranef_1.r[1], action_noise.ranef_1.r[2], action_noise.ranef_1.r[3]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
                  parameters      mean       std      mcse    ess_bulk    ess_tail      rhat   ess_per_sec
                      Symbol   Float64   Float64   Float64     Float64     Float64   Float64       Float64

          learning_rate.β[1]   -0.2447    0.8438    0.0317    748.6932    879.6147    1.0025       11.2206
          learning_rate.β[2]   -0.8272    0.1060    0.0025   1752.7039   1311.4515    1.0007       26.2676
  learning_rate.ranef_1.σ[1]    2.0307    0.9141    0.0268   1249.2890   1172.1578    1.0020       18.7230
  learning_rate.ranef_1.r[1]   -2.1582    0.8480    0.0319    745.2320    896.0456    1.0026       11.1687
  learning_rate.ranef_1.r[2]   -0.4967    0.8526    0.0322    745.1787    865.4898    1.0027       11.1679
  learning_rate.ranef_1.r[3]    2.3309    1.7053    0.0645   1023.1370    471.6576    1.0027       15.3336
           action_noise.β[1]   -0.6776    0.8144    0.0314    679.4431    687.8172    1.0034       10.1827
           action_noise.β[2]   -0.3709    0.2525    0.0064   1565.5306   1300.7278    1.0005       23.4624
   action_noise.ranef_1.σ[1]    1.8472    0.6982    0.0199   1350.2708   1481.9488    1.0008       20.2364
   action_noise.ranef_1.r[1]   -2.2267    0.8347    0.0327    659.0436    568.3449    1.0031        9.8770
   action_noise.ranef_1.r[2]   -1.0500    0.8389    0.0322    690.3257    687.2904    1.0052       10.3458
   action_noise.ranef_1.r[3]    1.4827    0.8310    0.0320    685.0286    726.0597    1.0043       10.2664

Quantiles
                  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
                      Symbol   Float64   Float64   Float64   Float64   Float64

          learning_rate.β[1]   -1.8345   -0.7613   -0.2704    0.2894    1.3731
          learning_rate.β[2]   -1.0342   -0.8958   -0.8259   -0.7588   -0.6153
  learning_rate.ranef_1.σ[1]    0.7969    1.4027    1.8508    2.4517    4.2686
  learning_rate.ranef_1.r[1]   -3.7939   -2.7044   -2.1391   -1.6385   -0.5618
  learning_rate.ranef_1.r[2]   -2.1398   -1.0442   -0.4797    0.0461    1.1324
  learning_rate.ranef_1.r[3]   -0.3037    1.1883    2.0907    3.1718    6.2939
           action_noise.β[1]   -2.2828   -1.2129   -0.6713   -0.1540    0.9660
           action_noise.β[2]   -0.8629   -0.5446   -0.3749   -0.2058    0.1451
   action_noise.ranef_1.σ[1]    0.9100    1.3467    1.6960    2.1821    3.6452
   action_noise.ranef_1.r[1]   -3.8580   -2.7452   -2.2008   -1.6915   -0.6322
   action_noise.ranef_1.r[2]   -2.7424   -1.5682   -1.0371   -0.5169    0.6406
   action_noise.ranef_1.r[3]   -0.2574    0.9721    1.4966    2.0197    3.0506

We can see that there are, for each parameter, two $\beta$ estimates, where the first is the intercept and the second is the treatment effect. We can also see that there is, for each random effect, a $\sigma$ estimate, which is the standard deviation of the random intercepts, as well as each of the sampled random intercepts for each session. We can also see that the effect of the treatment is estimated to be negative for the learning rate, which is by construction.

title = plot(
    title = "Posterior over treatment effect",
    grid = false,
    showaxis = false,
    bottom_margin = -30Plots.px,
)
plot(
    title,
    density(chns[Symbol("learning_rate.β[2]")], title = "learning rate", label = nothing),
    density(chns[Symbol("action_noise.β[2]")], title = "action noise", label = nothing),
    layout = @layout([A{0.01h}; [B C]])
)

When using the linear regression population model, the session parameters are not directly estimated, but rather the regression coefficients and random effects. We can still extract the session parameters as before, however.

parameter_df = summarize(get_session_parameters!(model))

show(parameter_df)
6×4 DataFrame
 Row │ id      treatment  action_noise  learning_rate
     │ String  String     Float64       Float64
─────┼────────────────────────────────────────────────
   1 │ A       control       0.0538537      0.0830313
   2 │ B       control       0.174386       0.323102
   3 │ C       control       2.20466        0.856355
   4 │ A       treatment     0.0374602      0.0381377
   5 │ B       treatment     0.12162        0.171993
   6 │ C       treatment     1.52124        0.727108

Custom population models

Finally, it is also possible to specify a custom population model for use with ActionModels. This is done by creating a conditioned Turing model that describes the population model, and which returns the sampled parameters for each session. The output of the model, which is used as a Turing submodel, must be iterable, and must for each session return a tuple with the parameter names as keys and the sampled values as values. The order of the sessions will be the same as the order of the sessions in the data, so it is important to ensure that the model returns the parameters in the correct order. Additionally, the names of the parameters is also passed as a Tuple to the model, so that the parameters can be correctly matched to the data. The order of the vector and the sampled parameters must match. Here, we create a custom population model where the learning rate and action noise for each session are sampled from multivariate normal distributions and then transformed.

#Load Turing
using ActionModels: Turing

#Get the number of sessions in the data
n_sessions = nrow(unique(data, session_cols))

#Create the Turing model for the custom population model
@model function custom_population_model(n_sessions::Int64)

    #Sample parameters for each session
    learning_rates ~ MvNormal(zeros(n_sessions), I)
    action_noises ~ MvNormal(zeros(n_sessions), I)

    #Transform the parameters to the correct range
    learning_rates = logistic.(learning_rates)
    action_noises = exp.(action_noises)

    #Return the parameters as an iterable of tuples
    return zip(learning_rates, action_noises)
end

#Condition the turing model
population_model = custom_population_model(n_sessions)

#Specify which parameters are estimated
parameters_to_estimate = (:learning_rate, :action_noise);

We can then create the full model, and sample from the posterior

model = create_model(
    action_model,
    population_model,
    data;
    action_cols = action_cols,
    observation_cols = observation_cols,
    session_cols = session_cols,
    parameters_to_estimate = parameters_to_estimate,
)

chns = sample_posterior!(model)
Chains MCMC chain (1000×24×2 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 2
Samples per chain = 1000
Wall duration     = 7.52 seconds
Compute duration  = 6.68 seconds
parameters        = learning_rates[1], learning_rates[2], learning_rates[3], learning_rates[4], learning_rates[5], learning_rates[6], action_noises[1], action_noises[2], action_noises[3], action_noises[4], action_noises[5], action_noises[6]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
         parameters      mean       std      mcse    ess_bulk    ess_tail      rhat   ess_per_sec
             Symbol   Float64   Float64   Float64     Float64     Float64   Float64       Float64

  learning_rates[1]   -2.3661    0.1560    0.0058   1000.0683    610.9035    1.0020      149.6436
  learning_rates[2]   -0.5878    0.3573    0.0150   1164.8785    486.1691    1.0054      174.3047
  learning_rates[3]    0.5570    0.8448    0.0166   2608.6847   1646.0775    1.0015      390.3464
  learning_rates[4]   -3.1636    0.1664    0.0054   1153.5685    704.6482    1.0002      172.6124
  learning_rates[5]   -1.5564    0.1743    0.0056   1212.3206    695.9803    1.0018      181.4037
  learning_rates[6]    0.4975    0.9012    0.0185   2428.4861   1356.9940    1.0014      363.3826
   action_noises[1]   -2.3770    0.4574    0.0137   1420.9694    687.1504    1.0047      212.6245
   action_noises[2]   -1.3505    0.3924    0.0116   1313.3550    870.6694    0.9999      196.5218
   action_noises[3]    0.6021    0.2987    0.0059   2738.6077   1452.3791    1.0030      409.7872
   action_noises[4]   -2.8810    0.5156    0.0158   1186.8953    898.4864    1.0009      177.5992
   action_noises[5]   -1.9010    0.4251    0.0145    992.9820    679.5898    1.0032      148.5833
   action_noises[6]    0.6086    0.2880    0.0059   2606.2684   1309.6396    1.0031      389.9848

Quantiles
         parameters      2.5%     25.0%     50.0%     75.0%     97.5%
             Symbol   Float64   Float64   Float64   Float64   Float64

  learning_rates[1]   -2.6324   -2.4608   -2.3767   -2.2889   -2.0195
  learning_rates[2]   -1.1586   -0.7861   -0.6270   -0.4489    0.1720
  learning_rates[3]   -1.0665    0.0069    0.5103    1.1194    2.2484
  learning_rates[4]   -3.4770   -3.2542   -3.1745   -3.0881   -2.7836
  learning_rates[5]   -1.8871   -1.6548   -1.5670   -1.4653   -1.1979
  learning_rates[6]   -1.2515   -0.1013    0.4807    1.0833    2.2703
   action_noises[1]   -3.1374   -2.7052   -2.4252   -2.0962   -1.3192
   action_noises[2]   -1.9959   -1.6336   -1.3927   -1.1163   -0.4938
   action_noises[3]    0.0834    0.4005    0.5783    0.7753    1.2710
   action_noises[4]   -3.7143   -3.2549   -2.9330   -2.5840   -1.7098
   action_noises[5]   -2.6170   -2.2126   -1.9549   -1.6383   -0.9480
   action_noises[6]    0.0987    0.4043    0.5921    0.7863    1.2480

We can see that the parameter are estimated for each session, in non-transformed space. We can still extract the session parameters as before. This extracts the parameters in the version that they are passed to the action model - i.e., transformed.

parameter_df = summarize(get_session_parameters!(model))

show(parameter_df)
6×4 DataFrame
 Row │ id      treatment  action_noise  learning_rate
     │ String  String     Float64       Float64
─────┼────────────────────────────────────────────────
   1 │ A       control       0.0884572      0.0849647
   2 │ B       control       0.248401       0.348196
   3 │ C       control       1.78307        0.624868
   4 │ A       treatment     0.0532394      0.0401376
   5 │ B       treatment     0.141572       0.172647
   6 │ C       treatment     1.8077         0.617914

Single session population model

Notably, it is also possible to use ActionModels with only a single session. In this case, the actions and observations can be passed as two vectors, instead of passing a DataFrame. For single session models, only the independent session population model is appropriate, since there is no session structure to model.

#Create a single session dataset
observations = [1.0, 1, 1, 2, 2, 2]
actions = [0, 0.2, 0.3, 0.4, 0.5, 0.6]

#Create a standard independent sessions population model
population_model = (learning_rate = LogitNormal(), action_noise = LogNormal())

#Create full model
model = create_model(action_model, population_model, observations, actions)

#Sample from the posterior
chns = sample_posterior!(model)
Chains MCMC chain (1000×14×2 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 2
Samples per chain = 1000
Wall duration     = 2.52 seconds
Compute duration  = 2.36 seconds
parameters        = learning_rate.session[1], action_noise.session[1]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
                parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec
                    Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64

  learning_rate.session[1]    0.0873    0.0147    0.0007   600.5320   474.8520    1.0008      255.0030
   action_noise.session[1]    0.1037    0.0631    0.0030   514.5651   531.6188    1.0015      218.4990

Quantiles
                parameters      2.5%     25.0%     50.0%     75.0%     97.5%
                    Symbol   Float64   Float64   Float64   Float64   Float64

  learning_rate.session[1]    0.0669    0.0798    0.0852    0.0916    0.1200
   action_noise.session[1]    0.0433    0.0647    0.0871    0.1191    0.2825

We can still plot the estimated parameters for the session

#TODO: plot(model)

This page was generated using Literate.jl.