pyhgf.updates.vectorized.learning.vectorized_weight_gradient#

pyhgf.updates.vectorized.learning.vectorized_weight_gradient(parent_state, child_state, coupling_fn, kind='precision_weighted', parent_has_constant=False, child_is_binary=False)[source]#

Per-layer weight gradient for the vectorised deep network.

Returns the descent gradient for the weight matrix. Sign-flipped from the natural “ascent” formulation so it composes with standard optax (apply_updates(weights, updates) performs weights + updates; optax.sgd(lr).update(grad, state, w) returns -lr * grad; together they reproduce the legacy weights + lr * g rule with g = -grad).

The gradient is computed according to kind:

  • standard (kind="standard"): \(g = \text{PE} \otimes g(\text{parent})\)

  • precision_weighted (kind="precision_weighted"): \(g = \text{PE} \otimes g(\text{parent}) \cdot \pi_\text{child}\)

  • precision_ratio (kind="precision_ratio"): Kalman-gain-style gain using the parent’s expected precision in the numerator. \(K = \pi_\text{parent} / (\pi_\text{parent} + \pi_\text{child})\), \(g = \text{PE} \otimes g(\text{parent}) \cdot K\)

  • map_natural (kind="map_natural"): MAP weight update from the predictive-coding free energy with a Gaussian weight prior whose precision is the parent’s expected precision; combines child precision (numerator) with parent prior plus per-weight Fisher curvature \(g(\text{parent})^2\) (denominator). Gaussian child: \(g = \text{PE} \otimes g(\text{parent}) \cdot \pi_\text{child} / (\pi_\text{parent} + \pi_\text{child} \cdot g(\text{parent})^2)\). Binary child: \(g = \text{PE} \otimes g(\text{parent}) / (\pi_\text{parent} + g(\text{parent})^2)\).

  • pure_natural (kind="pure_natural"): Riemannian natural gradient under the parent’s precision metric. Gaussian child: \(g = \text{PE} \otimes g(\text{parent}) \cdot \pi_\text{child} / \pi_\text{parent}\). Binary child: \(g = \text{PE} \otimes g(\text{parent}) / \pi_\text{parent}\).

Parameters:
  • parent_state (LayerState) – Current state of the parent layer.

  • child_state (LayerState) – Current state of the child layer (with observations).

  • coupling_fn (Callable) – Coupling function applied to parent means.

  • kind (str) – Gradient computation mode.

  • parent_has_constant (bool) – If True, the parent layer has a constant input node (mean = 1.0, precision = 1.0) appended to its activations after coupling.

  • child_is_binary (bool) – If True, the child layer is a binary node — drops the redundant precision factor in precision_weighted / map_natural / pure_natural modes (Bernoulli Fisher cancels through sigmoid).

Returns:

Descent gradient, same shape as weights. NaN / inf entries are zeroed out so optax does not propagate them through its moment accumulators.

Return type:

grad

Raises:

ValueError – If kind is unrecognised.