using ActionModels, HierarchicalGaussianFiltering
using CSV, DataFrames
using StatsPlots
Get the path for the HGF superfolder
hgf_path = dirname(dirname(pathof(HierarchicalGaussianFiltering)))
"/home/runner/work/HierarchicalGaussianFiltering.jl/HierarchicalGaussianFiltering.jl"
Add the path to the data files
data_path = hgf_path * "/docs/julia_files/tutorials/data/"
#Load data
data = CSV.read(data_path * "classic_cannonball_data.csv", DataFrame)
inputs = data[(data.ID .== 20) .& (data.session .== 1), :].outcome
#Create HGF
hgf = premade_hgf("JGET", verbose = false)
action_model = ActionModel(HGFGaussian(; HGF = hgf))
agent = init_agent(action_model)
#Set parameters
parameters = (
action_noise = 1,
u_input_noise = 0,
x_initial_mean = first(inputs) + 2,
x_initial_precision = 0.001,
x_volatility = -8,
xvol_volatility = -8,
xnoise_volatility = -7,
xnoise_vol_volatility = -2,
x_xvol_coupling_strength = 1,
xnoise_xnoise_vol_coupling_strength = 1,
)
set_parameters!(agent, parameters)
reset!(agent)
#Simulate updates and actions
actions = simulate!(agent, inputs);
#Plot belief trajectories
plot(agent, "u")
plot!(agent, "x")
plot(agent, "xvol")
plot(agent, "xnoise")
plot(agent, "xnoise_vol")
This page was generated using Literate.jl.