pyhgf.utils.vectorized_belief_propagation.propagation_step#

pyhgf.utils.vectorized_belief_propagation.propagation_step(network, opt_state, inputs, *, optimizer, time_step=1.0, learning_kind='precision_weighted', weight_update=True)[source]#

Single propagation step through the network.

Belief-propagation sweep (top-down prediction → leaf PE → interleaved posterior/PE bottom-up) followed by an optional weight-learning phase. Each step dispatches per element on Layer vs LayerStack:

  • Layer → standard per-layer kernel call (unrolled).

  • LayerStackjax.lax.scan over the stack’s slices.

Top and bottom elements must be Layer``s. A ``LayerStack’s child below (and parent above) can themselves be Layer or LayerStack; the stack-stack case requires the boundary widths to match.

Parameters:
  • network (Network) – The current vectorised network state.

  • opt_state (optax.OptState) – The current optax optimiser state.

  • inputs (tuple) – A tuple (x, y) with the predictors set on the top element and the observations clamped on the bottom element.

  • optimizer (optax.GradientTransformation) – The optax optimiser used for the weight-learning phase.

  • time_step (float) – The time elapsed since the previous step.

  • learning_kind (str) – The weight-gradient mode passed to pyhgf.updates.vectorized.learning.vectorized_weight_gradient().

  • weight_update (bool) – Whether to apply the weight-learning phase after belief propagation.

Returns:

A tuple ((network, opt_state), surprise) where network and opt_state are updated and surprise is the step’s surprise.

Return type:

carry