Source code for pyhgf.updates.posterior.volatile.volatile_node_posterior_update
# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk>
# Author: Aleksandrs Baskakovs <aleks@cas.au.dk>
from functools import partial
from jax import jit
from pyhgf.typing import Edges
from .posterior_update_value_level import (
posterior_update_mean_value_level,
posterior_update_precision_value_level,
)
from .posterior_update_volatility_level import (
posterior_update_mean_volatility_level,
posterior_update_precision_volatility_level,
)
[docs]
@partial(jit, static_argnames=("edges", "node_idx"))
def volatile_node_posterior_update(
attributes: dict,
edges: Edges,
node_idx: int,
) -> dict:
"""Update a volatile node and the implied volatility parent.
Unlike the standard continuous-state posterior updates elsewhere in the toolbox,
the volatile-state updates use the *expected* mean (i.e. the prediction) as the
reference point rather than the posterior mean. This choice is made to better suit
deep learning networks where the prediction serves as the natural reference for
computing updates.
"""
# Update precision first
precision_value = posterior_update_precision_value_level(
attributes, edges, node_idx
)
attributes[node_idx]["precision"] = precision_value
# Update mean using new precision
mean_value = posterior_update_mean_value_level(
attributes, edges, node_idx, precision_value
)
attributes[node_idx]["mean"] = mean_value
return attributes
[docs]
@partial(jit, static_argnames=("node_idx",))
def volatile_node_volatility_posterior_update_standard(
attributes: dict,
node_idx: int,
) -> dict:
"""Update the volatility level using the standard ordering.
This updates the implicit volatility parent's mean and precision using the
standard ordering: precision first, then mean using the updated precision.
"""
# Update precision first
precision_vol = posterior_update_precision_volatility_level(attributes, node_idx)
attributes[node_idx]["precision_vol"] = precision_vol
# Update mean using the new precision
mean_vol = posterior_update_mean_volatility_level(
attributes, node_idx, precision_vol
)
attributes[node_idx]["mean_vol"] = mean_vol
return attributes