pyhgf.updates.vectorized.learning.vectorized_weight_update#

pyhgf.updates.vectorized.learning.vectorized_weight_update(parent_state, child_state, weights, coupling_fn, kind='precision_weighted', lr=0.0, parent_has_constant=False, child_is_binary=False, adam_m=None, adam_v=None, adam_t=0, adam_lr=0.001, adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-08)[source]#

Unified weight update for vectorized layers.

The gradient is first computed according to kind, then scaled by lr (uniformly across all modes):

  • standard (kind="standard"): \(g = \text{PE} \otimes g(\text{parent})\)

  • precision_weighted (kind="precision_weighted"): \(g = \text{PE} \otimes g(\text{parent}) \cdot \pi_\text{child}\)

  • precision_ratio (kind="precision_ratio"): Kalman-gain-style gain using the parent’s expected precision in the numerator. \(K = \pi_\text{parent} / (\pi_\text{parent} + \pi_\text{child})\) \(g = \text{PE} \otimes g(\text{parent}) \cdot K\)

  • map_natural (kind="map_natural"): MAP weight update derived from the predictive-coding free energy with a Gaussian weight prior whose precision is the parent layer’s expected precision. Combines child precision (numerator) with the parent prior plus per-weight Fisher curvature \(g(\text{parent})^2\) (denominator), giving a bounded, curvature-aware update that uses both precisions. Gaussian child: \(g = \text{PE} \otimes g(\text{parent}) \cdot \pi_\text{child} / (\pi_\text{parent} + \pi_\text{child} \cdot g(\text{parent})^2)\). Binary child (drop the redundant \(\pi_\text{child}\) factor since the Bernoulli Fisher cancels through the sigmoid): \(g = \text{PE} \otimes g(\text{parent}) / (\pi_\text{parent} + g(\text{parent})^2)\).

  • pure_natural (kind="pure_natural"): Riemannian natural gradient under the parent’s precision metric — uses both precisions with no curvature term, no extra hyperparameter. Gaussian child: \(g = \text{PE} \otimes g(\text{parent}) \cdot \pi_\text{child} / \pi_\text{parent}\). Binary child: \(g = \text{PE} \otimes g(\text{parent}) / \pi_\text{parent}\). Not bounded — risks blowing up when \(\pi_\text{parent}\) is small.

lr controls how the gradient is applied (same semantics for all five kinds):

  • float ≥ 0: \(\Delta w = g \cdot \text{lr}\)

  • "adam": gradient filtered through the Adam optimiser (Kingma & Ba, 2015); step size controlled by adam_lr.

Parameters:
  • parent_state (LayerState) – Current state of the parent layer.

  • child_state (LayerState) – Current state of the child layer (with observations).

  • weights (Array) – Current weight matrix, shape (n_children, n_parents) or (n_children, n_parents + 1) when the parent layer includes a constant input node.

  • coupling_fn (Callable) – Coupling function applied to parent means.

  • kind (str) – Gradient computation mode: "standard", "precision_weighted", "precision_ratio", "map_natural", or "pure_natural".

  • lr (float | str) – How the gradient is applied: a non-negative float for direct scaling, or "adam" for the Adam optimiser. Applied uniformly across all kind values, including "precision_ratio".

  • parent_has_constant (bool) – If True, the parent layer has a constant input node. Constant nodes are assumed to have mean = 1.0 and precision = 1.0 (fully known bias), and are concatenated to the coupled parent vector after coupling_fn is applied so the bias entry is unconditionally linear regardless of the coupling function.

  • child_is_binary (bool) – If True, the child layer is a binary node. In "precision_weighted" mode the precision multiplication is skipped because the Bernoulli variance is already embedded in the binary prediction-error formula.

  • adam_m (Array | None) – First moment estimate for Adam, same shape as weights. Required when lr="adam".

  • adam_v (Array | None) – Second moment estimate for Adam, same shape as weights. Required when lr="adam".

  • adam_t (int) – Global Adam timestep (already incremented for this step).

  • adam_lr (float) – Adam step size. Only used when lr="adam".

  • adam_beta1 (float) – Exponential decay rate for the first moment.

  • adam_beta2 (float) – Exponential decay rate for the second moment.

  • adam_epsilon (float) – Small constant for numerical stability.

Returns:

  • new_weights – Updated weight matrix.

  • new_adam_m – Updated first moment (or None when Adam is not used).

  • new_adam_v – Updated second moment (or None when Adam is not used).

Raises:
  • ValueError – If kind is not one of "standard", "precision_weighted", "precision_ratio", "map_natural", or "pure_natural".

  • ValueError – If lr is a string other than "adam".

Return type:

tuple[Array, Array | None, Array | None]