pyhgf.model.Network#
- class pyhgf.model.Network(volatility_updates='unbounded', max_posterior_precision=10000000000.0, mean_field_updates=False, precision_clipping_value=1e-06)[source]#
A predictive coding neural network.
This is the core class to define and manipulate neural networks, that consists in 1. attributes, 2. structure and 3. update sequences.
- Parameters:
volatility_updates (str) – The type of update to perform for volatility coupling. Can be
"unbounded"(default),"eHGF"or"standard".max_posterior_precision (float) – Upper bound applied to every posterior precision write. Defaults to
1e10.mean_field_updates (bool) – If
False(default), use the relaxed prediction and posterior updates. IfTrue, use the original mean-field updates.precision_clipping_value (float)
- attributes#
The attributes of the probabilistic nodes.
- edges#
The edges of the probabilistic nodes as a tuple of
pyhgf.typing.AdjacencyLists. The tuple has the same length as the node number. For each node, the index lists the value/volatility parents/children.
- inputs#
Information on the input nodes.
- node_trajectories#
The dynamic of the node’s beliefs after updating.
- update_sequence#
The sequence of update functions that are applied during the belief propagation step.
- scan_fn#
The function that is passed to
jax.lax.scan(). This is a pre- parametrized version ofpyhgf.networks.beliefs_propagation().
- __init__(volatility_updates='unbounded', max_posterior_precision=10000000000.0, mean_field_updates=False, precision_clipping_value=1e-06)[source]#
Initialize an empty neural network.
- Parameters:
volatility_updates (str) –
The type of update to perform for volatility coupling. Can be “unbounded” (defaults), “eHGF” or “standard”. The unbounded approximation was recently introduced to avoid negative precisions updates, which greatly improve sampling performance. The eHGF update step was proposed as an alternative to the original definition in that it starts by updating the mean and then the precision of the parent node, which generally reduces the errors associated with impossible parameter space and improves sampling.
max_posterior_precision (float) – Upper bound applied to every posterior precision write (value level for continuous/volatile nodes and the implicit volatility level for volatile nodes). Defaults to
1e10and is shared with the vectorized JAX and Rust backends. Increase it to relax the cap, or lower it to be more conservative against precision blow-up.mean_field_updates (bool) – If
False(default), use the relaxed prediction and posterior updates, which lift the mean-field assumption on value-coupling edges via Schur-complement and Laplace/MGF corrections. IfTrue, use the original mean-field updates from [1].precision_clipping_value (float) – Binary state nodes clip their predicted mean to
[precision_clipping_value, 1 - precision_clipping_value]so the implied binary precision \(\hat{\mu}(1 - \hat{\mu})\) never collapses to zero. A larger value (e.g.1e-3, matching the TAPAS HGF Toolbox) keeps the forward filter stable in high-volatility regimes; a very small value (default1e-6) keeps the bound from creating flat, zero-gradient plateaus that hurt gradient-based inference (HMC/NUTS, optimisation). Shared with the vectorized JAX and Rust backends.
- Return type:
None
References
[1]Weber, L. A., Waade, P. T., Legrand, N., Møller, A. H., Stephan, K. E., & Mathys, C. (2026). The generalized hierarchical Gaussian filter. doi:10.7554/elife.110174.1
Methods
__init__([volatility_updates, ...])Initialize an empty neural network.
add_edges(parent_idxs, children_idxs[, ...])Add a value or volatility coupling link between a set of nodes.
add_nodes([kind, n_nodes, node_parameters, ...])Add new input/state node(s) to the neural network.
create_belief_propagation_fn([overwrite, ...])Create the belief propagation function.
create_learning_propagation_fn(...[, ...])Create the belief propagation function.
fit(x, y, inputs_x_idxs, inputs_y_idxs[, ...])Add new observations.
get_input_dimension()Get input node dimensions.
get_network()Return the attributes, edges and update sequence defining the network.
input_custom_sequence(update_branches, ...)Add new observations with custom update sequences.
input_data(input_data[, time_steps, ...])Add new observations.
plot_correlations()Plot the heatmap of cross-trajectories correlation.
plot_network([backend])Visualization of node network using GraphViz or Networkx.
plot_nodes(node_idxs, **kwargs)Plot the node(s) beliefs trajectories.
plot_samples(**kwargs)Plot the parameters trajectories.
plot_trajectories(**kwargs)Plot the parameters trajectories.
predict(x, inputs_x_idxs, inputs_y_idxs)Generate predictions from the network using only the prediction steps.
sample(n_predictions[, time_steps, rng_key, ...])Generate predictions using the utility predict function.
surprise(response_function[, ...])Surprise of the model conditioned by the response function.
to_pandas()Export the nodes trajectories and surprise as a Pandas data frame.
Attributes
input_idxsIdexes of state nodes that can observe new data points by default.