Source code for pyhgf.updates.posterior.volatile.volatile_node_posterior_update_unbounded

# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk>

from functools import partial

import jax.numpy as jnp
from jax import jit
from jax.nn import sigmoid

from pyhgf.math import lambert_w0


[docs] @partial(jit, static_argnames=("node_idx",)) def volatile_node_posterior_update_unbounded( attributes: dict, node_idx: int, ) -> dict: """Update the volatility level using an unbounded quadratic approximation. Implements the uhgf update: two quadratic expansions are blended via a variational energy-based softmax weight, and the final posterior is the moment-matched Gaussian of the resulting mixture. Expansion 1 is centred at the prediction (prior mean). Expansion 2 is centred at the approximate posterior mode found via the Lambert W_0 function, which solves the mode equation exactly in the limit alpha -> 0. Parameters ---------- attributes : The attributes of the probabilistic nodes. node_idx : Pointer to the volatile node. Returns ------- dict Updated attributes with ``precision_vol`` and ``mean_vol`` set. """ volatility_coupling = attributes[node_idx]["volatility_coupling_internal"] t_k = attributes[-1]["time_step"] al_aux = jnp.maximum( attributes[node_idx]["temp"]["current_variance"], 1e-128 ) # 1/pi_prev_jm1 be_aux = (1.0 / attributes[node_idx]["precision"]) + ( attributes[node_idx]["mean"] - attributes[node_idx]["expected_mean"] ) ** 2 muhat_j = attributes[node_idx]["expected_mean_vol"] pihat_j = attributes[node_idx]["expected_precision_vol"] ka = volatility_coupling om = attributes[node_idx]["tonic_volatility"] # Canonical exponent at the prediction: y = log(t_k) + ka*muhat_j + om gamma_c = jnp.log(t_k) + ka * muhat_j + om # Recompute v and w using muhat_j. The w formula is written as # 1/(1 + al_aux/v) so it stays finite when v_jm1 overflows to ∞ (→ 1). v_jm1 = jnp.exp(gamma_c) w_jm1 = 1.0 / (1.0 + al_aux / v_jm1) # Volatility prediction error: da_jm1 = pihat_jm1 * be_aux - 1, with # pihat_jm1 = expected_precision (set in the prediction step at mu_prev_j). # Matches MATLAB/Julia, which pass da_jm1 in — not recomputed at muhat_j. da_jm1 = attributes[node_idx]["expected_precision"] * be_aux - 1.0 # ---------------------------------------------------------------------------------- # Expansion 1: quadratic at the prediction (prior mean) # ---------------------------------------------------------------------------------- pi1 = pihat_j + 0.5 * ka**2 * w_jm1 * (1.0 - w_jm1) mu1 = muhat_j + (ka * w_jm1 / (2.0 * pi1)) * da_jm1 # ---------------------------------------------------------------------------------- # Expansion 2: quadratic at the Lambert W_0 approximate mode # ---------------------------------------------------------------------------------- pihat_y = pihat_j / ka**2 # Compute W_arg in log-space and cap at log(float_max) — matches MATLAB's # "W_arg = exp(min(log_W_arg, log(realmax)))". log_W_arg = jnp.log(be_aux) - jnp.log(2.0 * pihat_y) + 0.5 / pihat_y - gamma_c log_float_max = jnp.log(jnp.finfo(jnp.result_type(log_W_arg)).max) W_arg = jnp.exp(jnp.minimum(log_W_arg, log_float_max)) v_W = lambert_w0(W_arg) y_star = gamma_c + v_W - 0.5 / pihat_y x_star = (y_star - jnp.log(t_k) - om) / ka # Rearranged w/da formulas stay finite when s2 overflows (→ w=1, da=-1). s2 = t_k * jnp.exp(ka * x_star + om) w2 = 1.0 / (1.0 + al_aux / s2) da2 = be_aux / (al_aux + s2) - 1.0 pi2_full = pihat_j + 0.5 * ka**2 * w2 * (w2 + (2.0 * w2 - 1.0) * da2) # Guard against negative precision (Matlab fallback: use w2*(1-w2) form) pi2_safe = jnp.where( pi2_full <= 0.0, pihat_j + 0.5 * ka**2 * w2 * (1.0 - w2), pi2_full, ) mu2_safe = x_star + (0.5 * ka * w2 * da2 - pihat_j * (x_star - muhat_j)) / pi2_safe # Fall back to Expansion 1 if Expansion 2 yields non-finite results — # matches MATLAB: "if ~isfinite(pi2) || ~isfinite(mu2), pi2 = pi1; mu2 = mu1". exp2_finite = jnp.isfinite(pi2_safe) & jnp.isfinite(mu2_safe) pi2 = jnp.where(exp2_finite, pi2_safe, pi1) mu2 = jnp.where(exp2_finite, mu2_safe, mu1) # ---------------------------------------------------------------------------------- # Variational energy-based softmax blend (direct form, matches MATLAB) # ---------------------------------------------------------------------------------- ey1 = t_k * jnp.exp(ka * mu1 + om) I1 = ( -0.5 * jnp.log(al_aux + ey1) - 0.5 * be_aux / (al_aux + ey1) - 0.5 * pihat_j * (mu1 - muhat_j) ** 2 ) ey2 = t_k * jnp.exp(ka * mu2 + om) I2 = ( -0.5 * jnp.log(al_aux + ey2) - 0.5 * be_aux / (al_aux + ey2) - 0.5 * pihat_j * (mu2 - muhat_j) ** 2 ) # Stable sigmoid matches b = 1/(1 + exp(I1 - I2)) without NaN at ±∞. b = sigmoid(I2 - I1) # ---------------------------------------------------------------------------------- # Gaussian mixture moment matching # ---------------------------------------------------------------------------------- posterior_mean = (1.0 - b) * mu1 + b * mu2 sig2 = (1.0 - b) / pi1 + b / pi2 + b * (1.0 - b) * (mu1 - mu2) ** 2 posterior_precision = 1.0 / sig2 attributes[node_idx]["precision_vol"] = posterior_precision attributes[node_idx]["mean_vol"] = posterior_mean return attributes