Source code for pyhgf.updates.prediction.volatile
from functools import partial
import jax.numpy as jnp
from jax import Array, jit
from pyhgf.typing import Edges
[docs]
@partial(jit, static_argnames=("node_idx",))
def predict_precision_volatility_level(
attributes: dict,
node_idx: int,
) -> tuple[Array, Array]:
"""Predict the precision of the implicit volatility level."""
time_step = attributes[-1]["time_step"]
# Get volatility level parameters
precision_vol = attributes[node_idx]["precision_vol"]
tonic_volatility_vol = attributes[node_idx]["tonic_volatility_vol"]
# Compute predicted volatility for the volatility level
predicted_volatility_vol = time_step * jnp.exp(tonic_volatility_vol)
predicted_volatility_vol = jnp.where(
predicted_volatility_vol > 1e-128, predicted_volatility_vol, jnp.nan
)
# Expected precision
expected_precision_vol = 1 / ((1 / precision_vol) + predicted_volatility_vol)
# Effective precision
effective_precision_vol = predicted_volatility_vol * expected_precision_vol
return expected_precision_vol, effective_precision_vol
[docs]
@partial(jit, static_argnames=("edges", "node_idx"))
def predict_mean_value_level(
attributes: dict,
edges: Edges,
node_idx: int,
) -> Array:
"""Predict the mean of the value level (external facing).
This uses value parents if they exist.
"""
time_step = attributes[-1]["time_step"]
# List the node's value parents
value_parents_idxs = edges[node_idx].value_parents
# Get the drift rate from the node
driftrate = 0.0
# Look at the (optional) value parents for this node
if value_parents_idxs is not None:
for value_parent_idx, value_coupling_parent in zip(
value_parents_idxs,
attributes[node_idx]["value_coupling_parents"],
):
# Get the coupling function
child_position = edges[value_parent_idx].value_children.index(node_idx)
coupling_fn = edges[value_parent_idx].coupling_fn[child_position]
if coupling_fn is None:
parent_value = attributes[value_parent_idx]["expected_mean"]
else:
parent_value = coupling_fn(
attributes[value_parent_idx]["expected_mean"]
)
driftrate += value_coupling_parent * parent_value
# The new expected mean from the previous value
expected_mean = (
attributes[node_idx]["autoconnection_strength"] * attributes[node_idx]["mean"]
) + (time_step * driftrate)
return expected_mean
[docs]
@partial(jit, static_argnames=("node_idx",))
def predict_precision_value_level(
attributes: dict,
node_idx: int,
) -> tuple[Array, Array]:
"""Predict the precision of the value level using the implicit volatility level.
The volatility level's mean modulates the value level's precision.
"""
time_step = attributes[-1]["time_step"]
# Get value level parameters
precision = attributes[node_idx]["precision"]
tonic_volatility = attributes[node_idx]["tonic_volatility"]
# Get volatility level's expected mean (already computed)
expected_mean_vol = attributes[node_idx]["expected_mean_vol"]
# Get internal coupling strength
volatility_coupling_internal = attributes[node_idx]["volatility_coupling_internal"]
# Total volatility = tonic + contribution from implicit volatility parent
total_volatility = tonic_volatility + (
volatility_coupling_internal * expected_mean_vol
)
# Compute predicted volatility
predicted_volatility = time_step * jnp.exp(total_volatility)
predicted_volatility = jnp.where(
predicted_volatility > 1e-128, predicted_volatility, jnp.nan
)
# Expected precision
expected_precision = 1 / ((1 / precision) + predicted_volatility)
# Effective precision
effective_precision = predicted_volatility * expected_precision
return expected_precision, effective_precision
[docs]
@partial(jit, static_argnames=("edges", "node_idx"))
def volatile_node_prediction(
attributes: dict, node_idx: int, edges: Edges, **args
) -> dict:
"""Update the expected mean and expected precision of a value-volatility node.
This node has two internal levels:
1. Volatility level (implicit, internal)
2. Value level (external facing)
The volatility level predicts first, then affects the value level's precision.
"""
# Store current variance for potential unbounded updates
attributes[node_idx]["temp"]["current_variance"] = (
1 / attributes[node_idx]["precision"]
)
# 1. PREDICT VOLATILITY LEVEL (implicit internal state)
expected_precision_vol, effective_precision_vol = (
predict_precision_volatility_level(attributes, node_idx)
)
attributes[node_idx]["expected_mean_vol"] = attributes[node_idx]["mean_vol"]
attributes[node_idx]["expected_precision_vol"] = expected_precision_vol
attributes[node_idx]["temp"]["effective_precision_vol"] = effective_precision_vol
# 2. PREDICT VALUE LEVEL (external facing)
# Value level's precision depends on volatility level
expected_precision, effective_precision = predict_precision_value_level(
attributes, node_idx
)
# Value level's mean
expected_mean = predict_mean_value_level(attributes, edges, node_idx)
attributes[node_idx]["expected_mean"] = expected_mean
attributes[node_idx]["expected_precision"] = expected_precision
attributes[node_idx]["temp"]["effective_precision"] = effective_precision
return attributes