Example 1: Bayesian filtering of cardiac volatility#
import jax
# Enable 64-bit precision so JAX matches PyTensor's float64 graphs and avoids
# dtype-truncation warnings when the network is wrapped for PyMC.
jax.config.update("jax_enable_x64", True)
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
from pytensor import wrap_jax
from pyhgf.model import Network
from pyhgf.response import total_gaussian_surprise
plt.rcParams["figure.constrained_layout.use"] = True
The nodalized version of the Hierarchical Gaussian Filter that is implemented in pyhgf opens the possibility to create filters with multiple inputs. Here, we illustrate how we can use this feature to create an agent that is filtering their physiological signals in real-time. We use a two-level Hierarchical Gaussian Filter to predict the dynamics of the instantaneous heart rate (the RR interval measured at each heartbeat). We then extract the trajectory of surprise at each predictive node to relate it with the cognitive task performed by the participant while the signal is being recorded.
Loading physiological recording#
We use a RR time series included in Systole as an example.
rr_s = (
pd.read_csv(
"https://raw.githubusercontent.com/LegrandNico/systole/refs/heads/main/src/systole/datasets/rr.txt"
).rr.to_numpy()
/ 1000
)
Model#
Note
Here we use the total Gaussian surprise (pyhgf.response.total_gaussian_surprise()) as a response function. This response function deviates from the default behaviour for the continuous HGF in that it returns the sum of the surprise for all the probabilistic nodes in the network, whereas the default (pyhgf.response.first_level_gaussian_surprise()) only computes the surprise at the first level (i.e. the value parent of the continuous input node). We explicitly specify this parameter here to indicate that we want our model to minimise its prediction errors over all variables, and not only at the observation level. In this case, however, the results are expected to be very similar between the two methods.
@wrap_jax
def two_level_logp(tonic_volatility):
"""Compute the log-probability of the two-level HGF."""
return (
-(
Network()
.add_nodes(precision=1e4)
.add_nodes(value_children=0, mean=1.0)
.add_nodes(tonic_volatility=tonic_volatility, volatility_children=0)
)
.input_data(input_data=rr_s)
.surprise(response_function=total_gaussian_surprise)
.sum()
)
with pm.Model() as three_level_hgf:
# omegas priors
tonic_volatility = pm.Normal("tonic_volatility", 0.0, 5.0)
# HGF distribution
pm.Potential(
"hgf_loglike",
two_level_logp(tonic_volatility=tonic_volatility),
)
pm.model_to_graphviz(three_level_hgf)
with three_level_hgf:
idata = pm.sample(chains=2, cores=1, backend="jax")
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (2 chains in 1 job)
NUTS: [tonic_volatility]
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 7 seconds.
There was 1 divergence after tuning. Increase `target_accept` or reparameterize.
We recommend running at least 4 chains for robust computation of convergence diagnostics
az.plot_trace_dist(idata);
# retrieve the best fir for omega_2
tonic_volatility = az.summary(idata)["mean"].astype(float)["tonic_volatility"]
hgf = (
Network()
.add_nodes(precision=1e4)
.add_nodes(value_children=0, mean=1.0)
.add_nodes(tonic_volatility=tonic_volatility, volatility_children=0)
).input_data(input_data=rr_s)
hgf.plot_trajectories();
System configuration#
%load_ext watermark
%watermark -n -u -v -iv -w -p pyhgf,jax,jaxlib
Last updated: Tue, 16 Jun 2026
Python implementation: CPython
Python version : 3.12.13
IPython version : 9.14.1
pyhgf : 0.3.0
jax : 0.4.31
jaxlib: 0.4.31
IPython : 9.14.1
arviz : 1.2.0
jax : 0.4.31
matplotlib: 3.11.0
numpy : 2.4.6
pandas : 3.0.3
pyhgf : 0.3.0
pymc : 6.0.1
pytensor : 3.0.7
Watermark: 2.6.0