pyhgf.updates.vectorized.learning.vectorized_weight_gradient#
- pyhgf.updates.vectorized.learning.vectorized_weight_gradient(parent_state, child_state, coupling_fn, kind='precision_weighted', parent_has_constant=False, child_is_binary=False)[source]#
Per-layer weight gradient for the vectorised deep network.
Returns the descent gradient for the weight matrix. Sign-flipped from the natural “ascent” formulation so it composes with standard optax (apply_updates(weights, updates) performs
weights + updates; optax.sgd(lr).update(grad, state, w) returns-lr * grad; together they reproduce the legacyweights + lr * grule withg = -grad).The gradient is computed according to kind:
standard (
kind="standard"): \(g = \text{PE} \otimes g(\text{parent})\)precision_weighted (
kind="precision_weighted"): \(g = \text{PE} \otimes g(\text{parent}) \cdot \pi_\text{child}\)precision_ratio (
kind="precision_ratio"): Kalman-gain-style gain using the parent’s expected precision in the numerator. \(K = \pi_\text{parent} / (\pi_\text{parent} + \pi_\text{child})\), \(g = \text{PE} \otimes g(\text{parent}) \cdot K\)map_natural (
kind="map_natural"): MAP weight update from the predictive-coding free energy with a Gaussian weight prior whose precision is the parent’s expected precision; combines child precision (numerator) with parent prior plus per-weight Fisher curvature \(g(\text{parent})^2\) (denominator). Gaussian child: \(g = \text{PE} \otimes g(\text{parent}) \cdot \pi_\text{child} / (\pi_\text{parent} + \pi_\text{child} \cdot g(\text{parent})^2)\). Binary child: \(g = \text{PE} \otimes g(\text{parent}) / (\pi_\text{parent} + g(\text{parent})^2)\).pure_natural (
kind="pure_natural"): Riemannian natural gradient under the parent’s precision metric. Gaussian child: \(g = \text{PE} \otimes g(\text{parent}) \cdot \pi_\text{child} / \pi_\text{parent}\). Binary child: \(g = \text{PE} \otimes g(\text{parent}) / \pi_\text{parent}\).
- Parameters:
parent_state (LayerState) – Current state of the parent layer.
child_state (LayerState) – Current state of the child layer (with observations).
coupling_fn (Callable) – Coupling function applied to parent means.
kind (str) – Gradient computation mode.
parent_has_constant (bool) – If True, the parent layer has a constant input node (mean = 1.0, precision = 1.0) appended to its activations after coupling.
child_is_binary (bool) – If True, the child layer is a binary node — drops the redundant precision factor in
precision_weighted/map_natural/pure_naturalmodes (Bernoulli Fisher cancels through sigmoid).
- Returns:
Descent gradient, same shape as
weights. NaN / inf entries are zeroed out so optax does not propagate them through its moment accumulators.- Return type:
grad
- Raises:
ValueError – If kind is unrecognised.