Tutorial on 3-level binary
This tutorial is a copy of the 3 level binary hgf tutorial in MATLAB
First load packages
using ActionModels
using HierarchicalGaussianFiltering
using CSV
using DataFrames
using StatsPlots
Get the path for the HGF superfolder
hgf_path = dirname(dirname(pathof(HierarchicalGaussianFiltering)))
"/home/runner/work/HierarchicalGaussianFiltering.jl/HierarchicalGaussianFiltering.jl"
Add the path to the data files
data_path = hgf_path * "/docs/julia_files/tutorials/data/"
"/home/runner/work/HierarchicalGaussianFiltering.jl/HierarchicalGaussianFiltering.jl/docs/julia_files/tutorials/data/"
Load the data
inputs = CSV.read(data_path * "classic_binary_inputs.csv", DataFrame)[!, 1];
Create an HGF
hgf_parameters = Dict(
("xprob", "volatility") => -2.5,
("xprob", "initial_mean") => 0,
("xprob", "initial_precision") => 1,
("xvol", "volatility") => -6.0,
("xvol", "initial_mean") => 1,
("xvol", "initial_precision") => 1,
("xbin", "xprob", "coupling_strength") => 1.0,
("xprob", "xvol", "coupling_strength") => 1.0,
);
hgf = premade_hgf("binary_3level", hgf_parameters, verbose = false);
Create an agent
action_model = ActionModel(HGFSigmoid(; HGF = hgf, action_noise = 0.2))
agent = init_agent(action_model);
Evolve agent and save actions
actions = simulate!(agent, inputs);
Plot the trajectory of the agent
plot(agent, ("u", "input_value"))
plot!(agent, ("xbin", "prediction"))
plot(agent, ("xprob", "posterior"))
plot(agent, ("xvol", "posterior"))
Set priors for parameter recovery
prior = (; xprob_volatility = Normal(-3.0, 0.5))
(xprob_volatility = Normal{Float64}(μ=-3.0, σ=0.5),)
Get the actions from the MATLAB tutorial
actions = CSV.read(data_path * "classic_binary_actions.csv", DataFrame)[!, 1];
Fit the actions
#Create model
model =
create_model(action_model, prior, inputs, actions, check_parameter_rejections = true)
#Fit model
posterior_chains = sample_posterior!(model, n_samples = 200, n_chains = 2)
Chains MCMC chain (200×13×2 Array{Float64, 3}):
Iterations = 101:1:300
Number of chains = 2
Samples per chain = 200
Wall duration = 22.21 seconds
Compute duration = 20.33 seconds
parameters = xprob_volatility.session[1]
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
xprob_volatility.session[1] -2.4086 0.1543 0.0116 170.2517 248.1719 1.0137 8.3756
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
xprob_volatility.session[1] -2.7133 -2.5030 -2.4113 -2.3078 -2.0865
#Plot the chains
plot(posterior_chains)
This page was generated using Literate.jl.