pyhgf.utils.vectorized_belief_propagation.propagation_step#
- pyhgf.utils.vectorized_belief_propagation.propagation_step(network, opt_state, inputs, *, optimizer, time_step=1.0, learning_kind='precision_weighted', weight_update=True)[source]#
Single propagation step through the network.
Belief-propagation sweep (top-down prediction → leaf PE → interleaved posterior/PE bottom-up) followed by an optional weight-learning phase. Each step dispatches per element on
LayervsLayerStack:Layer→ standard per-layer kernel call (unrolled).LayerStack→jax.lax.scanover the stack’s slices.
Top and bottom elements must be
Layer``s. A ``LayerStack’s child below (and parent above) can themselves beLayerorLayerStack; the stack-stack case requires the boundary widths to match.- Parameters:
network (Network) – The current vectorised network state.
opt_state (optax.OptState) – The current optax optimiser state.
inputs (tuple) – A tuple
(x, y)with the predictors set on the top element and the observations clamped on the bottom element.optimizer (optax.GradientTransformation) – The optax optimiser used for the weight-learning phase.
time_step (float) – The time elapsed since the previous step.
learning_kind (str) – The weight-gradient mode passed to
pyhgf.updates.vectorized.learning.vectorized_weight_gradient().weight_update (bool) – Whether to apply the weight-learning phase after belief propagation.
- Returns:
A tuple
((network, opt_state), surprise)where network and opt_state are updated and surprise is the step’s surprise.- Return type:
carry