Source code for pyhgf.updates.posterior.volatile.volatile_node_posterior_update

# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk>
# Author: Aleksandrs Baskakovs <aleks@cas.au.dk>

from functools import partial

import jax.numpy as jnp
from jax import jit

from pyhgf.typing import Edges

from .posterior_update_value_level import (
    posterior_update_mean_value_level,
    posterior_update_mean_value_level_mean_field,
    posterior_update_precision_value_level,
    posterior_update_precision_value_level_mean_field,
)
from .posterior_update_volatility_level import (
    posterior_update_mean_volatility_level,
    posterior_update_precision_volatility_level,
)


[docs] @partial(jit, static_argnames=("edges", "node_idx", "max_posterior_precision")) def volatile_node_posterior_update( attributes: dict, edges: Edges, node_idx: int, max_posterior_precision: float = 1e10, ) -> dict: """Update a volatile node and the implied volatility parent. Unlike the standard continuous-state posterior updates elsewhere in the toolbox, the volatile-state updates use the *expected* mean (i.e. the prediction) as the reference point rather than the posterior mean. This choice is made to better suit deep learning networks where the prediction serves as the natural reference for computing updates. Parameters ---------- attributes : The attributes of the probabilistic nodes. edges : The edges of the probabilistic nodes as a tuple of :py:class:`pyhgf.typing.Indexes`. node_idx : Pointer to the volatile node that needs to be updated. max_posterior_precision : Upper bound applied to the value-level posterior precision write. Default ``1e10``. """ # Update precision first. Clip both ends: # * upper bound at ``max_posterior_precision`` (precision blow-up guard); # * lower bound at the *expected* precision — the nonlinear g'' term in # the value-coupling contribution can swing strongly negative when a # child's PE flips sign. precision_value = posterior_update_precision_value_level( attributes, edges, node_idx ) precision_value = jnp.clip( precision_value, a_min=attributes[node_idx]["expected_precision"], a_max=max_posterior_precision, ) attributes[node_idx]["precision"] = precision_value # Update mean using new precision mean_value = posterior_update_mean_value_level( attributes, edges, node_idx, precision_value ) attributes[node_idx]["mean"] = mean_value return attributes
[docs] @partial(jit, static_argnames=("node_idx", "max_posterior_precision")) def volatile_node_volatility_posterior_update_standard( attributes: dict, node_idx: int, max_posterior_precision: float = 1e10, ) -> dict: """Update the volatility level using the standard ordering. This updates the implicit volatility parent's mean and precision using the standard ordering: precision first, then mean using the updated precision. Parameters ---------- attributes : The attributes of the probabilistic nodes. node_idx : Pointer to the volatile node that needs to be updated. max_posterior_precision : Upper bound applied to the volatility-level posterior precision write. Default ``1e10``. """ # Update precision first precision_vol = posterior_update_precision_volatility_level(attributes, node_idx) precision_vol = jnp.minimum(precision_vol, max_posterior_precision) attributes[node_idx]["precision_vol"] = precision_vol # Update mean using the new precision mean_vol = posterior_update_mean_volatility_level( attributes, node_idx, precision_vol ) attributes[node_idx]["mean_vol"] = mean_vol return attributes
[docs] @partial(jit, static_argnames=("edges", "node_idx", "max_posterior_precision")) def volatile_node_posterior_update_mean_field( attributes: dict, edges: Edges, node_idx: int, max_posterior_precision: float = 1e10, ) -> dict: """Update a volatile node and the implied volatility parent with mean-field. Unlike the standard continuous-state posterior updates elsewhere in the toolbox, the volatile-state updates use the *expected* mean (i.e. the prediction) as the reference point rather than the posterior mean. This choice is made to better suit deep learning networks where the prediction serves as the natural reference for computing updates. Parameters ---------- attributes : The attributes of the probabilistic nodes. edges : The edges of the probabilistic nodes as a tuple of :py:class:`pyhgf.typing.Indexes`. node_idx : Pointer to the volatile node that needs to be updated. max_posterior_precision : Upper bound applied to the value-level posterior precision write. Default ``1e10``. """ precision_value = posterior_update_precision_value_level_mean_field( attributes, edges, node_idx ) precision_value = jnp.clip( precision_value, a_min=attributes[node_idx]["expected_precision"], a_max=max_posterior_precision, ) attributes[node_idx]["precision"] = precision_value mean_value = posterior_update_mean_value_level_mean_field( attributes, edges, node_idx, precision_value ) attributes[node_idx]["mean"] = mean_value return attributes