Model Fitting
In many cases, we want to be able to draw conclusions about specific observed phenomena, such as behavioural differences between distinct populations. A conventional approach in this context is model fitting, which involves estimating the parameter values of a model (e.g., prior beliefs) that are most likely given the observed behavior of a participant. This approach is often used in fields such as computational psychiatry or mathematical psychology to develop more precise models and theories of mental processes, to find mechanistic differences between clinical populations, or to investigate the relationship between computational constructs such as Bayesian beliefs and neuronal dynamics.
Quick Start
Model Fitting with ActionModels.jl
Model fitting in 'ActiveInference' is mediated through 'ActionModels', which is our sister package for implementing and fitting various behavioural models to data. The core of 'ActionModels' is the action model function, which takes a single observation, runs the inference scheme (updating the agent's beliefs), and calculates the probability distribution over actions from which the agent samples its actions. (Check out the ActionModels documentation for more details)
To demonstrate this, let's define a very simple generative model with a single state factor and two possible actions, and then initialize our active inference object:
# Define the number of states, observations, and controls
n_states = [4]
n_observations = [4]
n_controls = [2]
# Define the policy length
policy_length = 1
# Use the create_matrix_templates function to create uniform A and B matrices.
A, B = create_matrix_templates(n_states, n_observations, n_controls, policy_length)
# Initialize an active inference object with the created matrices
aif = init_aif(A, B)
We can now use the action_pomdp!
function (which serves as our active inference "action model") to calculate the probability distribution over actions for a single observation:
# Define observation
observation = [1]
# Calculate action probabilities
action_distribution = action_pomdp!(aif, observation)
Agent in ActionModels.jl
Another key component of 'ActionModels' is an Agent
, which wraps the action model and active inference object in a more abstract structure. The Agent
is initialized using a substruct
to include our active inference object, and the action model is our action_pomdp!
function.
Let's first install 'ActionModels' from the official Julia registry and import it:
Pkg.add("ActionModels")
using ActionModels
We can now create an Agent
with the action_pomdp!
function and the active inference object:
# Initialize agent with active inference object as substruct
agent = init_agent(
action_pomdp!, # The active inference action model
substruct = aif # The active inference object
)
We use an initialized Agent
primarily for fitting; however, it can also be used with a set of convenience functions to run simulations, which are described in Simulation with ActionModels.
Fitting a Single Subject Model
We have our Agent
object defined as above. Next, we need to specify priors for the parameters we want to estimate.
For example, let's estimate the action precision parameter α
and use a Gamma distribution as its prior.
# Import the Distributions package
using Distributions
# Define the prior distribution for the alpha parameters inside a dictionary
priors = Dict("alpha" => Gamma(1, 1))
We can now use the create_model
function to instantiate a probabilistic model object with data. This function takes the Agent
object, the priors, and a set of observations and actions as arguments.
First, let's define some observations and actions as vectors:
# Define observations and actions
observations = [1, 1, 2, 3, 1, 4, 2, 1]
actions = [2, 1, 2, 2, 2, 1, 2, 2]
Now we can instantiate the probabilistic model object:
# Create the model object
single_subject_model = create_model(agent, priors, observations, actions)
The single_subject_model
can be used as a standard Turing object. Performing inference on this model is as simple as:
results = fit_model(single_subject_model)
Fitting a Model with Multiple Subjects
Often, we have data from multiple subjects that we would like to fit simultaneously. The good news is that this can be done by instantiating our probabilisitc model on an entire dataset containing data from multiple subjects.
Let's define some dataset with observations and actions for three subjects:
# Import the DataFrames package
using DataFrames
# Create a DataFrame
data = DataFrame(
subjectID = [1, 1, 1, 2, 2, 2, 3, 3, 3], # Subject IDs
observations = [1, 1, 2, 3, 1, 4, 2, 1, 3], # Observations
actions = [2, 1, 2, 2, 2, 1, 2, 2, 1] # Actions
)
Row | subjectID | observations | actions |
---|---|---|---|
Int64 | Int64 | Int64 | |
1 | 1 | 1 | 2 |
2 | 1 | 1 | 1 |
3 | 1 | 2 | 2 |
4 | 2 | 3 | 2 |
5 | 2 | 1 | 2 |
6 | 2 | 4 | 1 |
7 | 3 | 2 | 2 |
8 | 3 | 1 | 2 |
9 | 3 | 3 | 1 |
To instantiate the probabilistic model on our dataset, we pass the data
DataFrame to the create_model
function along with the names of the columns that contain the subject identifiers, observations, and actions:
# Create the model object
multi_subject_model = create_model(
agent,
priors,
data; # Dataframe
grouping_cols = [:subjectID], # Column with subject IDs
input_cols = ["observations"], # Column with observations
action_cols = ["actions"] # Column with actions
)
To fit the model, we use the fit_model
function as before:
results = fit_model(multi_subject_model)
┌ Info: Found initial step size
└ ϵ = 3.2
Customizing the Fitting Procedure
The fit_model
function has several optional arguments that allow us to customize the fitting procedure. For example, you can specify the number of iterations, the number of chains, the sampling algorithm, or to parallelize over chains:
results = fit_model(
multi_subject_model, # The model object
parallelization = MCMCDistributed(), # Run chains in parallel
sampler = NUTS(;adtype=AutoReverseDiff(compile=true)), # Specify the type of sampler
n_itererations = 1000, # Number of iterations,
n_chains = 4, # Number of chains
)
'Turing' allows us to run distributed MCMCDistributed()
or threaded MCMCThreads()
parallel sampling. The default is to run chains serially MCMCSerial()
. For information on the available samplers see the Turing documentation.
Results
The output of the fit_model
function is an object that contains the standard 'Turing' chains which we can use to extract the summary statistics of the posterior distribution.
Let's extract the chains from the results object:
chains = results.chains
Chains MCMC chain (1000×15×1 Array{Float64, 3}):
Iterations = 501:1:1500
Number of chains = 1
Samples per chain = 1000
Wall duration = 10.21 seconds
Compute duration = 10.21 seconds
parameters = parameters[1, 1], parameters[1, 2], parameters[1, 3]
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
parameters[1, 1] 0.9818 1.0700 0.0400 496.4024 468.7198 1.0015 48.6050
parameters[1, 2] 1.0000 0.9031 0.0352 466.2592 405.1487 1.0007 45.6535
parameters[1, 3] 0.9687 0.9661 0.0350 465.1326 418.9269 0.9997 45.5432
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
parameters[1, 1] 0.0212 0.2333 0.6346 1.3427 3.9843
parameters[1, 2] 0.0340 0.3222 0.7330 1.4429 3.3095
parameters[1, 3] 0.0231 0.2763 0.6083 1.3878 3.4385
Note that the parameter names in the chains are somewhat cryptic. We can use the rename_chains
function to rename them to something more understandable:
renamed_chains = rename_chains(chains, multi_subject_model)
Chains MCMC chain (1000×15×1 Array{Float64, 3}):
Iterations = 501:1:1500
Number of chains = 1
Samples per chain = 1000
Wall duration = 10.21 seconds
Compute duration = 10.21 seconds
parameters = subjectID:1.alpha, subjectID:2.alpha, subjectID:3.alpha
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
subjectID:1.alpha 0.9818 1.0700 0.0400 496.4024 468.7198 1.0015 48.6050
subjectID:2.alpha 1.0000 0.9031 0.0352 466.2592 405.1487 1.0007 45.6535
subjectID:3.alpha 0.9687 0.9661 0.0350 465.1326 418.9269 0.9997 45.5432
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
subjectID:1.alpha 0.0212 0.2333 0.6346 1.3427 3.9843
subjectID:2.alpha 0.0340 0.3222 0.7330 1.4429 3.3095
subjectID:3.alpha 0.0231 0.2763 0.6083 1.3878 3.4385
That looks better! We can now use the 'StatsPlots' package to plot the chain traces and density plots of the posterior distributions for all subjects:
using StatsPlots # Load the StatsPlots package
plot(renamed_chains)
We can also visualize the posterior distributions against the priors. This can be done by first taking samples from the prior:
# Sample from the prior
prior_chains = sample(multi_subject_model, Prior(), 1000)
# Rename parameters in the prior chains
renamed_prior_chains = rename_chains(prior_chains, multi_subject_model)
To plot the posterior distributions against the priors, we use the plot_parameters
function:
plot_parameters(renamed_prior_chains, renamed_chains)
This page was generated using Literate.jl.