Fitting action models to data

In this section, we will cover how to fit action models to data with ActionModels. First, we will demonstrate how to set up the model that will be fit under hood by relying on Turing. This consists of specifying three things: an action model, a population model, and the data to fit the model to. Then, we will show how to sample from the posterior and prior distribution of the model parameters. And finally, we will show the tools in ActionModels to inspect and extract the results of the model fitting.

Setting up the model

First we import ActionModels, as well as StatsPlots for plotting the results later.

using ActionModels, StatsPlots

Defining the action model

We will here use the premade Rescorla-Wagner action model provided by ActionModels. This is identical to the model described in the defining action models REF section.

action_model = ActionModel(RescorlaWagner())
-- ActionModel --
Action model function: rescorla_wagner_act_after_update
Number of parameters: 1
Number of states: 0
Number of observations: 1
Number of actions: 1
submodel type: ActionModels.ContinuousRescorlaWagner

Loading the data

We will then specify the data that we want to fit the model to. For this example, we will use a simple manually created dataset, where three participants have completed an experiment where they must predict the next location of a moving target. Each participant has completed the experiment twice, in a control condition and under and experimental treatment.

using DataFrames

data = DataFrame(
    observations = repeat([1.0, 1, 1, 2, 2, 2], 6),
    actions = vcat(
        [0, 0.2, 0.3, 0.4, 0.5, 0.6],
        [0, 0.5, 0.8, 1, 1.5, 1.8],
        [0, 2, 0.5, 4, 5, 3],
        [0, 0.1, 0.15, 0.2, 0.25, 0.3],
        [0, 0.2, 0.4, 0.7, 1.0, 1.1],
        [0, 2, 0.5, 4, 5, 3],
    ),
    id = vcat(
        repeat(["A"], 6),
        repeat(["B"], 6),
        repeat(["C"], 6),
        repeat(["A"], 6),
        repeat(["B"], 6),
        repeat(["C"], 6),
    ),
    treatment = vcat(repeat(["control"], 18), repeat(["treatment"], 18)),
)

show(data)
36×4 DataFrame
 Row │ observations  actions  id      treatment
     │ Float64       Float64  String  String
─────┼──────────────────────────────────────────
   1 │          1.0      0.0  A       control
   2 │          1.0      0.2  A       control
   3 │          1.0      0.3  A       control
   4 │          2.0      0.4  A       control
   5 │          2.0      0.5  A       control
   6 │          2.0      0.6  A       control
   7 │          1.0      0.0  B       control
   8 │          1.0      0.5  B       control
  ⋮  │      ⋮           ⋮       ⋮         ⋮
  30 │          2.0      1.1  B       treatment
  31 │          1.0      0.0  C       treatment
  32 │          1.0      2.0  C       treatment
  33 │          1.0      0.5  C       treatment
  34 │          2.0      4.0  C       treatment
  35 │          2.0      5.0  C       treatment
  36 │          2.0      3.0  C       treatment
                                 21 rows omitted

Specifying the population model

Finally, we will specify a population model, which is the model of how parameters vary between the different sessions in the data. There are various options when doing this, which are described in the population model REF section. Here, we will use a regression population model, where we assume that the learning rate and action noise parameters depend linearly on the experimental treatment. It is a hierarchical model, which means assuming that the parameters are sampled from a Gaussian distribution, where the mean of the distribution is a linear function of the treatment condition. This is specified with standard LMER syntax, and we use a logistic inverse link function for the learning rate to ensure that it is between 0 and 1, and an exponential inverse link function for the action noise to ensure that it is positive.

population_model = [
    Regression(@formula(learning_rate ~ treatment + (1 | id)), logistic),
    Regression(@formula(action_noise ~ treatment + (1 | id)), exp),
];

Creating the full model

Finally, we can combine the three components into a full model that can be fit to the data. This is done with the create_model function, which takes the action model, population model, and data as arguments. Additionally, we specify which columns in the data contain the actions, observations, and session identifiers. This creates an ActionModels.ModelFit object, which containts the full model, and will contain the sampling results after fitting.

model = create_model(
    action_model,
    population_model,
    data;
    action_cols = :actions,
    observation_cols = :observations,
    session_cols = [:id, :treatment],
)
-- ModelFit object --
Action model: rescorla_wagner_act_after_update
Linear regression population model
2 estimated action model parameters, 6 sessions
Posterior not sampled
Prior not sampled

If there are multiple actions or observations, we can specify them as a NamedTuple mapping each action or observation to a column in the data. For example, if we had two actions and two observations, we could specify them as follows:

(action_name1 = :action_column_name1, action_name2 = :action_column_name2);

The column names can also be specified as a vector of symbols, in which case it will be assumed that the order matches the order of actions or observations in the action model.

Finally, there may be missing data in the dataset. If actions are missing, they can be imputed by ActionModels. This is done by setting the impute_actions argument to true in the create_model function. If impute_actions is not set to true, missing actions will simply be skipped during sampling instead. This is a problem for action models which depend on their previous actions.

Fitting the model

Now that the model is created, we are ready to fit it. This is done under the hood using MCMC sampling, which is provided by the Turing.jl framework. ActionModels provides the sample_posterior! function, which fits the model in this way with sensible defaults.

chns = sample_posterior!(model)
Chains MCMC chain (1000×24×2 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 2
Samples per chain = 1000
Wall duration     = 46.86 seconds
Compute duration  = 43.46 seconds
parameters        = learning_rate.β[1], learning_rate.β[2], learning_rate.ranef_1.σ[1], learning_rate.ranef_1.r[1], learning_rate.ranef_1.r[2], learning_rate.ranef_1.r[3], action_noise.β[1], action_noise.β[2], action_noise.ranef_1.σ[1], action_noise.ranef_1.r[1], action_noise.ranef_1.r[2], action_noise.ranef_1.r[3]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
                  parameters      mean       std      mcse    ess_bulk    ess_tail      rhat   ess_per_sec
                      Symbol   Float64   Float64   Float64     Float64     Float64   Float64       Float64

          learning_rate.β[1]   -0.2739    0.8515    0.0349    615.0436    702.9782    1.0003       14.1513
          learning_rate.β[2]   -0.8287    0.1087    0.0028   1552.6906   1179.8918    0.9993       35.7252
  learning_rate.ranef_1.σ[1]    1.9737    1.0101    0.0369    770.7671    916.0555    0.9999       17.7343
  learning_rate.ranef_1.r[1]   -2.1287    0.8499    0.0349    610.3512    708.6159    1.0004       14.0433
  learning_rate.ranef_1.r[2]   -0.4670    0.8537    0.0350    608.3110    713.1852    1.0002       13.9964
  learning_rate.ranef_1.r[3]    2.2017    1.6325    0.0515   1270.6294    950.6482    0.9997       29.2354
           action_noise.β[1]   -0.7178    0.8120    0.0312    692.1745    795.3975    1.0025       15.9260
           action_noise.β[2]   -0.3544    0.2600    0.0060   1854.8970   1229.0322    1.0018       42.6786
   action_noise.ranef_1.σ[1]    1.8676    0.7534    0.0211   1660.8644   1146.8143    0.9998       38.2142
   action_noise.ranef_1.r[1]   -2.1855    0.8294    0.0322    678.3393    844.5278    1.0035       15.6076
   action_noise.ranef_1.r[2]   -1.0133    0.8285    0.0312    721.8426    979.8111    1.0010       16.6086
   action_noise.ranef_1.r[3]    1.5280    0.8368    0.0316    721.1078    918.0065    1.0031       16.5917

Quantiles
                  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
                      Symbol   Float64   Float64   Float64   Float64   Float64

          learning_rate.β[1]   -1.8422   -0.8220   -0.2961    0.1968    1.4791
          learning_rate.β[2]   -1.0520   -0.8968   -0.8262   -0.7595   -0.6147
  learning_rate.ranef_1.σ[1]    0.7538    1.2994    1.7527    2.3639    4.5534
  learning_rate.ranef_1.r[1]   -3.8665   -2.6113   -2.1039   -1.5854   -0.5701
  learning_rate.ranef_1.r[2]   -2.1903   -0.9553   -0.4542    0.0838    1.1048
  learning_rate.ranef_1.r[3]   -0.3609    1.1288    1.9308    3.0140    6.0367
           action_noise.β[1]   -2.3611   -1.2004   -0.7072   -0.2003    0.7812
           action_noise.β[2]   -0.8566   -0.5237   -0.3545   -0.1755    0.1469
   action_noise.ranef_1.σ[1]    0.9175    1.3385    1.7087    2.2094    3.7160
   action_noise.ranef_1.r[1]   -3.7912   -2.7124   -2.1897   -1.6827   -0.5126
   action_noise.ranef_1.r[2]   -2.5468   -1.5531   -1.0344   -0.4884    0.6330
   action_noise.ranef_1.r[3]    0.0030    0.9953    1.5201    2.0551    3.2351

This returns a MCMCChains Chains object, which contains the samples from the posterior distribution of the model parameters. For each parameter, there is a posterior for the β values (intercept and treatment effect), as well as for the deviation of the random effects σ and the single random effects $r$. Notably, with a fulle dataset, the posterior will contain a large number of parameters. We can see that the second beta value for the learning rate (the dependence on the treatment condition) is negative. The dataset has been constructed to have lower learning rates in the treatment condition, so this is expected.

Notably, sample_posterior! has many options for how to sample the posterior, which can be set with keyword arguments. If we pass either MCMCThreads() or MCMCDistributed() as the second argument, Turing will use multithreading or distributed sampling to parallellise between chains. It is recommended to use MCMCThreads for multithreading, but note that Julia must be started with the --threads flag to enable multithreading. We can specify the number of samples and chains to sample with the n_samples and n_chains keyword arguments. The init_params keyword argument can be used to specify how the initial parameters for the chains are set. It can be set to :MAP or :MLE to use the maximum a posteriori or maximum likelihood estimates as the initial parameters, respectively. It can be set to :sample_prior to draw a single sample from the prior distribution, or to nothing to use Turing's default of random values between -2 and 2 as the initial parameters.Finally, a vector of initial parameters can be passed, which will be used as the initial parameters for all chains. Other arguments for the sampling can also be passed. This includes the autodifferentiation backend to use, which can be set with theadtypekeyword argument, and the sampler to use, which can be set with thesamplerkeyword argument. Notably,sampleposterior!will return the already sampledChainsobject if the posterior has already been sampled. Setresample = true` to override the already sampled posterior.

chns = sample_posterior!(
    model,
    MCMCThreads(),
    n_samples = 500,
    n_chains = 4,
    init_params = :MAP,
    ad_type = AutoForwardDiff(),
    sampler = NUTS(),
    resample = true,
);
Sampling (2 threads)   0%|                              |  ETA: N/A
┌ Info: Found initial step size
└   ϵ = 0.2
┌ Info: Found initial step size
└   ϵ = 0.0125
┌ Info: Found initial step size
└   ϵ = 0.4
Sampling (2 threads)  25%|███████▌                      |  ETA: 0:00:44
Sampling (2 threads)  50%|███████████████               |  ETA: 0:00:24
Sampling (2 threads)  75%|██████████████████████▌       |  ETA: 0:00:08
┌ Info: Found initial step size
└   ϵ = 0.4
Sampling (2 threads) 100%|██████████████████████████████| Time: 0:00:35
Sampling (2 threads) 100%|██████████████████████████████| Time: 0:00:35

Finally, some users may wish to use Turing's own interface for sampling from the posterior instead. The Turing inferface is more flexible in general, but requires more boilerplate code to set up. For this case, the ActionModels.ModelFit objects contains the Turing model that is used under the hood. Users can extract and use it as any other Turing model, if they wish.

turing_model = model.model
DynamicPPL.Model{typeof(ActionModels.full_model), (:action_model, :population_model, :session_model, :observations_per_session, :actions_per_session, :estimated_parameter_names, :session_ids, :initial_states, Symbol("##arg#230"), Symbol("##arg#231")), (), (), Tuple{ActionModel{ActionModels.ContinuousRescorlaWagner}, DynamicPPL.Model{typeof(ActionModels.regression_population_model), (:linear_submodels, :estimated_parameter_names, :n_sessions), (), (), Tuple{Vector{DynamicPPL.Model}, Vector{Symbol}, Int64}, Tuple{}, DynamicPPL.DefaultContext}, ActionModels.var"#session_model#92"{ActionModels.var"#session_model#86#93"}, Vector{Vector{Tuple{Float64}}}, Vector{Vector{Tuple{Float64}}}, Tuple{Symbol, Symbol}, Vector{String}, @NamedTuple{}, DynamicPPL.TypeWrap{Float64}, DynamicPPL.TypeWrap{Int64}}, Tuple{}, DynamicPPL.DefaultContext}(ActionModels.full_model, (action_model = ActionModel{ActionModels.ContinuousRescorlaWagner}(ActionModels.var"#rescorla_wagner_act_after_update#213"{ActionModels.var"#gaussian_report#209"}(ActionModels.var"#gaussian_report#209"()), (action_noise = Parameter{Float64}(1.0, Float64),), NamedTuple(), (observation = Observation{Float64}(Float64),), (report = Action{Float64, Normal}(Float64, Normal),), ActionModels.ContinuousRescorlaWagner(0.0, 0.1)), population_model = DynamicPPL.Model{typeof(ActionModels.regression_population_model), (:linear_submodels, :estimated_parameter_names, :n_sessions), (), (), Tuple{Vector{DynamicPPL.Model}, Vector{Symbol}, Int64}, Tuple{}, DynamicPPL.DefaultContext}(ActionModels.regression_population_model, (linear_submodels = DynamicPPL.Model[DynamicPPL.Model{typeof(ActionModels.linear_model), (:X, :ranef_info), (:inv_link, :prior, :has_ranef), (), Tuple{Matrix{Float64}, @NamedTuple{Z::Vector{Matrix{Float64}}, n_ranef_categories::Vector{Int64}, n_ranef_params::Vector{Int64}}}, Tuple{typeof(logistic), ActionModels.RegPrior, Bool}, DynamicPPL.DefaultContext}(ActionModels.linear_model, (X = [1.0 0.0; 1.0 0.0; 1.0 0.0; 1.0 1.0; 1.0 1.0; 1.0 1.0], ranef_info = (Z = [[1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0; 1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0]], n_ranef_categories = [3], n_ranef_params = [1])), (inv_link = LogExpFunctions.logistic, prior = ActionModels.RegPrior(Product{Continuous, TDist{Float64}, FillArrays.Fill{TDist{Float64}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(TDist{Float64}(ν=3.0), 2)), Distribution[Product{Continuous, Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}, FillArrays.Fill{Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(Truncated(TDist{Float64}(ν=3.0); lower=0.0), 1))]), has_ranef = true), DynamicPPL.DefaultContext()), DynamicPPL.Model{typeof(ActionModels.linear_model), (:X, :ranef_info), (:inv_link, :prior, :has_ranef), (), Tuple{Matrix{Float64}, @NamedTuple{Z::Vector{Matrix{Float64}}, n_ranef_categories::Vector{Int64}, n_ranef_params::Vector{Int64}}}, Tuple{typeof(exp), ActionModels.RegPrior, Bool}, DynamicPPL.DefaultContext}(ActionModels.linear_model, (X = [1.0 0.0; 1.0 0.0; 1.0 0.0; 1.0 1.0; 1.0 1.0; 1.0 1.0], ranef_info = (Z = [[1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0; 1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0]], n_ranef_categories = [3], n_ranef_params = [1])), (inv_link = exp, prior = ActionModels.RegPrior(Product{Continuous, TDist{Float64}, FillArrays.Fill{TDist{Float64}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(TDist{Float64}(ν=3.0), 2)), Distribution[Product{Continuous, Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}, FillArrays.Fill{Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(Truncated(TDist{Float64}(ν=3.0); lower=0.0), 1))]), has_ranef = true), DynamicPPL.DefaultContext())], estimated_parameter_names = [:learning_rate, :action_noise], n_sessions = 6), NamedTuple(), DynamicPPL.DefaultContext()), session_model = ActionModels.var"#session_model#92"{ActionModels.var"#session_model#86#93"}(ActionModels.var"#session_model#86#93"(Core.Box(ActionModels.var"#sample_actions_one_idx#91"(Core.Box(#= circular reference @-2 =#)))), Core.Box(ActionModels.var"#session_model#89#96"(Core.Box(ActionModels.var"#session_model#92"{ActionModels.var"#session_model#86#93"}(#= circular reference @-4 =#)))), Core.Box(([0.0, 0.2, 0.3, 0.4, 0.5, 0.6, 0.0, 0.5, 0.8, 1.0, 1.5, 1.8, 0.0, 2.0, 0.5, 4.0, 5.0, 3.0, 0.0, 0.1, 0.15, 0.2, 0.25, 0.3, 0.0, 0.2, 0.4, 0.7, 1.0, 1.1, 0.0, 2.0, 0.5, 4.0, 5.0, 3.0],))), observations_per_session = [[(1.0,), (1.0,), (1.0,), (2.0,), (2.0,), (2.0,)], [(1.0,), (1.0,), (1.0,), (2.0,), (2.0,), (2.0,)], [(1.0,), (1.0,), (1.0,), (2.0,), (2.0,), (2.0,)], [(1.0,), (1.0,), (1.0,), (2.0,), (2.0,), (2.0,)], [(1.0,), (1.0,), (1.0,), (2.0,), (2.0,), (2.0,)], [(1.0,), (1.0,), (1.0,), (2.0,), (2.0,), (2.0,)]], actions_per_session = [[(0.0,), (0.2,), (0.3,), (0.4,), (0.5,), (0.6,)], [(0.0,), (0.5,), (0.8,), (1.0,), (1.5,), (1.8,)], [(0.0,), (2.0,), (0.5,), (4.0,), (5.0,), (3.0,)], [(0.0,), (0.1,), (0.15,), (0.2,), (0.25,), (0.3,)], [(0.0,), (0.2,), (0.4,), (0.7,), (1.0,), (1.1,)], [(0.0,), (2.0,), (0.5,), (4.0,), (5.0,), (3.0,)]], estimated_parameter_names = (:learning_rate, :action_noise), session_ids = ["id:A.treatment:control", "id:B.treatment:control", "id:C.treatment:control", "id:A.treatment:treatment", "id:B.treatment:treatment", "id:C.treatment:treatment"], initial_states = NamedTuple(), var"##arg#230" = DynamicPPL.TypeWrap{Float64}(), var"##arg#231" = DynamicPPL.TypeWrap{Int64}()), NamedTuple(), DynamicPPL.DefaultContext())

If users want to sample from the model themselves, but still want to draw on the rest of the ActionModels API, they can set it in the ModelFit object themselves by creating an ActionModels.ModelFitResult object. This should be passed to either the posterior or prior field of the ModelFit object, after which it will interface with the ActionModels API as normal.

using ActionModels: Turing
chns = sample(turing_model, NUTS(), 1000, progress = false);

model.posterior = ActionModels.ModelFitResult(; chains = chns);
┌ Info: Found initial step size
└   ϵ = 0.025

In addition to sampling from the posterior, ActionModels also provides functionality for sampling from the prior distribution of the model parameters. This is done with the sample_prior! function, which works in a similar way to sample_posterior!. Notably, it is much simpler, due to not requiring a complex sampler. This means that it only takes n_chains and n_samples as keyword arguments.

prior_chns = sample_prior!(model, n_chains = 1, n_samples = 1000)
Chains MCMC chain (1000×13×1 Array{Float64, 3}):

Iterations        = 1:1:1000
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 1.65 seconds
Compute duration  = 1.65 seconds
parameters        = learning_rate.β[1], learning_rate.β[2], learning_rate.ranef_1.σ[1], learning_rate.ranef_1.r[1], learning_rate.ranef_1.r[2], learning_rate.ranef_1.r[3], action_noise.β[1], action_noise.β[2], action_noise.ranef_1.σ[1], action_noise.ranef_1.r[1], action_noise.ranef_1.r[2], action_noise.ranef_1.r[3]
internals         = lp

Summary Statistics
                  parameters      mean       std      mcse    ess_bulk   ess_tail      rhat   ess_per_sec
                      Symbol   Float64   Float64   Float64     Float64    Float64   Float64       Float64

          learning_rate.β[1]    0.1000    1.6996    0.0561    939.0323   773.7738    1.0035      569.1105
          learning_rate.β[2]   -0.0300    1.6328    0.0540    879.9462   944.1954    0.9997      533.3007
  learning_rate.ranef_1.σ[1]    1.0364    1.3250    0.0437    974.8590   903.5752    1.0004      590.8236
  learning_rate.ranef_1.r[1]   -0.0336    1.6861    0.0488   1041.2117   880.7797    1.0004      631.0374
  learning_rate.ranef_1.r[2]    0.0192    1.7416    0.0536   1064.8222   877.7412    0.9994      645.3468
  learning_rate.ranef_1.r[3]   -0.0389    1.4450    0.0449   1035.5275   975.2328    1.0009      627.5924
           action_noise.β[1]    0.0328    1.6734    0.0543    904.9334   948.6935    1.0005      548.4445
           action_noise.β[2]   -0.0282    1.5394    0.0508    923.2031   873.7012    1.0029      559.5170
   action_noise.ranef_1.σ[1]    1.1720    1.6563    0.0543    851.1144   936.7700    0.9994      515.8269
   action_noise.ranef_1.r[1]   -0.0654    1.9133    0.0612   1054.8695   876.9116    0.9999      639.3149
   action_noise.ranef_1.r[2]    0.0625    2.6637    0.0883   1018.5414   975.2328    0.9995      617.2978
   action_noise.ranef_1.r[3]   -0.0972    1.6989    0.0542   1008.7916   958.4968    0.9993      611.3888

Quantiles
                  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
                      Symbol   Float64   Float64   Float64   Float64   Float64

          learning_rate.β[1]   -3.0365   -0.7435    0.0110    0.8361    3.4479
          learning_rate.β[2]   -3.2650   -0.7785    0.0331    0.7756    3.1641
  learning_rate.ranef_1.σ[1]    0.0354    0.3245    0.6846    1.3208    4.0512
  learning_rate.ranef_1.r[1]   -3.1597   -0.4616   -0.0079    0.3661    3.1456
  learning_rate.ranef_1.r[2]   -3.3306   -0.3309   -0.0030    0.4502    2.6127
  learning_rate.ranef_1.r[3]   -2.8876   -0.4038   -0.0059    0.4012    2.6831
           action_noise.β[1]   -3.3063   -0.6863    0.0646    0.9014    3.0670
           action_noise.β[2]   -3.3695   -0.7467    0.0520    0.7749    2.8471
   action_noise.ranef_1.σ[1]    0.0324    0.3611    0.7793    1.4660    4.3725
   action_noise.ranef_1.r[1]   -3.4997   -0.4769   -0.0112    0.4034    3.0606
   action_noise.ranef_1.r[2]   -2.8894   -0.3918    0.0047    0.4611    3.0907
   action_noise.ranef_1.r[3]   -3.4951   -0.4197   -0.0043    0.3671    2.6893

Investigating the results

Population model parameters

The first step in investigating model fitting results is often to look at the population model parameters. Population model parameters and how to visualize them will depend on the type of population model used. See the population model REF section for more details on how to interpret results from different population models. But in general, the Chains object returned by sample_posterior! will contain the posterior distribution of the population model parameters. These can be visualized with various plotting functions; see the MCMCChains documentation for an overview. Here, we just use the standard plot function to visualize the posterior distribution over the beta values of interest:

#Fit the model
chns = sample_posterior!(model);

Plot the posterior distribution of the learning rate and action noise beta parameters.

plot(chns[[Symbol("learning_rate.β[1]"), Symbol("learning_rate.β[2]"), Symbol("action_noise.β[1]"), Symbol("action_noise.β[2]")]])

#TODO: ArviZ support / specifalized plotting functions

Parameter per session

Beyond the population model parameters, users will often be intersted in the parameter estimates for each session in the data. The session parameters can be extracted with the get_session_parameters! function, which returns a ActionModels.SessionParameters object. Whether session parameter estimates should be extracted for the posterior or prior distribution can be specified as the second argument.

#Extract posterior and prior estimates for session paramaters
session_parameters = get_session_parameters!(model)
prior_session_parameters = get_session_parameters!(model, :prior)
-- Session parameters object --
6 sessions, 1 chains, 1000 samples
2 estimated parameters:
   learning_rate
   action_noise

ActioModels provides a convenient functionality for plotting the session parameters.

#TODO: plot(session_parameters)

Users can also access the full distribution over the session parameters, for use in manual downstream analysis. The ActionModels.SessionParameters object contains the distributions for each parameter and each session. These can be found in the value field, which contains nested NamedTuples for each parameter and session, and ultimately an AxisArray with the samples for each chain.

learning_rate_singlesession = getfield(session_parameters.value.learning_rate, Symbol("id:A.treatment:control"));

The user may visualize, summarize or analyze the samples in whichever way they prefer.

#Calcualate the mean
mean(learning_rate_singlesession)
#Plot the distribution
density(learning_rate_singlesession, title = "Learning rate for session A in control condition")

ActionModels provides a convenient function for summarizing all the session parameters as a DataFrame, that can be used for further analysis.

median_df = summarize(session_parameters)

show(median_df)
6×4 DataFrame
 Row │ id      treatment  action_noise  learning_rate
     │ String  String     Float64       Float64
─────┼────────────────────────────────────────────────
   1 │ A       control        0.053729      0.0830947
   2 │ B       control        0.176646      0.322814
   3 │ C       control        2.2181        0.855045
   4 │ A       treatment      0.036657      0.0379923
   5 │ B       treatment      0.120573      0.171572
   6 │ C       treatment      1.51952       0.724694

This returns the median of each parameter for each session. The user can pass other functions for summarizing the samples, as for example to calculate the standard deviation of the posterior.

std_df = summarize(session_parameters, std)

show(std_df)
6×4 DataFrame
 Row │ id      treatment  action_noise  learning_rate
     │ String  String     Float64       Float64
─────┼────────────────────────────────────────────────
   1 │ A       control       0.0162926     0.00548921
   2 │ B       control       0.0444528     0.0271187
   3 │ C       control       0.705664      0.185973
   4 │ A       treatment     0.0113222     0.00306329
   5 │ B       treatment     0.0360526     0.0160724
   6 │ C       treatment     0.397858      0.240445

This can be saved to disk or used for plotting or analysis in whichever way the user prefers.

State trajectories per session

Users can also extract the estimated trajectory of states for each session with the get_state_trajectories! function, which returns an ActionModels.StateTrajectories object. State trajectories are often used to correlate with some external measure, such as neuroimaging data. The second argument specifies which state to extract. Again, the user can also specify to have the prior state trajectories extracted by passing :prior as the second argument.

state_trajectories = get_state_trajectories!(model, :expected_value)

prior_state_trajectories = get_state_trajectories!(model, :expected_value, :prior)
-- State trajectories object --
6 sessions, 1 chains, 1000 samples
1 estimated states:
   expected_value

ActionModels also provides functionality for plotting the state trajectories.

#TODO: plot(state_trajectories)

The ActionModels.StateTrajectories object contains the prior or porsterior distribution over state trajectories for each session, and for each of the specified states. This is stored in the value field, which contains nested NamedTuple objects with state names and then session names as keys, and ultimately an AxisArray with the samples across timesteps. This AxisArray has three axes: the first is the sample index, the second is the session index, and the third is the timestep index.

expected_value_singlesession = getfield(state_trajectories.value.expected_value, Symbol("id:A.treatment:control"));

Again, we can visualize, summarize or analyze the samples in whichever way we prefer. Here, we chose the second timestep of the first session, and calculate the mean and plot the distribution.

mean(expected_value_singlesession[timestep=2])
density(expected_value_singlesession[timestep=2], title = "Expectation at time 2 session A control condition")

This can also be summarized neatly in a DataFrame:

median_df = summarize(state_trajectories, median)

show(median_df)
#And from here, it can be used for plotting or further analysis as desired by the user.
42×4 DataFrame
 Row │ id      treatment  timestep  expected_value
     │ String  String     Int64     Float64
─────┼─────────────────────────────────────────────
   1 │ A       control           0       0.0
   2 │ A       control           1       0.0830947
   3 │ A       control           2       0.159285
   4 │ A       control           3       0.229144
   5 │ A       control           4       0.376292
   6 │ A       control           5       0.511214
   7 │ A       control           6       0.634924
   8 │ B       control           0       0.0
  ⋮  │   ⋮         ⋮         ⋮            ⋮
  36 │ C       treatment         0       0.0
  37 │ C       treatment         1       0.724694
  38 │ C       treatment         2       0.924206
  39 │ C       treatment         3       0.979133
  40 │ C       treatment         4       1.71895
  41 │ C       treatment         5       1.92262
  42 │ C       treatment         6       1.9787
                                    27 rows omitted

This page was generated using Literate.jl.