The continuous Hierarchical Gaussian Filter#

Open In Colab

Hide code cell content
import sys
from IPython.utils import io
if 'google.colab' in sys.modules:

  with io.capture_output() as captured:
      ! pip install pyhgf watermark
Hide code cell content
import arviz as az
import jax.numpy as jnp
import matplotlib.pyplot as plt
import pymc as pm

from pyhgf import load_data
from pyhgf.distribution import HGFDistribution
from pyhgf.model import HGF, Network
from pyhgf.response import first_level_gaussian_surprise

plt.rcParams["figure.constrained_layout.use"] = True
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.

In this notebook, we illustrate applications of the standard two-level and three-level Hierarchical Gaussian Filters (HGF) for continuous inputs. This class of models slightly differs from the previous binary example as input nodes here are not restricted to boolean variables but accept any observations on a continuous domain. Fitting continuous data allows using the HGF with any time series, which can find several applications in neuroscience (see for example the case study on physiological modelling using the Hierarchical Gaussian Filter Example 1: Bayesian filtering of cardiac volatility). The continuous HGF is built on to of the following probabilistic networks:

../_images/continuous.svg

Fig. 5 The two-level and three-level Hierarchical Gaussian Filter for continuous inputs. All nodes are continuous state nodes. The first node (\(x_0\)) can observe new values.#

Here, we will use the continuous HGF to predict the exchange rate of the US Dollar to the Swiss Franc during much of 2010 and 2011 (we use this time series as it is a classical example in the Matlab toolbox).

timeserie = load_data("continuous")

Fitting the continuous HGF with fixed parameters#

The two-level continuous Hierarchical Gaussian Filter#

Create the model#

Note

The default response function for a continuous HGF is the sum of the Gaussian surprise at the first level. In other words, at each time point the model try to update its hierarchy to minimize the discrepancy between the expected and real next observation in the continuous domain.

two_levels_continuous_hgf = (
    Network()
    .add_nodes(precision=1e4)
    .add_nodes(precision=1e4, mean=timeserie[0], tonic_volatility=-13.0, value_children=0)
    .add_nodes(precision=1e1, tonic_volatility=-2.0, volatility_children=1)
)

This function creates an instance of an HGF model automatically parametrized for a two-level continuous structure, so we do not have to worry about creating the node structure ourselves. This class also embed function to add new observations and plots results of network structure. We can visualize the node structure using the pyhgf.plots.plot_network() function that will draw the nodes Graphviz.

two_levels_continuous_hgf.plot_network()
../_images/27fe5f25c909183ea7b612320a02628b71b5e1915fe2c4bc35dc815fc064b0cc.svg

Add data#

# Provide new observations
two_levels_continuous_hgf = two_levels_continuous_hgf.input_data(input_data=timeserie)

Plot trajectories#

A Hierarchical Gaussian Filter parametrized with the standard Gaussian surprise as a response function will act as a Bayesian filter. By presenting new continuous observations and running the update equation forward, we can observe the trajectories of the parameters of the node that are adapting to the trajectory and volatility of the input time series (i.e. the mean \(\mu\) and the precision \(\pi\)). The plot_trajectories function automatically extracts the relevant parameters given the model structure and plots their evolution together with the input data.

two_levels_continuous_hgf.plot_trajectories(show_total_surprise=True);
../_images/4fcf63979264103b29ce0d99e860e16553fb61058f4597157b132c7d9be79793.png
Hide code cell content
# ensure that the results are valid
df = two_levels_continuous_hgf.to_pandas()
assert jnp.isclose(df.x_0_surprise.sum(), -1910.0183)
assert jnp.isclose(df.x_1_surprise.sum(), -2679.9297)
assert jnp.isclose(df.x_2_surprise.sum(), 886.30963)

Looking at the volatility level (i.e. the orange line in the first panel), we see that there are two salient events in our time series where volatility shoots up. The first was in April 2010 when the currency markets reacted to the news that Greece was effectively broken. This leads to a flight into the US dollar (the grey dots rising very quickly), sending the volatility higher. The second is an accelerating increase in the value of the Swiss Franc in August and September 2011, as the Euro crisis dragged on. The point where the Swiss central bank intervened and put a floor under how far the Euro could fall with respect to the Franc is visible in Franc’s valuation against the Dollar. This surprising intervention shows up as another spike in volatility.

We can see that the surprise will increase when the time series exhibits more unexpected behaviours. The degree to which a given observation is expected will depend on the expected value and volatility in the input node, which is influenced by the values of higher-order nodes. One way to assess model fit is to look at the total Gaussian surprise for each observation. This value can be returned using the pyhgf.model.HGF.surprise() method:

two_levels_continuous_hgf.surprise(response_function=first_level_gaussian_surprise).sum()
Array(-886.6174, dtype=float32)

Note

The surprise returned by a model when presented with new observations is a function of the response model that was used. Different response functions can be added and provided, together with additional parameters in the pyhgf.model.HGF.surprise() method. The surprise is the negative log probability density of the new observations under the model priors:

\[surprise = -log(p)\]

Plot correlation#

Node parameters that are highly correlated across time are likely to indicate that the model did not learn hierarchical structure in the data but instead overfitted on some components. One way to quickly check the parameters nodes correlation is to use the plot_correlation function embedded in the HGF class.

two_levels_continuous_hgf.plot_correlations();
../_images/0484242a84112acc5bca80afe32d86d70eb6688f3e07b10002e21e4f484493e0.png

The three-level continuous Hierarchical Gaussian Filter#

Create the model#

The three-level HGF can add a meta-volatility layer to the model. This can be useful if we suspect that the volatility of the time series is not stable across time and we would like our model to learn it. Here, we create a new pyhgf.model.HGF instance, setting the number of levels to 3. Note that we are extending the size of the dictionaries accordingly.

three_levels_continuous_hgf = HGF(
    n_levels=3,
    model_type="continuous",
    initial_mean={"1": 1.04, "2": 0.0, "3": 0.0},
    initial_precision={"1": 1e4, "2": 1.0, "3": 1.0},
    tonic_volatility={"1": -13.0, "2": -2.0, "3": -2.0}
)

The node structure now includes a volatility parent at the third level.

three_levels_continuous_hgf.plot_network()
../_images/b206995142a3c111a412aa92199b648e4347df5db1c8da041cf657c02c724447.svg

Add data#

three_levels_continuous_hgf = three_levels_continuous_hgf.input_data(input_data=timeserie)

Plot trajectories#

three_levels_continuous_hgf.plot_trajectories();
../_images/598443ae7be904755d092607786b490fde5899ef115f6bdb56f9dc283a18beb6.png

Surprise#

Similarly, we can retrieve the overall Gaussian surprise at the first node for each new observation using the built-in method:

three_levels_continuous_hgf.surprise().sum()
Array(-903.4093, dtype=float32)

The overall amount of surprise returned by the three-level HGF is quite similar to what was observed with the two-level model (-964 vs.-965). Because an agent will aim to minimize surprise, it looks like the two-level model is slightly better in this context. However, the surprise will also change as the value for the parameters of the node is optimized beforehand. One important parameter for each node is the tonic volatility (sometimes noted \(\omega\)). This is the tonic part of the variance (the part of the variance in each node that is not affected by the parent node). Here we are going to change the tonic volatility at the second level to see if it can help to minimize surprise:

# create an alternative model with different omega values
# the input time series is passed in the same call
three_levels_continuous_hgf_bis = HGF(
    n_levels=3,
    model_type="continuous",
    initial_mean={"1": 1.04, "2": 0.0, "3": 0.0},
    initial_precision={"1": 1e4, "2": 1e1, "3": 1e1},
    tonic_volatility={"1": -13.0, "2": -1.0, "3": -2.0},
).input_data(input_data=timeserie)
three_levels_continuous_hgf_bis.plot_trajectories();
../_images/8f01782ce28a8d5f4736b5e50590479960ffe11433396b57469ee92ac68484af.png
three_levels_continuous_hgf_bis.surprise().sum()
Array(-828.698, dtype=float32)

Now we are getting a global surprise of -828 with the new model, as compared to a global surprise of -910 before. It looks like the \(\omega\) value at the second level can play an important role in minimizing surprise for this kind of time series. But how can we decide on which value to choose? Doing this by trial and error would be a bit tedious. Instead, we can use dedicated Bayesian methods that will infer the values of \(\omega\) that minimize the surprise (i.e. that maximize the likelihood of the new observations given parameter priors).

Learning parameters with MCMC sampling#

In the previous section, we assumed we knew the parameters of the HGF models beforehand. This can give us information on how an agent using these values would have behaved when presented with these inputs. We can also adopt a different perspective and consider that we want to learn these parameters from the data, and then ask what would be the best parameter values for an agent to minimize surprises when presented with this data. Here, we are going to set priors over some parameters and use Hamiltonian Monte Carlo methods (NUTS) to sample their probability density.

Because the HGF classes are built on the top of JAX, they are natively differentiable and compatible with optimisation libraries. Here, we use PyMC to perform MCMC sampling. PyMC can use any log probability function (here the negative surprise of the model) as a building block for a new distribution by wrapping it in its underlying tensor library Aesara, now PyTensor. pyhgf includes a PyMC-compatible distribution that can do this automaticallypyhgf.distribution.HGFDistribution.

Two-level model#

Creating the model#

Note

The HGF distribution class pyhgf.distribution.HGFDistribution uses the first level Gaussian surprise (i.e. the sum of the Gaussian surprises at each new observation) as default response function, so adding this argument here is optional but is passed for clarity.

hgf_logp_op = HGFDistribution(
    n_levels=2,
    input_data=timeserie[jnp.newaxis, :],
    response_function=first_level_gaussian_surprise
)

This log probability function can then be embedded in a PyMC model using the same API. Here, we are going to optimize omega_1. The other parameters are fixed.

Note

The data has been passed to the distribution in the cell above when the function is created.

with pm.Model() as two_level_hgf:

    # Set a prior over the evolution rate at the first level.
    tonic_volatility_1 = pm.Normal("tonic_volatility_1", -10, 2.0)

    # Call the pre-parametrized HGF distribution here.
    # All parameters are set to their default value except omega_1, omega_2 and mu_1.
    pm.Potential(
        "hgf_loglike", hgf_logp_op(
            tonic_volatility_1=tonic_volatility_1, tonic_volatility_2=-2.0, mean_1=1.0
        )
    )

Note

The \(\omega\) parameters are real numbers that are defined from -\(\infty\) to +\(\infty\). However, as learning rates are expressed in log spaces, values higher than 2 are extremely unlikely and could create aberrant fits to the data. Therefore, here we are using a prior that is centred on more reasonable values.

Visualizing the model#

pm.model_to_graphviz(two_level_hgf)
../_images/276f245ebd9fd4bcae5e926f51a69ac0c4a8b7994776641b7ca80207aba0e073.svg

Sampling#

with two_level_hgf:
    two_level_hgf_idata = pm.sample(chains=2, cores=1)
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (2 chains in 1 job)
NUTS: [tonic_volatility_1]


Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 5 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
az.plot_trace(two_level_hgf_idata);
plt.tight_layout()
/tmp/ipykernel_3991/2516081684.py:2: UserWarning: The figure layout has changed to tight
  plt.tight_layout()
../_images/018e55a574b3af09f7888881c9f811290f1d3d138d41e5da5dd3f16495f3eb2b.png

Using the learned parameters#

We can see from the density distributions that the most probable values for \(\omega_{1}\) are found around -7. To get an idea of the belief trajectories that are implied by such parameters, we can fit the model again using the most likely value directly from the sample:

tonic_volatility_1 = az.summary(two_level_hgf_idata)["mean"]["tonic_volatility_1"]
hgf_mcmc = HGF(
    n_levels=2,
    model_type="continuous",
    initial_mean={"1": timeserie[0], "2": 0.0},
    initial_precision={"1": 1e4, "2": 1e1},
    tonic_volatility={"1": tonic_volatility_1, "2": -2.0}).input_data(
        input_data=timeserie
    )
hgf_mcmc.plot_trajectories();
../_images/0eb9ae9e8212ba850779acb80fa068d22e69445110000af26edca45659e1d734.png
hgf_mcmc.surprise().sum()
Array(-1106.0878, dtype=float32)

Three-level model#

Creating the model#

hgf_logp_op = HGFDistribution(
    n_levels=3,
    input_data=timeserie[jnp.newaxis, :]
)
with pm.Model() as three_level_hgf:

    # Set a prior over the evolution rate at the first level.
    tonic_volatility_1 = pm.Normal("tonic_volatility_1", -10, 2.0)

    # Call the pre-parametrized HGF distribution here.
    # All parameters are set to their default value except omega_1, omega_2, omega_3 and mu_1.
    pm.Potential(
        "hgf_loglike", hgf_logp_op(
            tonic_volatility_1=tonic_volatility_1, tonic_volatility_2=-2.0,
            tonic_volatility_3=-2.0, mean_1=1.0
        )
    )

Visualizing the model#

pm.model_to_graphviz(three_level_hgf)
../_images/276f245ebd9fd4bcae5e926f51a69ac0c4a8b7994776641b7ca80207aba0e073.svg

Sampling#

with three_level_hgf:
    three_level_hgf_idata = pm.sample(chains=2, cores=1)
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (2 chains in 1 job)
NUTS: [tonic_volatility_1]


Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 7 seconds.
There were 1 divergences after tuning. Increase `target_accept` or reparameterize.
We recommend running at least 4 chains for robust computation of convergence diagnostics
az.plot_trace(three_level_hgf_idata);
../_images/270754c1974623598f4616cb780f7a3ea6590971fa2a91dc17d403345ddd5093.png

Using the learned parameters#

tonic_volatility_1 = az.summary(three_level_hgf_idata)["mean"]["tonic_volatility_1"]
hgf_mcmc = HGF(
    n_levels=3,
    model_type="continuous",
    initial_mean={"1": timeserie[0], "2": 0.0, "3": 0.0},
    initial_precision={"1": 1e4, "2": 1e1, "3": 1e1},
    tonic_volatility={"1": tonic_volatility_1, "2": -2.0, "3": -2.0}).input_data(
        input_data=timeserie
    )
hgf_mcmc.plot_trajectories();
../_images/df5a3665ddfb56114973c65a1fc3ab4aea050de3d4249505d894a26fa339cb93.png
hgf_mcmc.surprise().sum()
Array(-1117.9757, dtype=float32)

System configuration#

%load_ext watermark
%watermark -n -u -v -iv -w -p pyhgf,jax,jaxlib
Last updated: Wed Jan 29 2025

Python implementation: CPython
Python version       : 3.12.8
IPython version      : 8.31.0

pyhgf : 0.2.3
jax   : 0.4.31
jaxlib: 0.4.31

jax       : 0.4.31
sys       : 3.12.8 (main, Dec  4 2024, 06:20:31) [GCC 13.2.0]
arviz     : 0.20.0
IPython   : 8.31.0
pyhgf     : 0.2.3
matplotlib: 3.10.0
pymc      : 5.20.0

Watermark: 2.5.0