pyhgf.utils.predict_step#

pyhgf.utils.predict_step(attributes, x_row, prediction_steps, edges, inputs_x_idxs, inputs_y_idxs)[source]#

Run a single forward (prediction-only) pass through the network.

This is the per-sample function used by Network.predict() via jax.vmap(). It sets the predictor values, runs the prediction sequence top-down, and collects the expected_mean from the target nodes.

Parameters:
  • attributes (dict[int | str, dict]) – Current node attributes (shared across all samples).

  • x_row (Array | ndarray | bool | number | bool | int | float | complex) – A single row of predictor values with shape (n_x_inputs,).

  • prediction_steps (tuple[tuple[int, PjitFunction], ...]) – The prediction steps to execute (excluding predictor nodes).

  • edges (tuple[AdjacencyLists, ...]) – The network’s edge structure.

  • inputs_x_idxs (tuple[int]) – Node indexes receiving the predictor values.

  • inputs_y_idxs (tuple[int]) – Node indexes whose expected_mean is collected as output.

Returns:

A 1-D array of expected_mean values, one per target node.

Return type:

predictions