pyhgf.utils.vectorized_belief_propagation.run_scan#
- pyhgf.utils.vectorized_belief_propagation.run_scan(init_carry, inputs, optimizer, learning_kind, weight_update, record, time_step=1.0)[source]#
Run
jax.lax.scanover the belief-propagation step.Decorated with
eqx.filter_jit: arrays ininit_carry/inputsare dynamic;optimizer/learning_kind/weight_update/record/time_stepare static and form the JIT cache key.- Parameters:
init_carry (tuple) – The initial scan carry, a tuple
(network, opt_state).inputs (tuple) – The per-step inputs scanned over, a tuple of predictor/observation arrays with a leading time axis.
optimizer (GradientTransformation) – The optax optimiser used for the weight-learning phase.
learning_kind (str) – The weight-gradient mode passed to
pyhgf.updates.vectorized.learning.vectorized_weight_gradient().weight_update (bool) – Whether to apply the weight-learning phase at every step.
record (tuple) – Tuple of
LayerStatefield names to record at every time step (e.g.("expected_mean", "precision")). An empty tuple disables recording. The scan output is just the per-stepoutput_pred. With a non-empty tuple, the per-step output is(traj_step, output_pred)wheretraj_stepisdict[field_name, tuple[Array, ...]](one per-element array per field, withLayerStackelements contributing arrays of shape(N, n_nodes)). Afterscanstacks across time, each leaf carries a leading(T,)axis.time_step (float) – Uniform inference time step \(\\Delta t\) passed to every
propagation_stepcall. Defaults to1.0.
- Returns:
((final_network, final_opt_state), step_output)wherestep_outputis eitherthe stacked predictions alone (
record == ()) or a(stacked_traj, stacked_predictions)tuple.
- Return type: