Fitting action models to data
In this section, we will cover how to fit action models to data with ActionModels. First, we will demonstrate how to set up the model that will be fit under hood by relying on Turing. This consists of specifying three things: an action model, a population model, and the data to fit the model to. Then, we will show how to sample from the posterior and prior distribution of the model parameters. And finally, we will show the tools in ActionModels to inspect and extract the results of the model fitting.
Setting up the model
First we import ActionModels, as well as StatsPlots for plotting the results later.
using ActionModels, StatsPlots
Defining the action model
We will here use the premade Rescorla-Wagner action model provided by ActionModels. This is identical to the model described in the defining action models REF section.
action_model = ActionModel(RescorlaWagner())
-- ActionModel --
Action model function: rescorla_wagner_act_after_update
Number of parameters: 1
Number of states: 0
Number of observations: 1
Number of actions: 1
submodel type: ActionModels.ContinuousRescorlaWagner
Loading the data
We will then specify the data that we want to fit the model to. For this example, we will use a simple manually created dataset, where three participants have completed an experiment where they must predict the next location of a moving target. Each participant has completed the experiment twice, in a control condition and under and experimental treatment.
using DataFrames
data = DataFrame(
observations = repeat([1.0, 1, 1, 2, 2, 2], 6),
actions = vcat(
[0, 0.2, 0.3, 0.4, 0.5, 0.6],
[0, 0.5, 0.8, 1, 1.5, 1.8],
[0, 2, 0.5, 4, 5, 3],
[0, 0.1, 0.15, 0.2, 0.25, 0.3],
[0, 0.2, 0.4, 0.7, 1.0, 1.1],
[0, 2, 0.5, 4, 5, 3],
),
id = vcat(
repeat(["A"], 6),
repeat(["B"], 6),
repeat(["C"], 6),
repeat(["A"], 6),
repeat(["B"], 6),
repeat(["C"], 6),
),
treatment = vcat(repeat(["control"], 18), repeat(["treatment"], 18)),
)
show(data)
36×4 DataFrame
Row │ observations actions id treatment
│ Float64 Float64 String String
─────┼──────────────────────────────────────────
1 │ 1.0 0.0 A control
2 │ 1.0 0.2 A control
3 │ 1.0 0.3 A control
4 │ 2.0 0.4 A control
5 │ 2.0 0.5 A control
6 │ 2.0 0.6 A control
7 │ 1.0 0.0 B control
8 │ 1.0 0.5 B control
⋮ │ ⋮ ⋮ ⋮ ⋮
30 │ 2.0 1.1 B treatment
31 │ 1.0 0.0 C treatment
32 │ 1.0 2.0 C treatment
33 │ 1.0 0.5 C treatment
34 │ 2.0 4.0 C treatment
35 │ 2.0 5.0 C treatment
36 │ 2.0 3.0 C treatment
21 rows omitted
Specifying the population model
Finally, we will specify a population model, which is the model of how parameters vary between the different sessions in the data. There are various options when doing this, which are described in the population model REF section. Here, we will use a regression population model, where we assume that the learning rate and action noise parameters depend linearly on the experimental treatment. It is a hierarchical model, which means assuming that the parameters are sampled from a Gaussian distribution, where the mean of the distribution is a linear function of the treatment condition. This is specified with standard LMER syntax, and we use a logistic inverse link function for the learning rate to ensure that it is between 0 and 1, and an exponential inverse link function for the action noise to ensure that it is positive.
population_model = [
Regression(@formula(learning_rate ~ treatment + (1 | id)), logistic),
Regression(@formula(action_noise ~ treatment + (1 | id)), exp),
];
Creating the full model
Finally, we can combine the three components into a full model that can be fit to the data. This is done with the create_model
function, which takes the action model, population model, and data as arguments. Additionally, we specify which columns in the data contain the actions, observations, and session identifiers. This creates an ActionModels.ModelFit object, which containts the full model, and will contain the sampling results after fitting.
model = create_model(
action_model,
population_model,
data;
action_cols = :actions,
observation_cols = :observations,
session_cols = [:id, :treatment],
)
-- ModelFit object --
Action model: rescorla_wagner_act_after_update
Linear regression population model
2 estimated action model parameters, 6 sessions
Posterior not sampled
Prior not sampled
If there are multiple actions or observations, we can specify them as a NamedTuple mapping each action or observation to a column in the data. For example, if we had two actions and two observations, we could specify them as follows:
(action_name1 = :action_column_name1, action_name2 = :action_column_name2);
The column names can also be specified as a vector of symbols, in which case it will be assumed that the order matches the order of actions or observations in the action model.
Finally, there may be missing data in the dataset. If actions are missing, they can be imputed by ActionModels. This is done by setting the impute_actions
argument to true
in the create_model
function. If impute_actions
is not set to true
, missing actions will simply be skipped during sampling instead. This is a problem for action models which depend on their previous actions.
Fitting the model
Now that the model is created, we are ready to fit it. This is done under the hood using MCMC sampling, which is provided by the Turing.jl framework. ActionModels provides the sample_posterior!
function, which fits the model in this way with sensible defaults.
chns = sample_posterior!(model)
Chains MCMC chain (1000×24×2 Array{Float64, 3}):
Iterations = 501:1:1500
Number of chains = 2
Samples per chain = 1000
Wall duration = 50.48 seconds
Compute duration = 47.23 seconds
parameters = learning_rate.β[1], learning_rate.β[2], learning_rate.ranef_1.σ[1], learning_rate.ranef_1.r[1], learning_rate.ranef_1.r[2], learning_rate.ranef_1.r[3], action_noise.β[1], action_noise.β[2], action_noise.ranef_1.σ[1], action_noise.ranef_1.r[1], action_noise.ranef_1.r[2], action_noise.ranef_1.r[3]
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
learning_rate.β[1] -0.3154 0.8416 0.0315 728.7682 762.2479 1.0031 15.4315
learning_rate.β[2] -0.8246 0.1064 0.0026 1715.2282 1308.6368 0.9997 36.3196
learning_rate.ranef_1.σ[1] 1.9781 1.0277 0.0363 942.5629 920.2904 1.0027 19.9586
learning_rate.ranef_1.r[1] -2.0840 0.8421 0.0315 735.2904 758.3671 1.0035 15.5696
learning_rate.ranef_1.r[2] -0.4278 0.8494 0.0314 752.2694 735.9333 1.0027 15.9291
learning_rate.ranef_1.r[3] 2.2329 1.7321 0.0616 1194.1897 768.2138 1.0053 25.2867
action_noise.β[1] -0.6679 0.7988 0.0317 640.4252 721.9865 1.0044 13.5609
action_noise.β[2] -0.3776 0.2648 0.0060 1979.8653 1459.2444 1.0007 41.9232
action_noise.ranef_1.σ[1] 1.8625 0.7660 0.0238 1164.3755 1172.8452 1.0004 24.6554
action_noise.ranef_1.r[1] -2.2267 0.8319 0.0324 661.7851 603.9202 1.0048 14.0132
action_noise.ranef_1.r[2] -1.0530 0.8153 0.0321 649.9414 716.0384 1.0041 13.7624
action_noise.ranef_1.r[3] 1.4891 0.8043 0.0313 665.3593 763.1122 1.0065 14.0888
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
learning_rate.β[1] -1.9708 -0.8473 -0.3137 0.1844 1.4482
learning_rate.β[2] -1.0408 -0.8885 -0.8246 -0.7579 -0.6199
learning_rate.ranef_1.σ[1] 0.7248 1.3202 1.7670 2.3601 4.3748
learning_rate.ranef_1.r[1] -3.8343 -2.5796 -2.0743 -1.5661 -0.4185
learning_rate.ranef_1.r[2] -2.1652 -0.9422 -0.4193 0.0960 1.2150
learning_rate.ranef_1.r[3] -0.1302 1.1545 1.9234 3.0218 6.5012
action_noise.β[1] -2.2887 -1.2022 -0.6817 -0.1366 0.8935
action_noise.β[2] -0.8808 -0.5618 -0.3771 -0.1967 0.1382
action_noise.ranef_1.σ[1] 0.8861 1.3353 1.6998 2.2056 3.7326
action_noise.ranef_1.r[1] -3.8839 -2.7801 -2.1843 -1.6727 -0.6017
action_noise.ranef_1.r[2] -2.6883 -1.5855 -1.0347 -0.5230 0.5002
action_noise.ranef_1.r[3] -0.1857 0.9591 1.5061 2.0202 3.0710
This returns a MCMCChains Chains object, which contains the samples from the posterior distribution of the model parameters. For each parameter, there is a posterior for the β values (intercept and treatment effect), as well as for the deviation of the random effects σ and the single random effects $r$. Notably, with a fulle dataset, the posterior will contain a large number of parameters. We can see that the second beta value for the learning rate (the dependence on the treatment condition) is negative. The dataset has been constructed to have lower learning rates in the treatment condition, so this is expected.
Notably, sample_posterior!
has many options for how to sample the posterior, which can be set with keyword arguments. If we pass either MCMCThreads()
or MCMCDistributed()
as the second argument, Turing will use multithreading or distributed sampling to parallellise between chains. It is recommended to use MCMCThreads for multithreading, but note that Julia must be started with the --threads
flag to enable multithreading. We can specify the number of samples and chains to sample with the n_samples
and n_chains
keyword arguments. The init_params
keyword argument can be used to specify how the initial parameters for the chains are set. It can be set to :MAP
or :MLE
to use the maximum a posteriori or maximum likelihood estimates as the initial parameters, respectively. It can be set to :sample_prior
to draw a single sample from the prior distribution, or to nothing
to use Turing's default of random values between -2 and 2 as the initial parameters. Finally, a vector of initial parameters can be passed, which will be used as the initial parameters for all chains. Other arguments for the sampling can also be passed. This includes the autodifferentiation backend to use, which can be set with the ad_type
keyword argument, and the sampler to use, which can be set with the sampler
keyword argument. Notably, sample_posterior!
will return the already sampled Chains
object if the posterior has already been sampled. Set resample = true
to override the already sampled posterior.
chns = sample_posterior!(
model,
MCMCThreads(),
n_samples = 500,
n_chains = 4,
init_params = :MAP,
ad_type = AutoForwardDiff(),
sampler = NUTS(),
resample = true,
);
Sampling (2 threads) 0%| | ETA: N/A
┌ Info: Found initial step size
└ ϵ = 0.05
┌ Info: Found initial step size
└ ϵ = 0.0125
┌ Info: Found initial step size
└ ϵ = 0.4
Sampling (2 threads) 25%|███████▌ | ETA: 0:00:42
Sampling (2 threads) 50%|███████████████ | ETA: 0:00:24
Sampling (2 threads) 75%|██████████████████████▌ | ETA: 0:00:09
┌ Info: Found initial step size
└ ϵ = 0.05
Sampling (2 threads) 100%|██████████████████████████████| Time: 0:00:34
Sampling (2 threads) 100%|██████████████████████████████| Time: 0:00:34
Finally, some users may wish to use Turing's own interface for sampling from the posterior instead. The Turing inferface is more flexible in general, but requires more boilerplate code to set up. For this case, the ActionModels.ModelFit
objects contains the Turing model that is used under the hood. Users can extract and use it as any other Turing model, if they wish.
turing_model = model.model
DynamicPPL.Model{typeof(ActionModels.full_model), (:action_model, :population_model, :session_model, :observations_per_session, :actions_per_session, :estimated_parameter_names, :session_ids, :initial_states, Symbol("##arg#230"), Symbol("##arg#231")), (), (), Tuple{ActionModel{ActionModels.ContinuousRescorlaWagner}, DynamicPPL.Model{typeof(ActionModels.regression_population_model), (:linear_submodels, :estimated_parameter_names, :n_sessions), (), (), Tuple{Vector{DynamicPPL.Model}, Vector{Symbol}, Int64}, Tuple{}, DynamicPPL.DefaultContext}, ActionModels.var"#session_model#92"{ActionModels.var"#session_model#86#93"}, Vector{Vector{Tuple{Float64}}}, Vector{Vector{Tuple{Float64}}}, Tuple{Symbol, Symbol}, Vector{String}, @NamedTuple{}, DynamicPPL.TypeWrap{Float64}, DynamicPPL.TypeWrap{Int64}}, Tuple{}, DynamicPPL.DefaultContext}(ActionModels.full_model, (action_model = ActionModel{ActionModels.ContinuousRescorlaWagner}(ActionModels.var"#rescorla_wagner_act_after_update#213"{ActionModels.var"#gaussian_report#209"}(ActionModels.var"#gaussian_report#209"()), (action_noise = Parameter{Float64}(1.0, Float64),), NamedTuple(), (observation = Observation{Float64}(Float64),), (report = Action{Float64, Normal}(Float64, Normal),), ActionModels.ContinuousRescorlaWagner(0.0, 0.1)), population_model = DynamicPPL.Model{typeof(ActionModels.regression_population_model), (:linear_submodels, :estimated_parameter_names, :n_sessions), (), (), Tuple{Vector{DynamicPPL.Model}, Vector{Symbol}, Int64}, Tuple{}, DynamicPPL.DefaultContext}(ActionModels.regression_population_model, (linear_submodels = DynamicPPL.Model[DynamicPPL.Model{typeof(ActionModels.linear_model), (:X, :ranef_info), (:inv_link, :prior, :has_ranef), (), Tuple{Matrix{Float64}, @NamedTuple{Z::Vector{Matrix{Float64}}, n_ranef_categories::Vector{Int64}, n_ranef_params::Vector{Int64}}}, Tuple{typeof(logistic), ActionModels.RegPrior, Bool}, DynamicPPL.DefaultContext}(ActionModels.linear_model, (X = [1.0 0.0; 1.0 1.0; 1.0 0.0; 1.0 1.0; 1.0 0.0; 1.0 1.0], ranef_info = (Z = [[1.0 0.0 0.0; 1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0; 0.0 0.0 1.0]], n_ranef_categories = [3], n_ranef_params = [1])), (inv_link = LogExpFunctions.logistic, prior = ActionModels.RegPrior(Product{Continuous, TDist{Float64}, FillArrays.Fill{TDist{Float64}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(TDist{Float64}(ν=3.0), 2)), Distribution[Product{Continuous, Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}, FillArrays.Fill{Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(Truncated(TDist{Float64}(ν=3.0); lower=0.0), 1))]), has_ranef = true), DynamicPPL.DefaultContext()), DynamicPPL.Model{typeof(ActionModels.linear_model), (:X, :ranef_info), (:inv_link, :prior, :has_ranef), (), Tuple{Matrix{Float64}, @NamedTuple{Z::Vector{Matrix{Float64}}, n_ranef_categories::Vector{Int64}, n_ranef_params::Vector{Int64}}}, Tuple{typeof(exp), ActionModels.RegPrior, Bool}, DynamicPPL.DefaultContext}(ActionModels.linear_model, (X = [1.0 0.0; 1.0 1.0; 1.0 0.0; 1.0 1.0; 1.0 0.0; 1.0 1.0], ranef_info = (Z = [[1.0 0.0 0.0; 1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0; 0.0 0.0 1.0]], n_ranef_categories = [3], n_ranef_params = [1])), (inv_link = exp, prior = ActionModels.RegPrior(Product{Continuous, TDist{Float64}, FillArrays.Fill{TDist{Float64}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(TDist{Float64}(ν=3.0), 2)), Distribution[Product{Continuous, Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}, FillArrays.Fill{Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(Truncated(TDist{Float64}(ν=3.0); lower=0.0), 1))]), has_ranef = true), DynamicPPL.DefaultContext())], estimated_parameter_names = [:learning_rate, :action_noise], n_sessions = 6), NamedTuple(), DynamicPPL.DefaultContext()), session_model = ActionModels.var"#session_model#92"{ActionModels.var"#session_model#86#93"}(ActionModels.var"#session_model#86#93"(Core.Box(ActionModels.var"#sample_actions_one_idx#91"(Core.Box(#= circular reference @-2 =#)))), Core.Box(ActionModels.var"#session_model#89#96"(Core.Box(ActionModels.var"#session_model#92"{ActionModels.var"#session_model#86#93"}(#= circular reference @-4 =#)))), Core.Box(([0.0, 0.2, 0.3, 0.4, 0.5, 0.6, 0.0, 0.1, 0.15, 0.2, 0.25, 0.3, 0.0, 0.5, 0.8, 1.0, 1.5, 1.8, 0.0, 0.2, 0.4, 0.7, 1.0, 1.1, 0.0, 2.0, 0.5, 4.0, 5.0, 3.0, 0.0, 2.0, 0.5, 4.0, 5.0, 3.0],))), observations_per_session = [[(1.0,), (1.0,), (1.0,), (2.0,), (2.0,), (2.0,)], [(1.0,), (1.0,), (1.0,), (2.0,), (2.0,), (2.0,)], [(1.0,), (1.0,), (1.0,), (2.0,), (2.0,), (2.0,)], [(1.0,), (1.0,), (1.0,), (2.0,), (2.0,), (2.0,)], [(1.0,), (1.0,), (1.0,), (2.0,), (2.0,), (2.0,)], [(1.0,), (1.0,), (1.0,), (2.0,), (2.0,), (2.0,)]], actions_per_session = [[(0.0,), (0.2,), (0.3,), (0.4,), (0.5,), (0.6,)], [(0.0,), (0.1,), (0.15,), (0.2,), (0.25,), (0.3,)], [(0.0,), (0.5,), (0.8,), (1.0,), (1.5,), (1.8,)], [(0.0,), (0.2,), (0.4,), (0.7,), (1.0,), (1.1,)], [(0.0,), (2.0,), (0.5,), (4.0,), (5.0,), (3.0,)], [(0.0,), (2.0,), (0.5,), (4.0,), (5.0,), (3.0,)]], estimated_parameter_names = (:learning_rate, :action_noise), session_ids = ["id:A.treatment:control", "id:A.treatment:treatment", "id:B.treatment:control", "id:B.treatment:treatment", "id:C.treatment:control", "id:C.treatment:treatment"], initial_states = NamedTuple(), var"##arg#230" = DynamicPPL.TypeWrap{Float64}(), var"##arg#231" = DynamicPPL.TypeWrap{Int64}()), NamedTuple(), DynamicPPL.DefaultContext())
If users want to sample from the model themselves, but still want to draw on the rest of the ActionModels API, they can set it in the ModelFit object themselves by creating an ActionModels.ModelFitResult
object. This should be passed to either the posterior or prior field of the ModelFit
object, after which it will interface with the ActionModels API as normal.
using ActionModels: Turing
chns = sample(turing_model, NUTS(), 1000, progress = false);
model.posterior = ActionModels.ModelFitResult(; chains = chns);
┌ Info: Found initial step size
└ ϵ = 0.05
In addition to sampling from the posterior, ActionModels also provides functionality for sampling from the prior distribution of the model parameters. This is done with the sample_prior!
function, which works in a similar way to sample_posterior!
. Notably, it is much simpler, due to not requiring a complex sampler. This means that it only takes n_chains
and n_samples
as keyword arguments.
prior_chns = sample_prior!(model, n_chains = 1, n_samples = 1000)
Chains MCMC chain (1000×13×1 Array{Float64, 3}):
Iterations = 1:1:1000
Number of chains = 1
Samples per chain = 1000
Wall duration = 1.55 seconds
Compute duration = 1.55 seconds
parameters = learning_rate.β[1], learning_rate.β[2], learning_rate.ranef_1.σ[1], learning_rate.ranef_1.r[1], learning_rate.ranef_1.r[2], learning_rate.ranef_1.r[3], action_noise.β[1], action_noise.β[2], action_noise.ranef_1.σ[1], action_noise.ranef_1.r[1], action_noise.ranef_1.r[2], action_noise.ranef_1.r[3]
internals = lp
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
learning_rate.β[1] 0.1078 1.6154 0.0511 1013.6934 1026.9433 1.0012 655.6878
learning_rate.β[2] 0.0046 1.6574 0.0549 919.0261 872.3115 0.9998 594.4541
learning_rate.ranef_1.σ[1] 1.0686 1.1949 0.0378 1053.1572 901.4391 1.0000 681.2143
learning_rate.ranef_1.r[1] 0.0381 1.6265 0.0517 877.2720 1023.6250 1.0011 567.4463
learning_rate.ranef_1.r[2] -0.0473 1.5406 0.0480 1168.0094 1004.5908 1.0002 755.5041
learning_rate.ranef_1.r[3] 0.0065 1.4651 0.0490 1038.5887 909.2811 0.9991 671.7909
action_noise.β[1] 0.0415 1.5681 0.0522 985.3916 846.0586 1.0002 637.3814
action_noise.β[2] -0.1086 1.6265 0.0510 1085.4447 1011.8213 1.0014 702.0988
action_noise.ranef_1.σ[1] 1.0785 1.4350 0.0443 948.6106 1005.6649 0.9995 613.5903
action_noise.ranef_1.r[1] 0.1459 1.8373 0.0596 981.4397 849.0263 0.9997 634.8252
action_noise.ranef_1.r[2] -0.0014 2.3224 0.0735 1039.4146 899.3407 0.9997 672.3251
action_noise.ranef_1.r[3] -0.1607 1.7652 0.0541 1029.3644 930.3599 1.0006 665.8243
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
learning_rate.β[1] -2.7545 -0.7067 0.0075 0.7814 3.8691
learning_rate.β[2] -2.8610 -0.7771 -0.0150 0.7416 2.9962
learning_rate.ranef_1.σ[1] 0.0339 0.3327 0.7470 1.3921 3.8353
learning_rate.ranef_1.r[1] -2.9347 -0.3956 -0.0039 0.4373 3.2240
learning_rate.ranef_1.r[2] -3.0841 -0.4594 -0.0141 0.4280 2.8816
learning_rate.ranef_1.r[3] -2.6300 -0.3852 0.0046 0.4450 2.8627
action_noise.β[1] -2.8339 -0.6689 0.0361 0.7336 3.3632
action_noise.β[2] -3.1493 -0.8757 -0.1010 0.6632 3.0071
action_noise.ranef_1.σ[1] 0.0291 0.3382 0.7587 1.3581 4.0530
action_noise.ranef_1.r[1] -2.5686 -0.3444 0.0221 0.4976 3.9512
action_noise.ranef_1.r[2] -3.1065 -0.3830 -0.0014 0.3770 3.3298
action_noise.ranef_1.r[3] -3.4338 -0.4859 -0.0282 0.3014 2.6379
Investigating the results
Population model parameters
The first step in investigating model fitting results is often to look at the population model parameters. Population model parameters and how to visualize them will depend on the type of population model used. See the population model REF section for more details on how to interpret results from different population models. But in general, the Chains object returned by sample_posterior!
will contain the posterior distribution of the population model parameters. These can be visualized with various plotting functions; see the MCMCChains documentation for an overview. Here, we just use the standard plot function to visualize the posterior distribution over the beta values of interest:
#Fit the model
chns = sample_posterior!(model);
Plot the posterior distribution of the learning rate and action noise beta parameters.
plot(chns[[Symbol("learning_rate.β[1]"), Symbol("learning_rate.β[2]"), Symbol("action_noise.β[1]"), Symbol("action_noise.β[2]")]])
#TODO: ArviZ support / specifalized plotting functions
Parameter per session
Beyond the population model parameters, users will often be intersted in the parameter estimates for each session in the data. The session parameters can be extracted with the get_session_parameters!
function, which returns a ActionModels.SessionParameters
object. Whether session parameter estimates should be extracted for the posterior or prior distribution can be specified as the second argument.
#Extract posterior and prior estimates for session paramaters
session_parameters = get_session_parameters!(model)
prior_session_parameters = get_session_parameters!(model, :prior)
-- Session parameters object --
6 sessions, 1 chains, 1000 samples
2 estimated parameters:
learning_rate
action_noise
ActioModels provides a convenient functionality for plotting the session parameters.
#TODO: plot(session_parameters)
Users can also access the full distribution over the session parameters, for use in manual downstream analysis. The ActionModels.SessionParameters
object contains the distributions for each parameter and each session. These can be found in the value field, which contains nested NamedTuples
for each parameter and session, and ultimately an AxisArray
with the samples for each chain.
learning_rate_singlesession = getfield(session_parameters.value.learning_rate, Symbol("id:A.treatment:control"));
The user may visualize, summarize or analyze the samples in whichever way they prefer.
#Calcualate the mean
mean(learning_rate_singlesession)
#Plot the distribution
density(learning_rate_singlesession, title = "Learning rate for session A in control condition")
ActionModels provides a convenient function for summarizing all the session parameters as a DataFrame
, that can be used for further analysis.
median_df = summarize(session_parameters)
show(median_df)
6×4 DataFrame
Row │ id treatment action_noise learning_rate
│ String String Float64 Float64
─────┼────────────────────────────────────────────────
1 │ A control 0.0537693 0.083032
2 │ A treatment 0.0373557 0.0380976
3 │ B control 0.176047 0.323365
4 │ B treatment 0.120045 0.171555
5 │ C control 2.17634 0.845226
6 │ C treatment 1.52351 0.700247
This returns the median of each parameter for each session. The user can pass other functions for summarizing the samples, as for example to calculate the standard deviation of the posterior.
std_df = summarize(session_parameters, std)
show(std_df)
6×4 DataFrame
Row │ id treatment action_noise learning_rate
│ String String Float64 Float64
─────┼────────────────────────────────────────────────
1 │ A control 0.0174819 0.00579464
2 │ A treatment 0.0124243 0.00328018
3 │ B control 0.046299 0.0268604
4 │ B treatment 0.0365466 0.0155541
5 │ C control 0.64765 0.186856
6 │ C treatment 0.355189 0.237741
This can be saved to disk or used for plotting or analysis in whichever way the user prefers.
State trajectories per session
Users can also extract the estimated trajectory of states for each session with the get_state_trajectories!
function, which returns an ActionModels.StateTrajectories
object. State trajectories are often used to correlate with some external measure, such as neuroimaging data. The second argument specifies which state to extract. Again, the user can also specify to have the prior state trajectories extracted by passing :prior
as the second argument.
state_trajectories = get_state_trajectories!(model, :expected_value)
prior_state_trajectories = get_state_trajectories!(model, :expected_value, :prior)
-- State trajectories object --
6 sessions, 1 chains, 1000 samples
1 estimated states:
expected_value
ActionModels also provides functionality for plotting the state trajectories.
#TODO: plot(state_trajectories)
The ActionModels.StateTrajectories
object contains the prior or porsterior distribution over state trajectories for each session, and for each of the specified states. This is stored in the value field, which contains nested NamedTuple
objects with state names and then session names as keys, and ultimately an AxisArray with the samples across timesteps. This AxisArray has three axes: the first is the sample index, the second is the session index, and the third is the timestep index.
expected_value_singlesession = getfield(state_trajectories.value.expected_value, Symbol("id:A.treatment:control"));
Again, we can visualize, summarize or analyze the samples in whichever way we prefer. Here, we chose the second timestep of the first session, and calculate the mean and plot the distribution.
mean(expected_value_singlesession[timestep=2])
density(expected_value_singlesession[timestep=2], title = "Expectation at time 2 session A control condition")
This can also be summarized neatly in a DataFrame:
median_df = summarize(state_trajectories, median)
show(median_df)
#And from here, it can be used for plotting or further analysis as desired by the user.
42×4 DataFrame
Row │ id treatment timestep expected_value
│ String String Int64 Float64
─────┼─────────────────────────────────────────────
1 │ A control 0 0.0
2 │ A control 1 0.083032
3 │ A control 2 0.15917
4 │ A control 3 0.228986
5 │ A control 4 0.376037
6 │ A control 5 0.510878
7 │ A control 6 0.634522
8 │ A treatment 0 0.0
⋮ │ ⋮ ⋮ ⋮ ⋮
36 │ C treatment 0 0.0
37 │ C treatment 1 0.700247
38 │ C treatment 2 0.910148
39 │ C treatment 3 0.973067
40 │ C treatment 4 1.69217
41 │ C treatment 5 1.90773
42 │ C treatment 6 1.97234
27 rows omitted
This page was generated using Literate.jl.