pyhgf.model.DeepNetwork#
- class pyhgf.model.DeepNetwork(coupling_fn=<function DeepNetwork.<lambda>>, update_type='eHGF')[source]#
Deep predictive coding network with vectorized operations.
This class implements a deep hierarchical Gaussian filter using layer-wise vectorized operations for efficient scaling to large networks.
Unlike the standard DeepNetwork which uses per-node updates with Python loops, this implementation uses JAX matrix operations to update all nodes in a layer simultaneously.
Examples
>>> # Build a network with method chaining >>> net = ( ... VectorizedDeepNetwork() ... .add_layer(size=10) # Output layer ... .add_layer(size=8) # Hidden layer 1 ... .add_layer(size=6) # Hidden layer 2 ... .add_layer(size=4) # Input layer ... ) >>> >>> # Fit to data >>> net.fit(x_train, y_train, lr=0.2) >>> >>> # Make predictions >>> predictions = net.predict(x_test)
Notes
The network uses volatile nodes internally, which have two levels: - Value level (external): represents the node’s belief about its value - Volatility level (internal): represents uncertainty about the value level
Layer indexing follows the convention: - Layer 0 is the output layer (receives observations) - Layer N is the input layer (receives predictors)
- Parameters:
coupling_fn (Callable)
update_type (str)
- __init__(coupling_fn=<function DeepNetwork.<lambda>>, update_type='eHGF')[source]#
Initialize a VectorizedDeepNetwork.
- Parameters:
coupling_fn (Callable) – Coupling function applied between layers. Default is linear (identity), matching the Rust backend and the Network class. This function is applied to parent means before the weighted sum to predict child means.
update_type (str) – The type of volatility-level posterior update. Can be
"eHGF"(default),"standard"or"unbounded". Matches the Network class and Rust backend.
Methods
__init__([coupling_fn, update_type])Initialize a VectorizedDeepNetwork.
add_layer(size[, kind, tonic_volatility, ...])Add a layer of nodes.
add_layer_stack(layer_sizes[, kind, ...])Add multiple hidden layers at once.
fit(x, y[, lr, optimizer, params, ...])Fit network to data.
predict(x)Forward pass without learning.
reset()Reset the network state.
weight_initialisation([strategy, seed])Initialise inter-layer weight matrices.
Attributes
n_layersNumber of layers in the network.
n_nodesTotal number of nodes in the network.