Fitting action models to data

In this section, we will cover how to fit action models to data with ActionModels. First, we will demonstrate how to set up the model that will be fit under hood by relying on Turing. This consists of specifying three things: an action model, a population model, and the data to fit the model to. Then, we will show how to sample from the posterior and prior distribution of the model parameters. And finally, we will show the tools in ActionModels to inspect and extract the results of the model fitting.

Setting up the model

First we import ActionModels, as well as StatsPlots for plotting the results later.

using ActionModels, StatsPlots

Defining the action model

We will here use the premade Rescorla-Wagner action model provided by ActionModels. This is identical to the model described in the defining action models REF section.

action_model = ActionModel(RescorlaWagner())
-- ActionModel --
Action model function: rescorla_wagner_act_after_update
Number of parameters: 1
Number of states: 0
Number of observations: 1
Number of actions: 1
submodel type: ActionModels.ContinuousRescorlaWagner

Loading the data

We will then specify the data that we want to fit the model to. For this example, we will use a simple manually created dataset, where three participants have completed an experiment where they must predict the next location of a moving target. Each participant has completed the experiment twice, in a control condition and under and experimental treatment.

using DataFrames

data = DataFrame(
    observations = repeat([1.0, 1, 1, 2, 2, 2], 6),
    actions = vcat(
        [0, 0.2, 0.3, 0.4, 0.5, 0.6],
        [0, 0.5, 0.8, 1, 1.5, 1.8],
        [0, 2, 0.5, 4, 5, 3],
        [0, 0.1, 0.15, 0.2, 0.25, 0.3],
        [0, 0.2, 0.4, 0.7, 1.0, 1.1],
        [0, 2, 0.5, 4, 5, 3],
    ),
    id = vcat(
        repeat(["A"], 6),
        repeat(["B"], 6),
        repeat(["C"], 6),
        repeat(["A"], 6),
        repeat(["B"], 6),
        repeat(["C"], 6),
    ),
    treatment = vcat(repeat(["control"], 18), repeat(["treatment"], 18)),
)

show(data)
36×4 DataFrame
 Row │ observations  actions  id      treatment
     │ Float64       Float64  String  String
─────┼──────────────────────────────────────────
   1 │          1.0      0.0  A       control
   2 │          1.0      0.2  A       control
   3 │          1.0      0.3  A       control
   4 │          2.0      0.4  A       control
   5 │          2.0      0.5  A       control
   6 │          2.0      0.6  A       control
   7 │          1.0      0.0  B       control
   8 │          1.0      0.5  B       control
  ⋮  │      ⋮           ⋮       ⋮         ⋮
  30 │          2.0      1.1  B       treatment
  31 │          1.0      0.0  C       treatment
  32 │          1.0      2.0  C       treatment
  33 │          1.0      0.5  C       treatment
  34 │          2.0      4.0  C       treatment
  35 │          2.0      5.0  C       treatment
  36 │          2.0      3.0  C       treatment
                                 21 rows omitted

Specifying the population model

Finally, we will specify a population model, which is the model of how parameters vary between the different sessions in the data. There are various options when doing this, which are described in the population model REF section. Here, we will use a regression population model, where we assume that the learning rate and action noise parameters depend linearly on the experimental treatment. It is a hierarchical model, which means assuming that the parameters are sampled from a Gaussian distribution, where the mean of the distribution is a linear function of the treatment condition. This is specified with standard LMER syntax, and we use a logistic inverse link function for the learning rate to ensure that it is between 0 and 1, and an exponential inverse link function for the action noise to ensure that it is positive.

population_model = [
    Regression(@formula(learning_rate ~ treatment + (1 | id)), logistic),
    Regression(@formula(action_noise ~ treatment + (1 | id)), exp),
];

Creating the full model

Finally, we can combine the three components into a full model that can be fit to the data. This is done with the create_model function, which takes the action model, population model, and data as arguments. Additionally, we specify which columns in the data contain the actions, observations, and session identifiers. This creates an ActionModels.ModelFit object, which containts the full model, and will contain the sampling results after fitting.

model = create_model(
    action_model,
    population_model,
    data;
    action_cols = :actions,
    observation_cols = :observations,
    session_cols = [:id, :treatment],
)
-- ModelFit object --
Action model: rescorla_wagner_act_after_update
Linear regression population model
2 estimated action model parameters, 6 sessions
Posterior not sampled
Prior not sampled

If there are multiple actions or observations, we can specify them as a NamedTuple mapping each action or observation to a column in the data. For example, if we had two actions and two observations, we could specify them as follows:

(action_name1 = :action_column_name1, action_name2 = :action_column_name2);

The column names can also be specified as a vector of symbols, in which case it will be assumed that the order matches the order of actions or observations in the action model.

Finally, there may be missing data in the dataset. If actions are missing, they can be imputed by ActionModels. This is done by setting the impute_actions argument to true in the create_model function. If impute_actions is not set to true, missing actions will simply be skipped during sampling instead. This is a problem for action models which depend on their previous actions.

Fitting the model

Now that the model is created, we are ready to fit it. This is done under the hood using MCMC sampling, which is provided by the Turing.jl framework. ActionModels provides the sample_posterior! function, which fits the model in this way with sensible defaults.

chns = sample_posterior!(model)
Chains MCMC chain (1000×24×2 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 2
Samples per chain = 1000
Wall duration     = 50.48 seconds
Compute duration  = 47.23 seconds
parameters        = learning_rate.β[1], learning_rate.β[2], learning_rate.ranef_1.σ[1], learning_rate.ranef_1.r[1], learning_rate.ranef_1.r[2], learning_rate.ranef_1.r[3], action_noise.β[1], action_noise.β[2], action_noise.ranef_1.σ[1], action_noise.ranef_1.r[1], action_noise.ranef_1.r[2], action_noise.ranef_1.r[3]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
                  parameters      mean       std      mcse    ess_bulk    ess_tail      rhat   ess_per_sec
                      Symbol   Float64   Float64   Float64     Float64     Float64   Float64       Float64

          learning_rate.β[1]   -0.3154    0.8416    0.0315    728.7682    762.2479    1.0031       15.4315
          learning_rate.β[2]   -0.8246    0.1064    0.0026   1715.2282   1308.6368    0.9997       36.3196
  learning_rate.ranef_1.σ[1]    1.9781    1.0277    0.0363    942.5629    920.2904    1.0027       19.9586
  learning_rate.ranef_1.r[1]   -2.0840    0.8421    0.0315    735.2904    758.3671    1.0035       15.5696
  learning_rate.ranef_1.r[2]   -0.4278    0.8494    0.0314    752.2694    735.9333    1.0027       15.9291
  learning_rate.ranef_1.r[3]    2.2329    1.7321    0.0616   1194.1897    768.2138    1.0053       25.2867
           action_noise.β[1]   -0.6679    0.7988    0.0317    640.4252    721.9865    1.0044       13.5609
           action_noise.β[2]   -0.3776    0.2648    0.0060   1979.8653   1459.2444    1.0007       41.9232
   action_noise.ranef_1.σ[1]    1.8625    0.7660    0.0238   1164.3755   1172.8452    1.0004       24.6554
   action_noise.ranef_1.r[1]   -2.2267    0.8319    0.0324    661.7851    603.9202    1.0048       14.0132
   action_noise.ranef_1.r[2]   -1.0530    0.8153    0.0321    649.9414    716.0384    1.0041       13.7624
   action_noise.ranef_1.r[3]    1.4891    0.8043    0.0313    665.3593    763.1122    1.0065       14.0888

Quantiles
                  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
                      Symbol   Float64   Float64   Float64   Float64   Float64

          learning_rate.β[1]   -1.9708   -0.8473   -0.3137    0.1844    1.4482
          learning_rate.β[2]   -1.0408   -0.8885   -0.8246   -0.7579   -0.6199
  learning_rate.ranef_1.σ[1]    0.7248    1.3202    1.7670    2.3601    4.3748
  learning_rate.ranef_1.r[1]   -3.8343   -2.5796   -2.0743   -1.5661   -0.4185
  learning_rate.ranef_1.r[2]   -2.1652   -0.9422   -0.4193    0.0960    1.2150
  learning_rate.ranef_1.r[3]   -0.1302    1.1545    1.9234    3.0218    6.5012
           action_noise.β[1]   -2.2887   -1.2022   -0.6817   -0.1366    0.8935
           action_noise.β[2]   -0.8808   -0.5618   -0.3771   -0.1967    0.1382
   action_noise.ranef_1.σ[1]    0.8861    1.3353    1.6998    2.2056    3.7326
   action_noise.ranef_1.r[1]   -3.8839   -2.7801   -2.1843   -1.6727   -0.6017
   action_noise.ranef_1.r[2]   -2.6883   -1.5855   -1.0347   -0.5230    0.5002
   action_noise.ranef_1.r[3]   -0.1857    0.9591    1.5061    2.0202    3.0710

This returns a MCMCChains Chains object, which contains the samples from the posterior distribution of the model parameters. For each parameter, there is a posterior for the β values (intercept and treatment effect), as well as for the deviation of the random effects σ and the single random effects $r$. Notably, with a fulle dataset, the posterior will contain a large number of parameters. We can see that the second beta value for the learning rate (the dependence on the treatment condition) is negative. The dataset has been constructed to have lower learning rates in the treatment condition, so this is expected.

Notably, sample_posterior! has many options for how to sample the posterior, which can be set with keyword arguments. If we pass either MCMCThreads() or MCMCDistributed() as the second argument, Turing will use multithreading or distributed sampling to parallellise between chains. It is recommended to use MCMCThreads for multithreading, but note that Julia must be started with the --threads flag to enable multithreading. We can specify the number of samples and chains to sample with the n_samples and n_chains keyword arguments. The init_params keyword argument can be used to specify how the initial parameters for the chains are set. It can be set to :MAP or :MLE to use the maximum a posteriori or maximum likelihood estimates as the initial parameters, respectively. It can be set to :sample_prior to draw a single sample from the prior distribution, or to nothing to use Turing's default of random values between -2 and 2 as the initial parameters. Finally, a vector of initial parameters can be passed, which will be used as the initial parameters for all chains. Other arguments for the sampling can also be passed. This includes the autodifferentiation backend to use, which can be set with the ad_type keyword argument, and the sampler to use, which can be set with the sampler keyword argument. Notably, sample_posterior! will return the already sampled Chains object if the posterior has already been sampled. Set resample = true to override the already sampled posterior.

chns = sample_posterior!(
    model,
    MCMCThreads(),
    n_samples = 500,
    n_chains = 4,
    init_params = :MAP,
    ad_type = AutoForwardDiff(),
    sampler = NUTS(),
    resample = true,
);
Sampling (2 threads)   0%|                              |  ETA: N/A
┌ Info: Found initial step size
└   ϵ = 0.05
┌ Info: Found initial step size
└   ϵ = 0.0125
┌ Info: Found initial step size
└   ϵ = 0.4
Sampling (2 threads)  25%|███████▌                      |  ETA: 0:00:42
Sampling (2 threads)  50%|███████████████               |  ETA: 0:00:24
Sampling (2 threads)  75%|██████████████████████▌       |  ETA: 0:00:09
┌ Info: Found initial step size
└   ϵ = 0.05
Sampling (2 threads) 100%|██████████████████████████████| Time: 0:00:34
Sampling (2 threads) 100%|██████████████████████████████| Time: 0:00:34

Finally, some users may wish to use Turing's own interface for sampling from the posterior instead. The Turing inferface is more flexible in general, but requires more boilerplate code to set up. For this case, the ActionModels.ModelFit objects contains the Turing model that is used under the hood. Users can extract and use it as any other Turing model, if they wish.

turing_model = model.model
DynamicPPL.Model{typeof(ActionModels.full_model), (:action_model, :population_model, :session_model, :observations_per_session, :actions_per_session, :estimated_parameter_names, :session_ids, :initial_states, Symbol("##arg#230"), Symbol("##arg#231")), (), (), Tuple{ActionModel{ActionModels.ContinuousRescorlaWagner}, DynamicPPL.Model{typeof(ActionModels.regression_population_model), (:linear_submodels, :estimated_parameter_names, :n_sessions), (), (), Tuple{Vector{DynamicPPL.Model}, Vector{Symbol}, Int64}, Tuple{}, DynamicPPL.DefaultContext}, ActionModels.var"#session_model#92"{ActionModels.var"#session_model#86#93"}, Vector{Vector{Tuple{Float64}}}, Vector{Vector{Tuple{Float64}}}, Tuple{Symbol, Symbol}, Vector{String}, @NamedTuple{}, DynamicPPL.TypeWrap{Float64}, DynamicPPL.TypeWrap{Int64}}, Tuple{}, DynamicPPL.DefaultContext}(ActionModels.full_model, (action_model = ActionModel{ActionModels.ContinuousRescorlaWagner}(ActionModels.var"#rescorla_wagner_act_after_update#213"{ActionModels.var"#gaussian_report#209"}(ActionModels.var"#gaussian_report#209"()), (action_noise = Parameter{Float64}(1.0, Float64),), NamedTuple(), (observation = Observation{Float64}(Float64),), (report = Action{Float64, Normal}(Float64, Normal),), ActionModels.ContinuousRescorlaWagner(0.0, 0.1)), population_model = DynamicPPL.Model{typeof(ActionModels.regression_population_model), (:linear_submodels, :estimated_parameter_names, :n_sessions), (), (), Tuple{Vector{DynamicPPL.Model}, Vector{Symbol}, Int64}, Tuple{}, DynamicPPL.DefaultContext}(ActionModels.regression_population_model, (linear_submodels = DynamicPPL.Model[DynamicPPL.Model{typeof(ActionModels.linear_model), (:X, :ranef_info), (:inv_link, :prior, :has_ranef), (), Tuple{Matrix{Float64}, @NamedTuple{Z::Vector{Matrix{Float64}}, n_ranef_categories::Vector{Int64}, n_ranef_params::Vector{Int64}}}, Tuple{typeof(logistic), ActionModels.RegPrior, Bool}, DynamicPPL.DefaultContext}(ActionModels.linear_model, (X = [1.0 0.0; 1.0 1.0; 1.0 0.0; 1.0 1.0; 1.0 0.0; 1.0 1.0], ranef_info = (Z = [[1.0 0.0 0.0; 1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0; 0.0 0.0 1.0]], n_ranef_categories = [3], n_ranef_params = [1])), (inv_link = LogExpFunctions.logistic, prior = ActionModels.RegPrior(Product{Continuous, TDist{Float64}, FillArrays.Fill{TDist{Float64}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(TDist{Float64}(ν=3.0), 2)), Distribution[Product{Continuous, Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}, FillArrays.Fill{Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(Truncated(TDist{Float64}(ν=3.0); lower=0.0), 1))]), has_ranef = true), DynamicPPL.DefaultContext()), DynamicPPL.Model{typeof(ActionModels.linear_model), (:X, :ranef_info), (:inv_link, :prior, :has_ranef), (), Tuple{Matrix{Float64}, @NamedTuple{Z::Vector{Matrix{Float64}}, n_ranef_categories::Vector{Int64}, n_ranef_params::Vector{Int64}}}, Tuple{typeof(exp), ActionModels.RegPrior, Bool}, DynamicPPL.DefaultContext}(ActionModels.linear_model, (X = [1.0 0.0; 1.0 1.0; 1.0 0.0; 1.0 1.0; 1.0 0.0; 1.0 1.0], ranef_info = (Z = [[1.0 0.0 0.0; 1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0; 0.0 0.0 1.0]], n_ranef_categories = [3], n_ranef_params = [1])), (inv_link = exp, prior = ActionModels.RegPrior(Product{Continuous, TDist{Float64}, FillArrays.Fill{TDist{Float64}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(TDist{Float64}(ν=3.0), 2)), Distribution[Product{Continuous, Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}, FillArrays.Fill{Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(Truncated(TDist{Float64}(ν=3.0); lower=0.0), 1))]), has_ranef = true), DynamicPPL.DefaultContext())], estimated_parameter_names = [:learning_rate, :action_noise], n_sessions = 6), NamedTuple(), DynamicPPL.DefaultContext()), session_model = ActionModels.var"#session_model#92"{ActionModels.var"#session_model#86#93"}(ActionModels.var"#session_model#86#93"(Core.Box(ActionModels.var"#sample_actions_one_idx#91"(Core.Box(#= circular reference @-2 =#)))), Core.Box(ActionModels.var"#session_model#89#96"(Core.Box(ActionModels.var"#session_model#92"{ActionModels.var"#session_model#86#93"}(#= circular reference @-4 =#)))), Core.Box(([0.0, 0.2, 0.3, 0.4, 0.5, 0.6, 0.0, 0.1, 0.15, 0.2, 0.25, 0.3, 0.0, 0.5, 0.8, 1.0, 1.5, 1.8, 0.0, 0.2, 0.4, 0.7, 1.0, 1.1, 0.0, 2.0, 0.5, 4.0, 5.0, 3.0, 0.0, 2.0, 0.5, 4.0, 5.0, 3.0],))), observations_per_session = [[(1.0,), (1.0,), (1.0,), (2.0,), (2.0,), (2.0,)], [(1.0,), (1.0,), (1.0,), (2.0,), (2.0,), (2.0,)], [(1.0,), (1.0,), (1.0,), (2.0,), (2.0,), (2.0,)], [(1.0,), (1.0,), (1.0,), (2.0,), (2.0,), (2.0,)], [(1.0,), (1.0,), (1.0,), (2.0,), (2.0,), (2.0,)], [(1.0,), (1.0,), (1.0,), (2.0,), (2.0,), (2.0,)]], actions_per_session = [[(0.0,), (0.2,), (0.3,), (0.4,), (0.5,), (0.6,)], [(0.0,), (0.1,), (0.15,), (0.2,), (0.25,), (0.3,)], [(0.0,), (0.5,), (0.8,), (1.0,), (1.5,), (1.8,)], [(0.0,), (0.2,), (0.4,), (0.7,), (1.0,), (1.1,)], [(0.0,), (2.0,), (0.5,), (4.0,), (5.0,), (3.0,)], [(0.0,), (2.0,), (0.5,), (4.0,), (5.0,), (3.0,)]], estimated_parameter_names = (:learning_rate, :action_noise), session_ids = ["id:A.treatment:control", "id:A.treatment:treatment", "id:B.treatment:control", "id:B.treatment:treatment", "id:C.treatment:control", "id:C.treatment:treatment"], initial_states = NamedTuple(), var"##arg#230" = DynamicPPL.TypeWrap{Float64}(), var"##arg#231" = DynamicPPL.TypeWrap{Int64}()), NamedTuple(), DynamicPPL.DefaultContext())

If users want to sample from the model themselves, but still want to draw on the rest of the ActionModels API, they can set it in the ModelFit object themselves by creating an ActionModels.ModelFitResult object. This should be passed to either the posterior or prior field of the ModelFit object, after which it will interface with the ActionModels API as normal.

using ActionModels: Turing
chns = sample(turing_model, NUTS(), 1000, progress = false);

model.posterior = ActionModels.ModelFitResult(; chains = chns);
┌ Info: Found initial step size
└   ϵ = 0.05

In addition to sampling from the posterior, ActionModels also provides functionality for sampling from the prior distribution of the model parameters. This is done with the sample_prior! function, which works in a similar way to sample_posterior!. Notably, it is much simpler, due to not requiring a complex sampler. This means that it only takes n_chains and n_samples as keyword arguments.

prior_chns = sample_prior!(model, n_chains = 1, n_samples = 1000)
Chains MCMC chain (1000×13×1 Array{Float64, 3}):

Iterations        = 1:1:1000
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 1.55 seconds
Compute duration  = 1.55 seconds
parameters        = learning_rate.β[1], learning_rate.β[2], learning_rate.ranef_1.σ[1], learning_rate.ranef_1.r[1], learning_rate.ranef_1.r[2], learning_rate.ranef_1.r[3], action_noise.β[1], action_noise.β[2], action_noise.ranef_1.σ[1], action_noise.ranef_1.r[1], action_noise.ranef_1.r[2], action_noise.ranef_1.r[3]
internals         = lp

Summary Statistics
                  parameters      mean       std      mcse    ess_bulk    ess_tail      rhat   ess_per_sec
                      Symbol   Float64   Float64   Float64     Float64     Float64   Float64       Float64

          learning_rate.β[1]    0.1078    1.6154    0.0511   1013.6934   1026.9433    1.0012      655.6878
          learning_rate.β[2]    0.0046    1.6574    0.0549    919.0261    872.3115    0.9998      594.4541
  learning_rate.ranef_1.σ[1]    1.0686    1.1949    0.0378   1053.1572    901.4391    1.0000      681.2143
  learning_rate.ranef_1.r[1]    0.0381    1.6265    0.0517    877.2720   1023.6250    1.0011      567.4463
  learning_rate.ranef_1.r[2]   -0.0473    1.5406    0.0480   1168.0094   1004.5908    1.0002      755.5041
  learning_rate.ranef_1.r[3]    0.0065    1.4651    0.0490   1038.5887    909.2811    0.9991      671.7909
           action_noise.β[1]    0.0415    1.5681    0.0522    985.3916    846.0586    1.0002      637.3814
           action_noise.β[2]   -0.1086    1.6265    0.0510   1085.4447   1011.8213    1.0014      702.0988
   action_noise.ranef_1.σ[1]    1.0785    1.4350    0.0443    948.6106   1005.6649    0.9995      613.5903
   action_noise.ranef_1.r[1]    0.1459    1.8373    0.0596    981.4397    849.0263    0.9997      634.8252
   action_noise.ranef_1.r[2]   -0.0014    2.3224    0.0735   1039.4146    899.3407    0.9997      672.3251
   action_noise.ranef_1.r[3]   -0.1607    1.7652    0.0541   1029.3644    930.3599    1.0006      665.8243

Quantiles
                  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
                      Symbol   Float64   Float64   Float64   Float64   Float64

          learning_rate.β[1]   -2.7545   -0.7067    0.0075    0.7814    3.8691
          learning_rate.β[2]   -2.8610   -0.7771   -0.0150    0.7416    2.9962
  learning_rate.ranef_1.σ[1]    0.0339    0.3327    0.7470    1.3921    3.8353
  learning_rate.ranef_1.r[1]   -2.9347   -0.3956   -0.0039    0.4373    3.2240
  learning_rate.ranef_1.r[2]   -3.0841   -0.4594   -0.0141    0.4280    2.8816
  learning_rate.ranef_1.r[3]   -2.6300   -0.3852    0.0046    0.4450    2.8627
           action_noise.β[1]   -2.8339   -0.6689    0.0361    0.7336    3.3632
           action_noise.β[2]   -3.1493   -0.8757   -0.1010    0.6632    3.0071
   action_noise.ranef_1.σ[1]    0.0291    0.3382    0.7587    1.3581    4.0530
   action_noise.ranef_1.r[1]   -2.5686   -0.3444    0.0221    0.4976    3.9512
   action_noise.ranef_1.r[2]   -3.1065   -0.3830   -0.0014    0.3770    3.3298
   action_noise.ranef_1.r[3]   -3.4338   -0.4859   -0.0282    0.3014    2.6379

Investigating the results

Population model parameters

The first step in investigating model fitting results is often to look at the population model parameters. Population model parameters and how to visualize them will depend on the type of population model used. See the population model REF section for more details on how to interpret results from different population models. But in general, the Chains object returned by sample_posterior! will contain the posterior distribution of the population model parameters. These can be visualized with various plotting functions; see the MCMCChains documentation for an overview. Here, we just use the standard plot function to visualize the posterior distribution over the beta values of interest:

#Fit the model
chns = sample_posterior!(model);

Plot the posterior distribution of the learning rate and action noise beta parameters.

plot(chns[[Symbol("learning_rate.β[1]"), Symbol("learning_rate.β[2]"), Symbol("action_noise.β[1]"), Symbol("action_noise.β[2]")]])

#TODO: ArviZ support / specifalized plotting functions

Parameter per session

Beyond the population model parameters, users will often be intersted in the parameter estimates for each session in the data. The session parameters can be extracted with the get_session_parameters! function, which returns a ActionModels.SessionParameters object. Whether session parameter estimates should be extracted for the posterior or prior distribution can be specified as the second argument.

#Extract posterior and prior estimates for session paramaters
session_parameters = get_session_parameters!(model)
prior_session_parameters = get_session_parameters!(model, :prior)
-- Session parameters object --
6 sessions, 1 chains, 1000 samples
2 estimated parameters:
   learning_rate
   action_noise

ActioModels provides a convenient functionality for plotting the session parameters.

#TODO: plot(session_parameters)

Users can also access the full distribution over the session parameters, for use in manual downstream analysis. The ActionModels.SessionParameters object contains the distributions for each parameter and each session. These can be found in the value field, which contains nested NamedTuples for each parameter and session, and ultimately an AxisArray with the samples for each chain.

learning_rate_singlesession = getfield(session_parameters.value.learning_rate, Symbol("id:A.treatment:control"));

The user may visualize, summarize or analyze the samples in whichever way they prefer.

#Calcualate the mean
mean(learning_rate_singlesession)
#Plot the distribution
density(learning_rate_singlesession, title = "Learning rate for session A in control condition")

ActionModels provides a convenient function for summarizing all the session parameters as a DataFrame, that can be used for further analysis.

median_df = summarize(session_parameters)

show(median_df)
6×4 DataFrame
 Row │ id      treatment  action_noise  learning_rate
     │ String  String     Float64       Float64
─────┼────────────────────────────────────────────────
   1 │ A       control       0.0537693      0.083032
   2 │ A       treatment     0.0373557      0.0380976
   3 │ B       control       0.176047       0.323365
   4 │ B       treatment     0.120045       0.171555
   5 │ C       control       2.17634        0.845226
   6 │ C       treatment     1.52351        0.700247

This returns the median of each parameter for each session. The user can pass other functions for summarizing the samples, as for example to calculate the standard deviation of the posterior.

std_df = summarize(session_parameters, std)

show(std_df)
6×4 DataFrame
 Row │ id      treatment  action_noise  learning_rate
     │ String  String     Float64       Float64
─────┼────────────────────────────────────────────────
   1 │ A       control       0.0174819     0.00579464
   2 │ A       treatment     0.0124243     0.00328018
   3 │ B       control       0.046299      0.0268604
   4 │ B       treatment     0.0365466     0.0155541
   5 │ C       control       0.64765       0.186856
   6 │ C       treatment     0.355189      0.237741

This can be saved to disk or used for plotting or analysis in whichever way the user prefers.

State trajectories per session

Users can also extract the estimated trajectory of states for each session with the get_state_trajectories! function, which returns an ActionModels.StateTrajectories object. State trajectories are often used to correlate with some external measure, such as neuroimaging data. The second argument specifies which state to extract. Again, the user can also specify to have the prior state trajectories extracted by passing :prior as the second argument.

state_trajectories = get_state_trajectories!(model, :expected_value)

prior_state_trajectories = get_state_trajectories!(model, :expected_value, :prior)
-- State trajectories object --
6 sessions, 1 chains, 1000 samples
1 estimated states:
   expected_value

ActionModels also provides functionality for plotting the state trajectories.

#TODO: plot(state_trajectories)

The ActionModels.StateTrajectories object contains the prior or porsterior distribution over state trajectories for each session, and for each of the specified states. This is stored in the value field, which contains nested NamedTuple objects with state names and then session names as keys, and ultimately an AxisArray with the samples across timesteps. This AxisArray has three axes: the first is the sample index, the second is the session index, and the third is the timestep index.

expected_value_singlesession = getfield(state_trajectories.value.expected_value, Symbol("id:A.treatment:control"));

Again, we can visualize, summarize or analyze the samples in whichever way we prefer. Here, we chose the second timestep of the first session, and calculate the mean and plot the distribution.

mean(expected_value_singlesession[timestep=2])
density(expected_value_singlesession[timestep=2], title = "Expectation at time 2 session A control condition")

This can also be summarized neatly in a DataFrame:

median_df = summarize(state_trajectories, median)

show(median_df)
#And from here, it can be used for plotting or further analysis as desired by the user.
42×4 DataFrame
 Row │ id      treatment  timestep  expected_value
     │ String  String     Int64     Float64
─────┼─────────────────────────────────────────────
   1 │ A       control           0       0.0
   2 │ A       control           1       0.083032
   3 │ A       control           2       0.15917
   4 │ A       control           3       0.228986
   5 │ A       control           4       0.376037
   6 │ A       control           5       0.510878
   7 │ A       control           6       0.634522
   8 │ A       treatment         0       0.0
  ⋮  │   ⋮         ⋮         ⋮            ⋮
  36 │ C       treatment         0       0.0
  37 │ C       treatment         1       0.700247
  38 │ C       treatment         2       0.910148
  39 │ C       treatment         3       0.973067
  40 │ C       treatment         4       1.69217
  41 │ C       treatment         5       1.90773
  42 │ C       treatment         6       1.97234
                                    27 rows omitted

This page was generated using Literate.jl.