pyhgf.utils.predict_step#
- pyhgf.utils.predict_step(attributes, x_row, prediction_steps, edges, inputs_x_idxs, inputs_y_idxs)[source]#
Run a single forward (prediction-only) pass through the network.
This is the per-sample function used by
Network.predict()viajax.vmap(). It sets the predictor values, runs the prediction sequence top-down, and collects theexpected_meanfrom the target nodes.- Parameters:
attributes (dict[int | str, dict]) – Current node attributes (shared across all samples).
x_row (Array | ndarray | bool | number | bool | int | float | complex) – A single row of predictor values with shape
(n_x_inputs,).prediction_steps (tuple[tuple[int, PjitFunction], ...]) – The prediction steps to execute (excluding predictor nodes).
edges (tuple[AdjacencyLists, ...]) – The network’s edge structure.
inputs_x_idxs (tuple[int]) – Node indexes receiving the predictor values.
inputs_y_idxs (tuple[int]) – Node indexes whose
expected_meanis collected as output.
- Returns:
A 1-D array of
expected_meanvalues, one per target node.- Return type:
predictions