pyhgf.updates.vectorized.learning.vectorized_weight_update#
- pyhgf.updates.vectorized.learning.vectorized_weight_update(parent_state, child_state, weights, coupling_fn, kind='precision_weighted', lr=0.0, parent_has_constant=False, child_is_binary=False, adam_m=None, adam_v=None, adam_t=0, adam_lr=0.001, adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-08)[source]#
Unified weight update for vectorized layers.
The gradient is first computed according to kind, then scaled by lr (uniformly across all modes):
standard (
kind="standard"): \(g = \text{PE} \otimes g(\text{parent})\)precision_weighted (
kind="precision_weighted"): \(g = \text{PE} \otimes g(\text{parent}) \cdot \pi_\text{child}\)precision_ratio (
kind="precision_ratio"): Kalman-gain-style gain using the parent’s expected precision in the numerator. \(K = \pi_\text{parent} / (\pi_\text{parent} + \pi_\text{child})\) \(g = \text{PE} \otimes g(\text{parent}) \cdot K\)map_natural (
kind="map_natural"): MAP weight update derived from the predictive-coding free energy with a Gaussian weight prior whose precision is the parent layer’s expected precision. Combines child precision (numerator) with the parent prior plus per-weight Fisher curvature \(g(\text{parent})^2\) (denominator), giving a bounded, curvature-aware update that uses both precisions. Gaussian child: \(g = \text{PE} \otimes g(\text{parent}) \cdot \pi_\text{child} / (\pi_\text{parent} + \pi_\text{child} \cdot g(\text{parent})^2)\). Binary child (drop the redundant \(\pi_\text{child}\) factor since the Bernoulli Fisher cancels through the sigmoid): \(g = \text{PE} \otimes g(\text{parent}) / (\pi_\text{parent} + g(\text{parent})^2)\).pure_natural (
kind="pure_natural"): Riemannian natural gradient under the parent’s precision metric — uses both precisions with no curvature term, no extra hyperparameter. Gaussian child: \(g = \text{PE} \otimes g(\text{parent}) \cdot \pi_\text{child} / \pi_\text{parent}\). Binary child: \(g = \text{PE} \otimes g(\text{parent}) / \pi_\text{parent}\). Not bounded — risks blowing up when \(\pi_\text{parent}\) is small.
lr controls how the gradient is applied (same semantics for all five kinds):
float ≥ 0: \(\Delta w = g \cdot \text{lr}\)
"adam": gradient filtered through the Adam optimiser (Kingma & Ba, 2015); step size controlled by adam_lr.
- Parameters:
parent_state (LayerState) – Current state of the parent layer.
child_state (LayerState) – Current state of the child layer (with observations).
weights (Array) – Current weight matrix, shape
(n_children, n_parents)or(n_children, n_parents + 1)when the parent layer includes a constant input node.coupling_fn (Callable) – Coupling function applied to parent means.
kind (str) – Gradient computation mode:
"standard","precision_weighted","precision_ratio","map_natural", or"pure_natural".lr (float | str) – How the gradient is applied: a non-negative float for direct scaling, or
"adam"for the Adam optimiser. Applied uniformly across all kind values, including"precision_ratio".parent_has_constant (bool) – If True, the parent layer has a constant input node. Constant nodes are assumed to have mean = 1.0 and precision = 1.0 (fully known bias), and are concatenated to the coupled parent vector after
coupling_fnis applied so the bias entry is unconditionally linear regardless of the coupling function.child_is_binary (bool) – If True, the child layer is a binary node. In
"precision_weighted"mode the precision multiplication is skipped because the Bernoulli variance is already embedded in the binary prediction-error formula.adam_m (Array | None) – First moment estimate for Adam, same shape as weights. Required when
lr="adam".adam_v (Array | None) – Second moment estimate for Adam, same shape as weights. Required when
lr="adam".adam_t (int) – Global Adam timestep (already incremented for this step).
adam_lr (float) – Adam step size. Only used when
lr="adam".adam_beta1 (float) – Exponential decay rate for the first moment.
adam_beta2 (float) – Exponential decay rate for the second moment.
adam_epsilon (float) – Small constant for numerical stability.
- Returns:
new_weights – Updated weight matrix.
new_adam_m – Updated first moment (or None when Adam is not used).
new_adam_v – Updated second moment (or None when Adam is not used).
- Raises:
ValueError – If kind is not one of
"standard","precision_weighted","precision_ratio","map_natural", or"pure_natural".ValueError – If lr is a string other than
"adam".
- Return type:
tuple[Array, Array | None, Array | None]