pyhgf.utils.vectorized_belief_propagation.run_scan#

pyhgf.utils.vectorized_belief_propagation.run_scan(init_carry, inputs, optimizer, learning_kind, weight_update, record, time_step=1.0)[source]#

Run jax.lax.scan over the belief-propagation step.

Decorated with eqx.filter_jit: arrays in init_carry / inputs are dynamic; optimizer / learning_kind / weight_update / record / time_step are static and form the JIT cache key.

Parameters:
  • init_carry (tuple) – The initial scan carry, a tuple (network, opt_state).

  • inputs (tuple) – The per-step inputs scanned over, a tuple of predictor/observation arrays with a leading time axis.

  • optimizer (GradientTransformation) – The optax optimiser used for the weight-learning phase.

  • learning_kind (str) – The weight-gradient mode passed to pyhgf.updates.vectorized.learning.vectorized_weight_gradient().

  • weight_update (bool) – Whether to apply the weight-learning phase at every step.

  • record (tuple) – Tuple of LayerState field names to record at every time step (e.g. ("expected_mean", "precision")). An empty tuple disables recording. The scan output is just the per-step output_pred. With a non-empty tuple, the per-step output is (traj_step, output_pred) where traj_step is dict[field_name, tuple[Array, ...]] (one per-element array per field, with LayerStack elements contributing arrays of shape (N, n_nodes)). After scan stacks across time, each leaf carries a leading (T,) axis.

  • time_step (float) – Uniform inference time step \(\\Delta t\) passed to every propagation_step call. Defaults to 1.0.

Returns:

  • ((final_network, final_opt_state), step_output) where step_output is either

  • the stacked predictions alone (record == ()) or a

  • (stacked_traj, stacked_predictions) tuple.

Return type:

tuple