API#

Updates functions#

Update functions are the heart of probabilistic networks as they shape the propagation of beliefs in the neural hierarchy. The library implements the standard variational updates for value and volatility coupling, as described in Weber et al. (2023).

The updates module contains the update functions used during the belief propagation. Update functions are available through three sub-modules, organized according to their functional roles. We usually dissociate the first updates, triggered top-down (from the leaves to the roots of the networks), that are prediction steps and recover the current state of inference. The second updates are the prediction error, signalling the divergence between the prediction and the new observation (for input nodes), or state (for state nodes). Interleaved with these steps are posterior update steps, where a node receives prediction errors from the child nodes and estimates new statistics.

Posterior updates#

Update the sufficient statistics of a state node after receiving prediction errors from children nodes. The prediction errors from all the children below the node should be computed before calling the posterior update step.

Categorical nodes#

categorical_state_update(attributes, ...)

Update the categorical input node given an array of binary observations.

Continuous nodes#

posterior_update_mean_continuous_node

posterior_update_precision_continuous_node

continuous_node_posterior_update(attributes, ...)

Update the posterior of a continuous node using the standard HGF update.

continuous_node_posterior_update_ehgf(...)

Update the posterior of a continuous node using the eHGF update.

Exponential family#

posterior_update_exponential_family_dynamic(...)

Update the hyperparameters of an ef state node using HGF-implied learning rates.

Prediction steps#

Compute the expectation for future observation given the influence of parent nodes. The prediction step are executed for all nodes, top-down, before any observation.

Binary nodes#

binary_state_node_prediction(attributes, ...)

Get the new expected mean and precision of a binary state node.

Continuous nodes#

predict_mean(attributes, edges, node_idx)

Compute the expected mean of a continuous state node.

predict_precision(attributes, edges, node_idx)

Compute the expected precision of a continuous state node.

continuous_node_prediction(attributes, ...)

Update the expected mean and expected precision of a continuous node.

Dirichlet processes#

dirichlet_node_prediction(edges, attributes, ...)

Prediction of a Dirichlet process node.

Prediction error steps#

Compute the value and volatility prediction errors of a given node. The prediction error can only be computed after the posterior update (or observation) of a given node.

Binary state nodes#

binary_state_node_prediction_error(...)

Compute the value prediction errors and predicted precision of a binary node.

binary_finite_state_node_prediction_error(...)

Update the posterior of a binary node given finite precision of the input.

Categorical state nodes#

categorical_state_prediction_error(...)

Prediction error from a categorical state node.

Continuous state nodes#

continuous_node_value_prediction_error(...)

Compute the value prediction error of a state node.

continuous_node_volatility_prediction_error(...)

Compute the volatility prediction error of a state node.

continuous_node_prediction_error(attributes, ...)

Store prediction errors in an input node.

Dirichlet state nodes#

dirichlet_node_prediction_error(edges, ...)

Prediction error and update the child networks of a Dirichlet process node.

update_cluster(operands, edges, node_idx)

Update an existing cluster.

create_cluster(operands, edges, node_idx)

Create a new cluster.

get_candidate(value, sensory_precision, ...)

Find the best cluster candidate given previous clusters and an input value.

likely_cluster_proposal(mean_mu_G0, ...[, ...])

Sample likely new belief distributions given pre-existing clusters.

clusters_likelihood(value, expected_mean, ...)

Likelihood of a parametrized candidate under the new observation.

Exponential family#

prediction_error_update_exponential_family_fixed(...)

Update the parameters of an exponential family distribution.

prediction_error_update_exponential_family_dynamic(...)

Pass the expected sufficient statistics to the implied continuous nodes.

Distribution#

The Hierarchical Gaussian Filter as a PyMC distribution. This distribution can be embedded in models using PyMC>=5.0.0.

logp

Compute the log-probability of a decision model under belief trajectories.

hgf_logp

Compute log-probabilities of a batch of Hierarchical Gaussian Filters.

HGFLogpGradOp

Gradient Op for the HGF distribution.

HGFDistribution

The HGF distribution PyMC >= 5.0 compatible.

HGFPointwise

The HGF distribution returning pointwise log probability.

Model#

The main class is used to create a standard Hierarchical Gaussian Filter for binary or continuous inputs, with two or three levels. This class wraps the previous JAX modules and creates a standard node structure for these models.

HGF

The two-level and three-level Hierarchical Gaussian Filters (HGF).

Network

A predictive coding neural network.

add_continuous_state

Add continuous state node(s) to a network.

add_binary_state

Add binary state node(s) to a network.

add_ef_state

Add exponential family state node(s) to a network.

add_categorical_state

Add categorical state node(s) to a network.

add_dp_state

Add a Dirichlet Process node to a network.

get_couplings

Transform coupling parameter into tuple of indexes and strenghts.

update_parameters

Update the default node parameters using keywords args and dictonary.

insert_nodes

Insert a set of parametrised node in a network.

Plots#

Plotting functionalities to visualize parameters trajectories and correlations after observing new data. We are currently fully supporting Graphviz. NetworkX is also available for some functions.

Graphviz#

plot_trajectories(network[, ci, ...])

Plot the trajectories of the nodes' sufficient statistics and surprise.

plot_correlations(network)

Plot the heatmap correlation of the sufficient statistics trajectories.

plot_network(network)

Visualization of node network using GraphViz.

plot_nodes(network, node_idxs[, ci, ...])

Plot the trajectory of expected sufficient statistics of a set of nodes.

Networkx#

plot_network(network[, figsize, node_size, ...])

Visualization of node network using NetworkX and pydot layout.

Response#

A collection of response functions. A response function is simply a callable taking at least the HGF instance as input after observation and returning surprise.

first_level_gaussian_surprise(hgf[, ...])

Gaussian surprise at the first level of a probabilistic network.

total_gaussian_surprise(hgf[, ...])

Sum of the Gaussian surprise across the probabilistic network.

first_level_binary_surprise(hgf[, ...])

Time series of binary surprises for all binary state nodes.

binary_softmax(hgf[, ...])

Surprise under the binary sofmax model.

binary_softmax_inverse_temperature(hgf[, ...])

Surprise from a binary sofmax parametrized by the inverse temperature.

Utils#

Utilities for manipulating neural networks.

beliefs_propagation(attributes, inputs, ...)

Update the network's parameters after observing new data point(s).

list_branches(node_idxs, edges[, branch_list])

Return the branch of a network from a given set of root nodes.

fill_categorical_state_node(network, ...)

Generate a binary network implied by categorical state(-transition) nodes.

get_update_sequence(network, update_type)

Generate an update sequence from the network's structure.

to_pandas(network)

Export the nodes trajectories and surprise as a Pandas data frame.

add_edges(attributes, edges[, kind, ...])

Add a value or volatility coupling link between a set of nodes.

get_input_idxs(edges)

List all possible default inputs nodes.

add_parent(attributes, edges, index, ...)

Add a new continuous-state parent node to the attributes and edges of a network.

remove_node(attributes, edges, index)

Remove a given node from the network.

Math#

Math functions and probability densities.

MultivariateNormal()

The multivariate normal as an exponential family distribution.

Normal()

The univariate normal as an exponential family distribution.

gaussian_predictive_distribution(x, xi, nu)

Density of the Gaussian-predictive distribution.

gaussian_density(x, mean, precision)

Gaussian density as defined by mean and precision.

sigmoid(x[, lower_bound, upper_bound])

Logistic sigmoid function.

binary_surprise(x, expected_mean)

Surprise at a binary outcome.

gaussian_surprise(x, expected_mean, ...)

Surprise at an outcome under a Gaussian prediction.

dirichlet_kullback_leibler(alpha_1, alpha_2)

Compute the Kullback-Leibler divergence between two Dirichlet distributions.

binary_surprise_finite_precision(value, ...)

Compute the binary surprise with finite precision.