Source code for pyhgf.updates.posterior.volatile.volatile_node_posterior_update_unbounded
# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk>
from functools import partial
import jax.numpy as jnp
from jax import jit
from jax.nn import sigmoid
from pyhgf.math import lambert_w0
[docs]
@partial(jit, static_argnames=("node_idx", "max_posterior_precision"))
def volatile_node_posterior_update_unbounded(
attributes: dict,
node_idx: int,
max_posterior_precision: float = 1e10,
) -> dict:
"""Update the volatility level using an unbounded quadratic approximation.
Implements the uhgf update: two quadratic expansions are blended via a
variational energy-based softmax weight, and the final posterior is the
moment-matched Gaussian of the resulting mixture.
Expansion 1 is centred at the prediction (prior mean).
Expansion 2 is centred at the approximate posterior mode found via the Lambert
W_0 function, which solves the mode equation exactly in the limit alpha -> 0.
Parameters
----------
attributes :
The attributes of the probabilistic nodes.
node_idx :
Pointer to the volatile node.
max_posterior_precision :
Upper bound applied to the volatility-level posterior precision write.
Default ``1e10``.
Returns
-------
dict
Updated attributes with ``precision_vol`` and ``mean_vol`` set.
"""
volatility_coupling = attributes[node_idx]["volatility_coupling_internal"]
time_step = attributes[-1]["time_step"]
previous_variance = jnp.maximum(
attributes[node_idx]["temp"]["current_variance"], 1e-128
) # previous-step variance (= 1 / precision at the previous step)
be_aux = (1.0 / attributes[node_idx]["precision"]) + (
attributes[node_idx]["mean"] - attributes[node_idx]["expected_mean"]
) ** 2
expected_mean_vol = attributes[node_idx]["expected_mean_vol"]
expected_precision_vol = attributes[node_idx]["expected_precision_vol"]
tonic_volatility = attributes[node_idx]["tonic_volatility"]
# All quantities that would otherwise pass through ``exp`` of a potentially
# large number are kept in log-space. Materialising ``v = exp(γ)`` is correct
# in the forward pass (downstream saturating uses are stable), but corrupts
# the backward pass: the local partial of a saturating expression is ``0``
# while ``d v / d γ = exp(γ) = ∞``, and ``0 · ∞ = NaN``. A single non-finite
# gradient anywhere in the scan turns the whole gradient into NaN, which
# forces NUTS to reject the trajectory and shrink the step size — blowing up
# the number of leapfrog evaluations per sample. The ``sigmoid``/``logaddexp``
# rewrites below match the direct forms for every finite input and stay
# gradient-safe at the saturation limits. Mirrors
# ``continuous_node_posterior_update_unbounded``.
log_time_step = jnp.log(time_step)
log_previous_variance = jnp.log(previous_variance)
# Canonical exponent at the prediction: γ = log(time_step) + volatility_coupling*expected_mean_vol + tonic_volatility
gamma_c = log_time_step + volatility_coupling * expected_mean_vol + tonic_volatility
# w_jm1 = 1/(1 + previous_variance/exp(γ)) = sigmoid(γ − log α).
w_jm1 = sigmoid(gamma_c - log_previous_variance)
# Volatility prediction error: da_jm1 = pihat_jm1 * be_aux - 1, with
# pihat_jm1 = expected_precision (set in the prediction step at mu_prev_j).
# Matches MATLAB/Julia, which pass da_jm1 in — not recomputed at expected_mean_vol.
da_jm1 = attributes[node_idx]["expected_precision"] * be_aux - 1.0
# ----------------------------------------------------------------------------------
# Expansion 1: quadratic at the prediction (prior mean)
# ----------------------------------------------------------------------------------
pi1 = expected_precision_vol + 0.5 * volatility_coupling**2 * w_jm1 * (1.0 - w_jm1)
mu1 = expected_mean_vol + (volatility_coupling * w_jm1 / (2.0 * pi1)) * da_jm1
# ----------------------------------------------------------------------------------
# Expansion 2: quadratic at the Lambert W_0 approximate mode
# ----------------------------------------------------------------------------------
pihat_y = expected_precision_vol / volatility_coupling**2
# Compute W_arg in log-space and cap at log(float_max) — matches MATLAB's
# "W_arg = exp(min(log_W_arg, log(realmax)))".
log_W_arg = jnp.log(be_aux) - jnp.log(2.0 * pihat_y) + 0.5 / pihat_y - gamma_c
log_float_max = jnp.log(jnp.finfo(jnp.result_type(log_W_arg)).max)
W_arg = jnp.exp(jnp.minimum(log_W_arg, log_float_max))
v_W = lambert_w0(W_arg)
y_star = gamma_c + v_W - 0.5 / pihat_y
x_star = (y_star - log_time_step - tonic_volatility) / volatility_coupling
# Log-space s2/w2/da2 — equivalent to the direct
# s2 = time_step * exp(volatility_coupling*x_star + tonic_volatility); w2 = 1/(1 + previous_variance/s2);
# da2 = be_aux/(previous_variance + s2) - 1
# but without materialising ``s2 = inf`` (which injects 0·∞ NaN gradients).
log_s2 = log_time_step + volatility_coupling * x_star + tonic_volatility
log_denom_s = jnp.logaddexp(
log_previous_variance, log_s2
) # = log(previous_variance + s2)
w2 = sigmoid(log_s2 - log_previous_variance)
da2 = be_aux * jnp.exp(-log_denom_s) - 1.0
pi2_full = expected_precision_vol + 0.5 * volatility_coupling**2 * w2 * (
w2 + (2.0 * w2 - 1.0) * da2
)
# Guard against negative precision (Matlab fallback: use w2*(1-w2) form)
pi2_safe = jnp.where(
pi2_full <= 0.0,
expected_precision_vol + 0.5 * volatility_coupling**2 * w2 * (1.0 - w2),
pi2_full,
)
mu2_safe = (
x_star
+ (
0.5 * volatility_coupling * w2 * da2
- expected_precision_vol * (x_star - expected_mean_vol)
)
/ pi2_safe
)
# Fall back to Expansion 1 if Expansion 2 yields non-finite results —
# matches MATLAB: "if ~isfinite(pi2) || ~isfinite(mu2), pi2 = pi1; mu2 = mu1".
#
# Double-where masking: replace any non-finite ``pi2_safe`` / ``mu2_safe``
# with safe constants *before* they enter the outer ``where``. The bare form
# ``jnp.where(c, pi2_safe, pi1)`` is correct forward but poisons the backward
# pass — ``where``'s VJP routes a zero cotangent into the masked-out branch
# and ``0 * NaN = NaN``.
exp2_finite = jnp.isfinite(pi2_safe) & jnp.isfinite(mu2_safe)
pi2_safe_for_grad = jnp.where(exp2_finite, pi2_safe, 1.0)
mu2_safe_for_grad = jnp.where(exp2_finite, mu2_safe, 0.0)
pi2 = jnp.where(exp2_finite, pi2_safe_for_grad, pi1)
mu2 = jnp.where(exp2_finite, mu2_safe_for_grad, mu1)
# ----------------------------------------------------------------------------------
# Variational energy-based softmax blend (log-space form, gradient-safe).
# The direct ``ey = time_step * exp(volatility_coupling*mu + tonic_volatility)`` materialises ``inf`` for large
# ``volatility_coupling*mu + tonic_volatility`` and injects 0·∞ NaNs in the backward pass; ``logaddexp`` and
# ``exp(-positive)`` stay bounded both forward and backward.
# ----------------------------------------------------------------------------------
log_ey1 = log_time_step + volatility_coupling * mu1 + tonic_volatility
log_denom_1 = jnp.logaddexp(
log_previous_variance, log_ey1
) # = log(previous_variance + ey1)
I1 = (
-0.5 * log_denom_1
- 0.5 * be_aux * jnp.exp(-log_denom_1)
- 0.5 * expected_precision_vol * (mu1 - expected_mean_vol) ** 2
)
log_ey2 = log_time_step + volatility_coupling * mu2 + tonic_volatility
log_denom_2 = jnp.logaddexp(log_previous_variance, log_ey2)
I2 = (
-0.5 * log_denom_2
- 0.5 * be_aux * jnp.exp(-log_denom_2)
- 0.5 * expected_precision_vol * (mu2 - expected_mean_vol) ** 2
)
# Stable sigmoid matches b = 1/(1 + exp(I1 - I2)) without NaN at ±∞.
b = sigmoid(I2 - I1)
# ----------------------------------------------------------------------------------
# Gaussian mixture moment matching
# ----------------------------------------------------------------------------------
posterior_mean = (1.0 - b) * mu1 + b * mu2
sig2 = (1.0 - b) / pi1 + b / pi2 + b * (1.0 - b) * (mu1 - mu2) ** 2
posterior_precision = 1.0 / sig2
attributes[node_idx]["precision_vol"] = jnp.minimum(
posterior_precision, max_posterior_precision
)
attributes[node_idx]["mean_vol"] = posterior_mean
return attributes