Fitting action models to data
In this section, we will cover how to fit action models to data with ActionModels. First, we will demonstrate how to set up the model that will be fit under hood by relying on Turing. This consists of specifying three things: an action model, a population model, and the data to fit the model to. Then, we will show how to sample from the posterior and prior distribution of the model parameters. And finally, we will show the tools in ActionModels to inspect and extract the results of the model fitting.
Setting up the model
First we import ActionModels, as well as StatsPlots for plotting the results later.
using ActionModels, StatsPlots
Defining the action model
We will here use the premade Rescorla-Wagner action model provided by ActionModels. This is identical to the model described in the defining action models REF section.
action_model = ActionModel(RescorlaWagner())
-- ActionModel --
Action model function: rescorla_wagner_act_after_update
Number of parameters: 1
Number of states: 0
Number of observations: 1
Number of actions: 1
submodel type: ActionModels.ContinuousRescorlaWagner
Loading the data
We will then specify the data that we want to fit the model to. For this example, we will use a simple manually created dataset, where three participants have completed an experiment where they must predict the next location of a moving target. Each participant has completed the experiment twice, in a control condition and under and experimental treatment.
using DataFrames
data = DataFrame(
observations = repeat([1.0, 1, 1, 2, 2, 2], 6),
actions = vcat(
[0, 0.2, 0.3, 0.4, 0.5, 0.6],
[0, 0.5, 0.8, 1, 1.5, 1.8],
[0, 2, 0.5, 4, 5, 3],
[0, 0.1, 0.15, 0.2, 0.25, 0.3],
[0, 0.2, 0.4, 0.7, 1.0, 1.1],
[0, 2, 0.5, 4, 5, 3],
),
id = vcat(
repeat(["A"], 6),
repeat(["B"], 6),
repeat(["C"], 6),
repeat(["A"], 6),
repeat(["B"], 6),
repeat(["C"], 6),
),
treatment = vcat(repeat(["control"], 18), repeat(["treatment"], 18)),
)
show(data)
36×4 DataFrame
Row │ observations actions id treatment
│ Float64 Float64 String String
─────┼──────────────────────────────────────────
1 │ 1.0 0.0 A control
2 │ 1.0 0.2 A control
3 │ 1.0 0.3 A control
4 │ 2.0 0.4 A control
5 │ 2.0 0.5 A control
6 │ 2.0 0.6 A control
7 │ 1.0 0.0 B control
8 │ 1.0 0.5 B control
⋮ │ ⋮ ⋮ ⋮ ⋮
30 │ 2.0 1.1 B treatment
31 │ 1.0 0.0 C treatment
32 │ 1.0 2.0 C treatment
33 │ 1.0 0.5 C treatment
34 │ 2.0 4.0 C treatment
35 │ 2.0 5.0 C treatment
36 │ 2.0 3.0 C treatment
21 rows omitted
Specifying the population model
Finally, we will specify a population model, which is the model of how parameters vary between the different sessions in the data. There are various options when doing this, which are described in the population model REF section. Here, we will use a regression population model, where we assume that the learning rate and action noise parameters depend linearly on the experimental treatment. It is a hierarchical model, which means assuming that the parameters are sampled from a Gaussian distribution, where the mean of the distribution is a linear function of the treatment condition. This is specified with standard LMER syntax, and we use a logistic inverse link function for the learning rate to ensure that it is between 0 and 1, and an exponential inverse link function for the action noise to ensure that it is positive.
population_model = [
Regression(@formula(learning_rate ~ treatment + (1 | id)), logistic),
Regression(@formula(action_noise ~ treatment + (1 | id)), exp),
];
Creating the full model
Finally, we can combine the three components into a full model that can be fit to the data. This is done with the create_model
function, which takes the action model, population model, and data as arguments. Additionally, we specify which columns in the data contain the actions, observations, and session identifiers. This creates an ActionModels.ModelFit object, which containts the full model, and will contain the sampling results after fitting.
model = create_model(
action_model,
population_model,
data;
action_cols = :actions,
observation_cols = :observations,
session_cols = [:id, :treatment],
)
-- ModelFit object --
Action model: rescorla_wagner_act_after_update
Linear regression population model
2 estimated action model parameters, 6 sessions
Posterior not sampled
Prior not sampled
If there are multiple actions or observations, we can specify them as a NamedTuple mapping each action or observation to a column in the data. For example, if we had two actions and two observations, we could specify them as follows:
(action_name1 = :action_column_name1, action_name2 = :action_column_name2);
The column names can also be specified as a vector of symbols, in which case it will be assumed that the order matches the order of actions or observations in the action model.
Finally, there may be missing data in the dataset. If actions are missing, they can be imputed by ActionModels. This is done by setting the impute_actions
argument to true
in the create_model
function. If impute_actions
is not set to true
, missing actions will simply be skipped during sampling instead. This is a problem for action models which depend on their previous actions.
Fitting the model
Now that the model is created, we are ready to fit it. This is done under the hood using MCMC sampling, which is provided by the Turing.jl framework. ActionModels provides the sample_posterior!
function, which fits the model in this way with sensible defaults.
chns = sample_posterior!(model)
Chains MCMC chain (1000×24×2 Array{Float64, 3}):
Iterations = 501:1:1500
Number of chains = 2
Samples per chain = 1000
Wall duration = 46.86 seconds
Compute duration = 43.46 seconds
parameters = learning_rate.β[1], learning_rate.β[2], learning_rate.ranef_1.σ[1], learning_rate.ranef_1.r[1], learning_rate.ranef_1.r[2], learning_rate.ranef_1.r[3], action_noise.β[1], action_noise.β[2], action_noise.ranef_1.σ[1], action_noise.ranef_1.r[1], action_noise.ranef_1.r[2], action_noise.ranef_1.r[3]
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
learning_rate.β[1] -0.2739 0.8515 0.0349 615.0436 702.9782 1.0003 14.1513
learning_rate.β[2] -0.8287 0.1087 0.0028 1552.6906 1179.8918 0.9993 35.7252
learning_rate.ranef_1.σ[1] 1.9737 1.0101 0.0369 770.7671 916.0555 0.9999 17.7343
learning_rate.ranef_1.r[1] -2.1287 0.8499 0.0349 610.3512 708.6159 1.0004 14.0433
learning_rate.ranef_1.r[2] -0.4670 0.8537 0.0350 608.3110 713.1852 1.0002 13.9964
learning_rate.ranef_1.r[3] 2.2017 1.6325 0.0515 1270.6294 950.6482 0.9997 29.2354
action_noise.β[1] -0.7178 0.8120 0.0312 692.1745 795.3975 1.0025 15.9260
action_noise.β[2] -0.3544 0.2600 0.0060 1854.8970 1229.0322 1.0018 42.6786
action_noise.ranef_1.σ[1] 1.8676 0.7534 0.0211 1660.8644 1146.8143 0.9998 38.2142
action_noise.ranef_1.r[1] -2.1855 0.8294 0.0322 678.3393 844.5278 1.0035 15.6076
action_noise.ranef_1.r[2] -1.0133 0.8285 0.0312 721.8426 979.8111 1.0010 16.6086
action_noise.ranef_1.r[3] 1.5280 0.8368 0.0316 721.1078 918.0065 1.0031 16.5917
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
learning_rate.β[1] -1.8422 -0.8220 -0.2961 0.1968 1.4791
learning_rate.β[2] -1.0520 -0.8968 -0.8262 -0.7595 -0.6147
learning_rate.ranef_1.σ[1] 0.7538 1.2994 1.7527 2.3639 4.5534
learning_rate.ranef_1.r[1] -3.8665 -2.6113 -2.1039 -1.5854 -0.5701
learning_rate.ranef_1.r[2] -2.1903 -0.9553 -0.4542 0.0838 1.1048
learning_rate.ranef_1.r[3] -0.3609 1.1288 1.9308 3.0140 6.0367
action_noise.β[1] -2.3611 -1.2004 -0.7072 -0.2003 0.7812
action_noise.β[2] -0.8566 -0.5237 -0.3545 -0.1755 0.1469
action_noise.ranef_1.σ[1] 0.9175 1.3385 1.7087 2.2094 3.7160
action_noise.ranef_1.r[1] -3.7912 -2.7124 -2.1897 -1.6827 -0.5126
action_noise.ranef_1.r[2] -2.5468 -1.5531 -1.0344 -0.4884 0.6330
action_noise.ranef_1.r[3] 0.0030 0.9953 1.5201 2.0551 3.2351
This returns a MCMCChains Chains object, which contains the samples from the posterior distribution of the model parameters. For each parameter, there is a posterior for the β values (intercept and treatment effect), as well as for the deviation of the random effects σ and the single random effects $r$. Notably, with a fulle dataset, the posterior will contain a large number of parameters. We can see that the second beta value for the learning rate (the dependence on the treatment condition) is negative. The dataset has been constructed to have lower learning rates in the treatment condition, so this is expected.
Notably, sample_posterior!
has many options for how to sample the posterior, which can be set with keyword arguments. If we pass either MCMCThreads()
or MCMCDistributed()
as the second argument, Turing will use multithreading or distributed sampling to parallellise between chains. It is recommended to use MCMCThreads for multithreading, but note that Julia must be started with the --threads
flag to enable multithreading. We can specify the number of samples and chains to sample with the n_samples
and n_chains
keyword arguments. The init_params
keyword argument can be used to specify how the initial parameters for the chains are set. It can be set to :MAP
or :MLE
to use the maximum a posteriori or maximum likelihood estimates as the initial parameters, respectively. It can be set to :sample_prior
to draw a single sample from the prior distribution, or to nothing
to use Turing's default of random values between -2 and 2 as the initial parameters.Finally, a vector of initial parameters can be passed, which will be used as the initial parameters for all chains. Other arguments for the sampling can also be passed. This includes the autodifferentiation backend to use, which can be set with the
adtypekeyword argument, and the sampler to use, which can be set with the
samplerkeyword argument. Notably,
sampleposterior!will return the already sampled
Chainsobject if the posterior has already been sampled. Set
resample = true` to override the already sampled posterior.
chns = sample_posterior!(
model,
MCMCThreads(),
n_samples = 500,
n_chains = 4,
init_params = :MAP,
ad_type = AutoForwardDiff(),
sampler = NUTS(),
resample = true,
);
Sampling (2 threads) 0%| | ETA: N/A
┌ Info: Found initial step size
└ ϵ = 0.2
┌ Info: Found initial step size
└ ϵ = 0.0125
┌ Info: Found initial step size
└ ϵ = 0.4
Sampling (2 threads) 25%|███████▌ | ETA: 0:00:44
Sampling (2 threads) 50%|███████████████ | ETA: 0:00:24
Sampling (2 threads) 75%|██████████████████████▌ | ETA: 0:00:08
┌ Info: Found initial step size
└ ϵ = 0.4
Sampling (2 threads) 100%|██████████████████████████████| Time: 0:00:35
Sampling (2 threads) 100%|██████████████████████████████| Time: 0:00:35
Finally, some users may wish to use Turing's own interface for sampling from the posterior instead. The Turing inferface is more flexible in general, but requires more boilerplate code to set up. For this case, the ActionModels.ModelFit
objects contains the Turing model that is used under the hood. Users can extract and use it as any other Turing model, if they wish.
turing_model = model.model
DynamicPPL.Model{typeof(ActionModels.full_model), (:action_model, :population_model, :session_model, :observations_per_session, :actions_per_session, :estimated_parameter_names, :session_ids, :initial_states, Symbol("##arg#230"), Symbol("##arg#231")), (), (), Tuple{ActionModel{ActionModels.ContinuousRescorlaWagner}, DynamicPPL.Model{typeof(ActionModels.regression_population_model), (:linear_submodels, :estimated_parameter_names, :n_sessions), (), (), Tuple{Vector{DynamicPPL.Model}, Vector{Symbol}, Int64}, Tuple{}, DynamicPPL.DefaultContext}, ActionModels.var"#session_model#92"{ActionModels.var"#session_model#86#93"}, Vector{Vector{Tuple{Float64}}}, Vector{Vector{Tuple{Float64}}}, Tuple{Symbol, Symbol}, Vector{String}, @NamedTuple{}, DynamicPPL.TypeWrap{Float64}, DynamicPPL.TypeWrap{Int64}}, Tuple{}, DynamicPPL.DefaultContext}(ActionModels.full_model, (action_model = ActionModel{ActionModels.ContinuousRescorlaWagner}(ActionModels.var"#rescorla_wagner_act_after_update#213"{ActionModels.var"#gaussian_report#209"}(ActionModels.var"#gaussian_report#209"()), (action_noise = Parameter{Float64}(1.0, Float64),), NamedTuple(), (observation = Observation{Float64}(Float64),), (report = Action{Float64, Normal}(Float64, Normal),), ActionModels.ContinuousRescorlaWagner(0.0, 0.1)), population_model = DynamicPPL.Model{typeof(ActionModels.regression_population_model), (:linear_submodels, :estimated_parameter_names, :n_sessions), (), (), Tuple{Vector{DynamicPPL.Model}, Vector{Symbol}, Int64}, Tuple{}, DynamicPPL.DefaultContext}(ActionModels.regression_population_model, (linear_submodels = DynamicPPL.Model[DynamicPPL.Model{typeof(ActionModels.linear_model), (:X, :ranef_info), (:inv_link, :prior, :has_ranef), (), Tuple{Matrix{Float64}, @NamedTuple{Z::Vector{Matrix{Float64}}, n_ranef_categories::Vector{Int64}, n_ranef_params::Vector{Int64}}}, Tuple{typeof(logistic), ActionModels.RegPrior, Bool}, DynamicPPL.DefaultContext}(ActionModels.linear_model, (X = [1.0 0.0; 1.0 0.0; 1.0 0.0; 1.0 1.0; 1.0 1.0; 1.0 1.0], ranef_info = (Z = [[1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0; 1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0]], n_ranef_categories = [3], n_ranef_params = [1])), (inv_link = LogExpFunctions.logistic, prior = ActionModels.RegPrior(Product{Continuous, TDist{Float64}, FillArrays.Fill{TDist{Float64}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(TDist{Float64}(ν=3.0), 2)), Distribution[Product{Continuous, Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}, FillArrays.Fill{Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(Truncated(TDist{Float64}(ν=3.0); lower=0.0), 1))]), has_ranef = true), DynamicPPL.DefaultContext()), DynamicPPL.Model{typeof(ActionModels.linear_model), (:X, :ranef_info), (:inv_link, :prior, :has_ranef), (), Tuple{Matrix{Float64}, @NamedTuple{Z::Vector{Matrix{Float64}}, n_ranef_categories::Vector{Int64}, n_ranef_params::Vector{Int64}}}, Tuple{typeof(exp), ActionModels.RegPrior, Bool}, DynamicPPL.DefaultContext}(ActionModels.linear_model, (X = [1.0 0.0; 1.0 0.0; 1.0 0.0; 1.0 1.0; 1.0 1.0; 1.0 1.0], ranef_info = (Z = [[1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0; 1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0]], n_ranef_categories = [3], n_ranef_params = [1])), (inv_link = exp, prior = ActionModels.RegPrior(Product{Continuous, TDist{Float64}, FillArrays.Fill{TDist{Float64}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(TDist{Float64}(ν=3.0), 2)), Distribution[Product{Continuous, Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}, FillArrays.Fill{Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}, 1, Tuple{Base.OneTo{Int64}}}}(v=Fill(Truncated(TDist{Float64}(ν=3.0); lower=0.0), 1))]), has_ranef = true), DynamicPPL.DefaultContext())], estimated_parameter_names = [:learning_rate, :action_noise], n_sessions = 6), NamedTuple(), DynamicPPL.DefaultContext()), session_model = ActionModels.var"#session_model#92"{ActionModels.var"#session_model#86#93"}(ActionModels.var"#session_model#86#93"(Core.Box(ActionModels.var"#sample_actions_one_idx#91"(Core.Box(#= circular reference @-2 =#)))), Core.Box(ActionModels.var"#session_model#89#96"(Core.Box(ActionModels.var"#session_model#92"{ActionModels.var"#session_model#86#93"}(#= circular reference @-4 =#)))), Core.Box(([0.0, 0.2, 0.3, 0.4, 0.5, 0.6, 0.0, 0.5, 0.8, 1.0, 1.5, 1.8, 0.0, 2.0, 0.5, 4.0, 5.0, 3.0, 0.0, 0.1, 0.15, 0.2, 0.25, 0.3, 0.0, 0.2, 0.4, 0.7, 1.0, 1.1, 0.0, 2.0, 0.5, 4.0, 5.0, 3.0],))), observations_per_session = [[(1.0,), (1.0,), (1.0,), (2.0,), (2.0,), (2.0,)], [(1.0,), (1.0,), (1.0,), (2.0,), (2.0,), (2.0,)], [(1.0,), (1.0,), (1.0,), (2.0,), (2.0,), (2.0,)], [(1.0,), (1.0,), (1.0,), (2.0,), (2.0,), (2.0,)], [(1.0,), (1.0,), (1.0,), (2.0,), (2.0,), (2.0,)], [(1.0,), (1.0,), (1.0,), (2.0,), (2.0,), (2.0,)]], actions_per_session = [[(0.0,), (0.2,), (0.3,), (0.4,), (0.5,), (0.6,)], [(0.0,), (0.5,), (0.8,), (1.0,), (1.5,), (1.8,)], [(0.0,), (2.0,), (0.5,), (4.0,), (5.0,), (3.0,)], [(0.0,), (0.1,), (0.15,), (0.2,), (0.25,), (0.3,)], [(0.0,), (0.2,), (0.4,), (0.7,), (1.0,), (1.1,)], [(0.0,), (2.0,), (0.5,), (4.0,), (5.0,), (3.0,)]], estimated_parameter_names = (:learning_rate, :action_noise), session_ids = ["id:A.treatment:control", "id:B.treatment:control", "id:C.treatment:control", "id:A.treatment:treatment", "id:B.treatment:treatment", "id:C.treatment:treatment"], initial_states = NamedTuple(), var"##arg#230" = DynamicPPL.TypeWrap{Float64}(), var"##arg#231" = DynamicPPL.TypeWrap{Int64}()), NamedTuple(), DynamicPPL.DefaultContext())
If users want to sample from the model themselves, but still want to draw on the rest of the ActionModels API, they can set it in the ModelFit object themselves by creating an ActionModels.ModelFitResult
object. This should be passed to either the posterior or prior field of the ModelFit
object, after which it will interface with the ActionModels API as normal.
using ActionModels: Turing
chns = sample(turing_model, NUTS(), 1000, progress = false);
model.posterior = ActionModels.ModelFitResult(; chains = chns);
┌ Info: Found initial step size
└ ϵ = 0.025
In addition to sampling from the posterior, ActionModels also provides functionality for sampling from the prior distribution of the model parameters. This is done with the sample_prior!
function, which works in a similar way to sample_posterior!
. Notably, it is much simpler, due to not requiring a complex sampler. This means that it only takes n_chains
and n_samples
as keyword arguments.
prior_chns = sample_prior!(model, n_chains = 1, n_samples = 1000)
Chains MCMC chain (1000×13×1 Array{Float64, 3}):
Iterations = 1:1:1000
Number of chains = 1
Samples per chain = 1000
Wall duration = 1.65 seconds
Compute duration = 1.65 seconds
parameters = learning_rate.β[1], learning_rate.β[2], learning_rate.ranef_1.σ[1], learning_rate.ranef_1.r[1], learning_rate.ranef_1.r[2], learning_rate.ranef_1.r[3], action_noise.β[1], action_noise.β[2], action_noise.ranef_1.σ[1], action_noise.ranef_1.r[1], action_noise.ranef_1.r[2], action_noise.ranef_1.r[3]
internals = lp
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
learning_rate.β[1] 0.1000 1.6996 0.0561 939.0323 773.7738 1.0035 569.1105
learning_rate.β[2] -0.0300 1.6328 0.0540 879.9462 944.1954 0.9997 533.3007
learning_rate.ranef_1.σ[1] 1.0364 1.3250 0.0437 974.8590 903.5752 1.0004 590.8236
learning_rate.ranef_1.r[1] -0.0336 1.6861 0.0488 1041.2117 880.7797 1.0004 631.0374
learning_rate.ranef_1.r[2] 0.0192 1.7416 0.0536 1064.8222 877.7412 0.9994 645.3468
learning_rate.ranef_1.r[3] -0.0389 1.4450 0.0449 1035.5275 975.2328 1.0009 627.5924
action_noise.β[1] 0.0328 1.6734 0.0543 904.9334 948.6935 1.0005 548.4445
action_noise.β[2] -0.0282 1.5394 0.0508 923.2031 873.7012 1.0029 559.5170
action_noise.ranef_1.σ[1] 1.1720 1.6563 0.0543 851.1144 936.7700 0.9994 515.8269
action_noise.ranef_1.r[1] -0.0654 1.9133 0.0612 1054.8695 876.9116 0.9999 639.3149
action_noise.ranef_1.r[2] 0.0625 2.6637 0.0883 1018.5414 975.2328 0.9995 617.2978
action_noise.ranef_1.r[3] -0.0972 1.6989 0.0542 1008.7916 958.4968 0.9993 611.3888
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
learning_rate.β[1] -3.0365 -0.7435 0.0110 0.8361 3.4479
learning_rate.β[2] -3.2650 -0.7785 0.0331 0.7756 3.1641
learning_rate.ranef_1.σ[1] 0.0354 0.3245 0.6846 1.3208 4.0512
learning_rate.ranef_1.r[1] -3.1597 -0.4616 -0.0079 0.3661 3.1456
learning_rate.ranef_1.r[2] -3.3306 -0.3309 -0.0030 0.4502 2.6127
learning_rate.ranef_1.r[3] -2.8876 -0.4038 -0.0059 0.4012 2.6831
action_noise.β[1] -3.3063 -0.6863 0.0646 0.9014 3.0670
action_noise.β[2] -3.3695 -0.7467 0.0520 0.7749 2.8471
action_noise.ranef_1.σ[1] 0.0324 0.3611 0.7793 1.4660 4.3725
action_noise.ranef_1.r[1] -3.4997 -0.4769 -0.0112 0.4034 3.0606
action_noise.ranef_1.r[2] -2.8894 -0.3918 0.0047 0.4611 3.0907
action_noise.ranef_1.r[3] -3.4951 -0.4197 -0.0043 0.3671 2.6893
Investigating the results
Population model parameters
The first step in investigating model fitting results is often to look at the population model parameters. Population model parameters and how to visualize them will depend on the type of population model used. See the population model REF section for more details on how to interpret results from different population models. But in general, the Chains object returned by sample_posterior!
will contain the posterior distribution of the population model parameters. These can be visualized with various plotting functions; see the MCMCChains documentation for an overview. Here, we just use the standard plot function to visualize the posterior distribution over the beta values of interest:
#Fit the model
chns = sample_posterior!(model);
Plot the posterior distribution of the learning rate and action noise beta parameters.
plot(chns[[Symbol("learning_rate.β[1]"), Symbol("learning_rate.β[2]"), Symbol("action_noise.β[1]"), Symbol("action_noise.β[2]")]])
#TODO: ArviZ support / specifalized plotting functions
Parameter per session
Beyond the population model parameters, users will often be intersted in the parameter estimates for each session in the data. The session parameters can be extracted with the get_session_parameters!
function, which returns a ActionModels.SessionParameters
object. Whether session parameter estimates should be extracted for the posterior or prior distribution can be specified as the second argument.
#Extract posterior and prior estimates for session paramaters
session_parameters = get_session_parameters!(model)
prior_session_parameters = get_session_parameters!(model, :prior)
-- Session parameters object --
6 sessions, 1 chains, 1000 samples
2 estimated parameters:
learning_rate
action_noise
ActioModels provides a convenient functionality for plotting the session parameters.
#TODO: plot(session_parameters)
Users can also access the full distribution over the session parameters, for use in manual downstream analysis. The ActionModels.SessionParameters
object contains the distributions for each parameter and each session. These can be found in the value field, which contains nested NamedTuples
for each parameter and session, and ultimately an AxisArray
with the samples for each chain.
learning_rate_singlesession = getfield(session_parameters.value.learning_rate, Symbol("id:A.treatment:control"));
The user may visualize, summarize or analyze the samples in whichever way they prefer.
#Calcualate the mean
mean(learning_rate_singlesession)
#Plot the distribution
density(learning_rate_singlesession, title = "Learning rate for session A in control condition")
ActionModels provides a convenient function for summarizing all the session parameters as a DataFrame
, that can be used for further analysis.
median_df = summarize(session_parameters)
show(median_df)
6×4 DataFrame
Row │ id treatment action_noise learning_rate
│ String String Float64 Float64
─────┼────────────────────────────────────────────────
1 │ A control 0.053729 0.0830947
2 │ B control 0.176646 0.322814
3 │ C control 2.2181 0.855045
4 │ A treatment 0.036657 0.0379923
5 │ B treatment 0.120573 0.171572
6 │ C treatment 1.51952 0.724694
This returns the median of each parameter for each session. The user can pass other functions for summarizing the samples, as for example to calculate the standard deviation of the posterior.
std_df = summarize(session_parameters, std)
show(std_df)
6×4 DataFrame
Row │ id treatment action_noise learning_rate
│ String String Float64 Float64
─────┼────────────────────────────────────────────────
1 │ A control 0.0162926 0.00548921
2 │ B control 0.0444528 0.0271187
3 │ C control 0.705664 0.185973
4 │ A treatment 0.0113222 0.00306329
5 │ B treatment 0.0360526 0.0160724
6 │ C treatment 0.397858 0.240445
This can be saved to disk or used for plotting or analysis in whichever way the user prefers.
State trajectories per session
Users can also extract the estimated trajectory of states for each session with the get_state_trajectories!
function, which returns an ActionModels.StateTrajectories
object. State trajectories are often used to correlate with some external measure, such as neuroimaging data. The second argument specifies which state to extract. Again, the user can also specify to have the prior state trajectories extracted by passing :prior
as the second argument.
state_trajectories = get_state_trajectories!(model, :expected_value)
prior_state_trajectories = get_state_trajectories!(model, :expected_value, :prior)
-- State trajectories object --
6 sessions, 1 chains, 1000 samples
1 estimated states:
expected_value
ActionModels also provides functionality for plotting the state trajectories.
#TODO: plot(state_trajectories)
The ActionModels.StateTrajectories
object contains the prior or porsterior distribution over state trajectories for each session, and for each of the specified states. This is stored in the value field, which contains nested NamedTuple
objects with state names and then session names as keys, and ultimately an AxisArray with the samples across timesteps. This AxisArray has three axes: the first is the sample index, the second is the session index, and the third is the timestep index.
expected_value_singlesession = getfield(state_trajectories.value.expected_value, Symbol("id:A.treatment:control"));
Again, we can visualize, summarize or analyze the samples in whichever way we prefer. Here, we chose the second timestep of the first session, and calculate the mean and plot the distribution.
mean(expected_value_singlesession[timestep=2])
density(expected_value_singlesession[timestep=2], title = "Expectation at time 2 session A control condition")
This can also be summarized neatly in a DataFrame:
median_df = summarize(state_trajectories, median)
show(median_df)
#And from here, it can be used for plotting or further analysis as desired by the user.
42×4 DataFrame
Row │ id treatment timestep expected_value
│ String String Int64 Float64
─────┼─────────────────────────────────────────────
1 │ A control 0 0.0
2 │ A control 1 0.0830947
3 │ A control 2 0.159285
4 │ A control 3 0.229144
5 │ A control 4 0.376292
6 │ A control 5 0.511214
7 │ A control 6 0.634924
8 │ B control 0 0.0
⋮ │ ⋮ ⋮ ⋮ ⋮
36 │ C treatment 0 0.0
37 │ C treatment 1 0.724694
38 │ C treatment 2 0.924206
39 │ C treatment 3 0.979133
40 │ C treatment 4 1.71895
41 │ C treatment 5 1.92262
42 │ C treatment 6 1.9787
27 rows omitted
This page was generated using Literate.jl.