pyhgf.updates.learning.learning_weights#
- pyhgf.updates.learning.learning_weights(attributes, node_idx, edges, kind='precision_weighted', lr=None, adam_beta1=None, adam_beta2=None, adam_epsilon=None)[source]#
Unified weights update.
The gradient is first computed according to kind, then scaled by lr:
standard (
kind="standard"): raw prediction-error outer product, no
- precision weighting.
\(g_i = \text{PE} \cdot g(\text{parent}_i)\)
precision_weighted (
kind="precision_weighted"): gradient weighted by the child posterior precision. \(g_i = \text{PE} \cdot \pi_\text{child} \cdot g(\text{parent}_i)\)precision_ratio (
kind="precision_ratio"): Kalman-gain-weighted PE using the posterior precisions of child and parent. \(K_i = \pi_\text{child} / (\pi_{\text{parent}_i} + \pi_\text{child})\) \(g_i = K_i \cdot \text{PE} \cdot g(\text{parent}_i)\)
lr controls how the gradient is applied (same semantics for all three kinds):
Adam (
adam_beta1is a float): gradient filtered through Adam; step size
controlled by lr. - Fixed (
adam_beta1is None): \(\Delta w_i = \text{lr} \cdot g_i\).To recover the old “full Kalman step” behaviour for
kind="precision_ratio", passlr=1.0.- Parameters:
attributes (dict[int | str, dict]) – The attributes of the probabilistic network.
node_idx (int) – Pointer to the input node.
edges (tuple[AdjacencyLists, ...]) – The edges of the probabilistic nodes as a tuple of
pyhgf.typing.Indexes. The tuple has the same length as node number. For each node, the index list value and volatility parents and children.kind (str) – Gradient computation mode:
"standard","precision_weighted"(default), or"precision_ratio".lr (float | None) – Fixed learning rate or Adam step size. Applied uniformly across all kind values, including
"precision_ratio".adam_beta1 (float | None) – Adam first moment decay rate. When
None(default) Adam is not used.adam_beta2 (float | None) – Adam second moment decay rate.
adam_epsilon (float | None) – Adam numerical stability constant.
- Returns:
The attributes of the probabilistic network.
- Return type: