pyhgf.model.Network#

class pyhgf.model.Network(update_type='eHGF')[source]#

A predictive coding neural network.

This is the core class to define and manipulate neural networks, that consists in 1. attributes, 2. structure and 3. update sequences.

Attributes:
attributes

The attributes of the probabilistic nodes.

edges

The edges of the probabilistic nodes as a tuple of pyhgf.typing.AdjacencyLists. The tuple has the same length as the node number. For each node, the index lists the value/volatility parents/children.

inputs

Information on the input nodes.

node_trajectories

The dynamic of the node’s beliefs after updating.

update_sequence

The sequence of update functions that are applied during the belief propagation step.

scan_fn

The function that is passed to jax.lax.scan(). This is a pre- parametrized version of pyhgf.networks.beliefs_propagation().

Parameters:

update_type (str)

__init__(update_type='eHGF')[source]#

Initialize an empty neural network.

Parameters:
update_type

The type of update to perform for volatility coupling. Can be “unbounded” (defaults), “ehgf” or “standard”. The unbounded approximation was recently introduced to avoid negative precisions updates, which greatly improve sampling performance. The eHGF update step was proposed as an alternative to the original definition in that it starts by updating the mean and then the precision of the parent node, which generally reduces the errors associated with impossible parameter space and improves sampling.

Parameters:

update_type (str)

Return type:

None

Methods

__init__([update_type])

Initialize an empty neural network.

add_edges([kind, parent_idxs, ...])

Add a value or volatility coupling link between a set of nodes.

add_nodes([kind, n_nodes, node_parameters, ...])

Add new input/state node(s) to the neural network.

create_belief_propagation_fn([overwrite, ...])

Create the belief propagation function.

get_input_dimension()

Get input node dimensions.

get_network()

Return the attributes, edges and update sequence defining the network.

input_custom_sequence(update_branches, ...)

Add new observations with custom update sequences.

input_data(input_data[, time_steps, ...])

Add new observations.

plot_correlations()

Plot the heatmap of cross-trajectories correlation.

plot_network([backend])

Visualization of node network using GraphViz or Networkx.

plot_nodes(node_idxs, **kwargs)

Plot the node(s) beliefs trajectories.

plot_samples(**kwargs)

Plot the parameters trajectories.

plot_trajectories(**kwargs)

Plot the parameters trajectories.

sample(n_predictions[, time_steps, rng_key, ...])

Generate predictions using the utility predict function.

surprise(response_function[, ...])

Surprise of the model conditioned by the response function.

to_pandas()

Export the nodes trajectories and surprise as a Pandas data frame.

Attributes

input_idxs

Idexes of state nodes that can observe new data points by default.