Tutorial on 2-level continuous HGF

#This is a replication of the tutorial from the MATLAB toolbox, using an HGF to filter the exchange rates between USD and CHF

First load packages

using ActionModels
using HierarchicalGaussianFiltering
using StatsPlots

Get the path for the HGF superfolder

hgf_path = dirname(dirname(pathof(HierarchicalGaussianFiltering)))
"/home/runner/work/HierarchicalGaussianFiltering.jl/HierarchicalGaussianFiltering.jl"

Add the path to the data files

data_path = hgf_path * "/docs/julia_files/tutorials/data/"
"/home/runner/work/HierarchicalGaussianFiltering.jl/HierarchicalGaussianFiltering.jl/docs/julia_files/tutorials/data/"

Load the data

inputs = Float64[]
open(data_path * "classic_usdchf_inputs.dat") do f
    for ln in eachline(f)
        push!(inputs, parse(Float64, ln))
    end
end

#Create HGF
hgf = premade_hgf("continuous_2level", verbose = false);

action_model = ActionModel(HGFGaussian(; HGF = hgf))
agent = init_agent(action_model)
-- ActionModels Agent --
Action model: hgf_gaussian
This agent has received 0 observations

Set parameters for parameter recover

parameters = (
    x_xvol_coupling_strength = 1.0,
    u_input_noise = -log(1e4),
    x_volatility = -13,
    xvol_volatility = -2,
    x_initial_mean = 1.04,
    x_initial_precision = 1 / (0.0001),
    xvol_initial_mean = 1.0,
    xvol_initial_precision = 1 / 0.1,
    action_noise = 0.01,
);

set_parameters!(agent, parameters)
reset!(agent)

Evolve agent

actions = simulate!(agent, inputs);

Plot trajectories

plot(
    agent,
    "u",
    size = (1300, 500),
    xlims = (0, 615),
    markersize = 3,
    markercolor = "green2",
    title = "HGF trajectory",
    ylabel = "CHF-USD exchange rate",
    xlabel = "Trading days since 1 January 2010",
)
plot!(agent, ("x", "posterior"), color = "red")
plot!(actions, size = (1300, 500), xlims = (0, 614), markersize = 3, markercolor = "orange")
plot(
    agent,
    "xvol",
    color = "blue",
    size = (1300, 500),
    xlims = (0, 615),
    xlabel = "Trading days since 1 January 2010",
    title = "Volatility parent trajectory",
)

Set priors for fitting

priors = (
    u_input_noise = Normal(-6, 1),
    x_volatility = Normal(-4, 1),
    xvol_volatility = Normal(-4, 1),
    action_noise = LogNormal(log(0.01), 1),
);

Do parameter recovery

model =
    create_model(action_model, priors, inputs, actions, check_parameter_rejections = true)

#Fit
posterior_chains = sample_posterior!(model, n_samples = 200, n_chains = 2)
Chains MCMC chain (200×16×2 Array{Float64, 3}):

Iterations        = 101:1:300
Number of chains  = 2
Samples per chain = 200
Wall duration     = 169.72 seconds
Compute duration  = 168.08 seconds
parameters        = u_input_noise.session[1], x_volatility.session[1], xvol_volatility.session[1], action_noise.session[1]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
                  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec
                      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64

    u_input_noise.session[1]   -3.4207    0.7827    0.0833    92.2396   133.5027    1.0014        0.5488
     x_volatility.session[1]   -6.5218    0.7927    0.0809    98.0808   138.1043    1.0059        0.5835
  xvol_volatility.session[1]   -7.7308    0.4444    0.0290   232.6085   275.8073    1.0027        1.3839
     action_noise.session[1]    0.0103    0.0003    0.0000   375.7189   220.5796    1.0283        2.2353

Quantiles
                  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
                      Symbol   Float64   Float64   Float64   Float64   Float64

    u_input_noise.session[1]   -4.9037   -4.0165   -3.3981   -2.7944   -1.9466
     x_volatility.session[1]   -7.9407   -7.1629   -6.5123   -5.9100   -5.0881
  xvol_volatility.session[1]   -8.6732   -8.0598   -7.6919   -7.4282   -6.9607
     action_noise.session[1]    0.0097    0.0102    0.0103    0.0105    0.0109

Plot the chains

plot(posterior_chains)

This page was generated using Literate.jl.