Tutorial on 3-level binary

This tutorial is a copy of the 3 level binary hgf tutorial in MATLAB

First load packages

using ActionModels
using HierarchicalGaussianFiltering
using CSV
using DataFrames
using StatsPlots

Get the path for the HGF superfolder

hgf_path = dirname(dirname(pathof(HierarchicalGaussianFiltering)))
"/home/runner/work/HierarchicalGaussianFiltering.jl/HierarchicalGaussianFiltering.jl"

Add the path to the data files

data_path = hgf_path * "/docs/julia_files/tutorials/data/"
"/home/runner/work/HierarchicalGaussianFiltering.jl/HierarchicalGaussianFiltering.jl/docs/julia_files/tutorials/data/"

Load the data

inputs = CSV.read(data_path * "classic_binary_inputs.csv", DataFrame)[!, 1];

Create an HGF

hgf_parameters = Dict(
    ("xprob", "volatility") => -2.5,
    ("xprob", "initial_mean") => 0,
    ("xprob", "initial_precision") => 1,
    ("xvol", "volatility") => -6.0,
    ("xvol", "initial_mean") => 1,
    ("xvol", "initial_precision") => 1,
    ("xbin", "xprob", "coupling_strength") => 1.0,
    ("xprob", "xvol", "coupling_strength") => 1.0,
);

hgf = premade_hgf("binary_3level", hgf_parameters, verbose = false);

Create an agent

action_model = ActionModel(HGFSigmoid(; HGF = hgf, action_noise = 0.2))
agent = init_agent(action_model);

Evolve agent and save actions

actions = simulate!(agent, inputs);

Plot the trajectory of the agent

plot(agent, ("u", "input_value"))
plot!(agent, ("xbin", "prediction"))
plot(agent, ("xprob", "posterior"))
plot(agent, ("xvol", "posterior"))

Set priors for parameter recovery

prior = (; xprob_volatility = Normal(-3.0, 0.5))
(xprob_volatility = Normal{Float64}(μ=-3.0, σ=0.5),)

Get the actions from the MATLAB tutorial

actions = CSV.read(data_path * "classic_binary_actions.csv", DataFrame)[!, 1];

Fit the actions

#Create model
model =
    create_model(action_model, prior, inputs, actions, check_parameter_rejections = true)

#Fit model
posterior_chains = sample_posterior!(model, n_samples = 200, n_chains = 2)
Chains MCMC chain (200×13×2 Array{Float64, 3}):

Iterations        = 101:1:300
Number of chains  = 2
Samples per chain = 200
Wall duration     = 22.21 seconds
Compute duration  = 20.33 seconds
parameters        = xprob_volatility.session[1]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
                   parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec
                       Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64

  xprob_volatility.session[1]   -2.4086    0.1543    0.0116   170.2517   248.1719    1.0137        8.3756

Quantiles
                   parameters      2.5%     25.0%     50.0%     75.0%     97.5%
                       Symbol   Float64   Float64   Float64   Float64   Float64

  xprob_volatility.session[1]   -2.7133   -2.5030   -2.4113   -2.3078   -2.0865
#Plot the chains
plot(posterior_chains)

This page was generated using Literate.jl.