Source code for pyhgf.utils.vectorized_belief_propagation

# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk>
# Author: Aleksandrs Baskakovs <aleks@cas.au.dk>

"""Vectorized belief propagation step for deep predictive coding networks."""

from __future__ import annotations

import dataclasses

import equinox as eqx
import jax
import jax.numpy as jnp
import optax

from pyhgf.typing.vectorised import Layer, LayerStack, Network
from pyhgf.updates.vectorized.binary import (
    vectorized_binary_prediction,
    vectorized_binary_prediction_error,
)
from pyhgf.updates.vectorized.learning import vectorized_weight_gradient
from pyhgf.updates.vectorized.volatile import (
    vectorized_layer_posterior_update,
    vectorized_layer_prediction,
    vectorized_layer_prediction_error,
)

# ---------------------------------------------------------------------------
# Element-shape helpers
# ---------------------------------------------------------------------------


def _bottom_slice(stack: LayerStack):
    """Return ``(state, params, weights_in)`` of the *bottommost* stack slice."""
    state = jax.tree_util.tree_map(lambda x: x[0], stack.state)
    params = jax.tree_util.tree_map(lambda x: x[0], stack.params)
    return state, params, stack.weights_in[0]


def _top_slice(stack: LayerStack):
    """Return ``(state, params, weights_in)`` of the *topmost* stack slice."""
    state = jax.tree_util.tree_map(lambda x: x[-1], stack.state)
    params = jax.tree_util.tree_map(lambda x: x[-1], stack.params)
    return state, params, stack.weights_in[-1]


def _parent_view(elem):
    """Treat a ``Layer`` or ``LayerStack`` uniformly when acting as a parent.

    Returns ``(state, weights_in, coupling_fn, add_constant_input)``. The four pieces
    ``propagation_step`` needs to predict a child below.

    For a ``LayerStack``, the parent is the *bottommost* slice (the slice closest to the
    child below the stack).
    """
    if isinstance(elem, LayerStack):
        state, _, weights = _bottom_slice(elem)
        return state, weights, elem.coupling_fn, elem.add_constant_input
    return elem.state, elem.weights_in, elem.coupling_fn, elem.add_constant_input


def _child_view(elem):
    """Treat a ``Layer`` or ``LayerStack`` uniformly when acting as a child.

    Returns ``(state, kind, is_input_layer)``. What's needed when something above is
    doing a posterior update or computing PE-driven weight grads using this element's
    state as the child.

    For a ``LayerStack``, the child role is filled by the *topmost* slice (the slice
    closest to the parent above the stack).
    """
    if isinstance(elem, LayerStack):
        state, _, _ = _top_slice(elem)
        return state, elem.kind, False  # interior; never an input layer
    return elem.state, elem.kind, elem.is_input_layer


# ---------------------------------------------------------------------------
# Top-down prediction
# ---------------------------------------------------------------------------


def _predict_layer_from_parent(
    child: Layer,
    parent_state,
    parent_weights,
    parent_coupling_fn,
    parent_has_constant: bool,
    *,
    time_step: float,
    precision_clipping_value: float,
):
    """Predict a single ``Layer`` child from a parent view."""
    if child.kind == "binary":
        new_state = vectorized_binary_prediction(
            child_state=child.state,
            parent_state=parent_state,
            weights=parent_weights,
            coupling_fn=parent_coupling_fn,
            parent_has_constant=parent_has_constant,
            precision_clipping_value=precision_clipping_value,
        )
    else:
        new_state = vectorized_layer_prediction(
            child_state=child.state,
            parent_state=parent_state,
            weights=parent_weights,
            params=child.params,
            time_step=time_step,
            coupling_fn=parent_coupling_fn,
            parent_has_constant=parent_has_constant,
            has_volatility_parent=child.has_volatility_parent,
            is_input_layer=child.is_input_layer,
        )
    return dataclasses.replace(child, state=new_state)


def _predict_stack_from_parent(
    stack: LayerStack,
    parent_state,
    parent_weights,
    parent_coupling_fn,
    parent_has_constant: bool,
    *,
    time_step: float,
):
    """Top-down sweep over a ``LayerStack``.

    Step 1 (boundary): predict the topmost slice from the external parent. Using the
    parent's coupling/weights/bias (which may differ from the stack's. For ``(L, S)``
    they're the layer's; for ``(S, S)`` they're the parent stack's bottommost slice).

    Step 2 (scan): predict slice ``k`` from slice ``k+1`` for ``k = N-2 ... 0``
    using ``stack.weights_in[k+1]`` and the stack's own coupling_fn / bias.
    Scan in reverse so the carry threads top-to-bottom through the stack.
    """
    top_slice_state, top_slice_params, _ = _top_slice(stack)
    new_top_state = vectorized_layer_prediction(
        child_state=top_slice_state,
        parent_state=parent_state,
        weights=parent_weights,
        params=top_slice_params,
        time_step=time_step,
        coupling_fn=parent_coupling_fn,
        parent_has_constant=parent_has_constant,
        has_volatility_parent=stack.has_volatility_parent,
        is_input_layer=False,
    )

    n = stack.n_layers
    if n == 1:
        # Degenerate: stack with a single slice. Just write the new state.
        new_state = jax.tree_util.tree_map(
            lambda x, v: x.at[0].set(v), stack.state, new_top_state
        )
        return dataclasses.replace(stack, state=new_state)

    # xs: per-iteration data for predicting slices N-2 ... 0 from the slice above.
    # At step k, body(parent_state, xs[k]) → predict slice k. The "parent's
    # weights" used to predict slice k come from slice k+1 — i.e. stack.weights_in[k+1].
    xs_child_state = jax.tree_util.tree_map(lambda x: x[: n - 1], stack.state)
    xs_child_params = jax.tree_util.tree_map(lambda x: x[: n - 1], stack.params)
    xs_parent_weights = stack.weights_in[1:]  # shape (n-1, ...)

    def body(parent_state_carry, k_data):
        child_state, child_params, parent_weights_k = k_data
        new_child_state = vectorized_layer_prediction(
            child_state=child_state,
            parent_state=parent_state_carry,
            weights=parent_weights_k,
            params=child_params,
            time_step=time_step,
            coupling_fn=stack.coupling_fn,
            parent_has_constant=stack.add_constant_input,
            has_volatility_parent=stack.has_volatility_parent,
            is_input_layer=False,
        )
        return new_child_state, new_child_state

    _, new_states_below = jax.lax.scan(
        body,
        init=new_top_state,
        xs=(xs_child_state, xs_child_params, xs_parent_weights),
        reverse=True,
    )

    # new_states_below has shape (n-1, ...) for slices 0..n-2;
    # new_top_state is for slice n-1. Concatenate along axis 0.
    new_full_state = jax.tree_util.tree_map(
        lambda below, top: jnp.concatenate([below, top[None, ...]], axis=0),
        new_states_below,
        new_top_state,
    )
    return dataclasses.replace(stack, state=new_full_state)


def _topdown_predict(
    parent_elem, child_elem, *, time_step: float, precision_clipping_value: float = 1e-6
):
    """Predict ``child_elem`` from ``parent_elem``.

    Both can be Layer or LayerStack.
    """
    parent_state, parent_weights, parent_coupling_fn, parent_has_const = _parent_view(
        parent_elem
    )
    if isinstance(child_elem, LayerStack):
        # LayerStacks are continuous/volatile only — the binary clip never applies.
        return _predict_stack_from_parent(
            child_elem,
            parent_state,
            parent_weights,
            parent_coupling_fn,
            parent_has_const,
            time_step=time_step,
        )
    return _predict_layer_from_parent(
        child_elem,
        parent_state,
        parent_weights,
        parent_coupling_fn,
        parent_has_const,
        time_step=time_step,
        precision_clipping_value=precision_clipping_value,
    )


# ---------------------------------------------------------------------------
# Leaf PE (bottom element of the network)
# ---------------------------------------------------------------------------


def _leaf_pe(layer: Layer, *, volatility_updates: str, max_posterior_precision: float):
    """Compute the PE of the bottom layer (a ``Layer``; leaves can't be stacks)."""
    if layer.kind == "binary":
        new_state = vectorized_binary_prediction_error(layer=layer.state)
    else:
        new_state = vectorized_layer_prediction_error(
            layer=layer.state,
            params=layer.params,
            volatility_updates=volatility_updates,
            has_volatility_parent=layer.has_volatility_parent,
            max_posterior_precision=max_posterior_precision,
        )
    return dataclasses.replace(layer, state=new_state)


# ---------------------------------------------------------------------------
# Bottom-up posterior + PE
# ---------------------------------------------------------------------------


def _posterior_pe_layer(
    parent: Layer,
    child_state,
    child_is_input_layer: bool,
    *,
    volatility_updates: str,
    max_posterior_precision: float,
):
    """Single-layer posterior update + PE."""
    new_state = vectorized_layer_posterior_update(
        layer=parent.state,
        child=child_state,
        weights=parent.weights_in,
        coupling_fn=parent.coupling_fn,
        parent_has_constant=parent.add_constant_input,
        max_posterior_precision=max_posterior_precision,
        child_is_input_layer=child_is_input_layer,
    )
    if parent.kind == "binary":
        new_state = vectorized_binary_prediction_error(layer=new_state)
    else:
        new_state = vectorized_layer_prediction_error(
            layer=new_state,
            params=parent.params,
            volatility_updates=volatility_updates,
            has_volatility_parent=parent.has_volatility_parent,
            max_posterior_precision=max_posterior_precision,
        )
    return dataclasses.replace(parent, state=new_state)


def _posterior_pe_stack(
    stack: LayerStack,
    child_state_init,
    *,
    volatility_updates: str,
    max_posterior_precision: float,
):
    """Bottom-up sweep over a ``LayerStack`` (posterior update + PE per slice).

    Scan forward from slice 0 (bottommost) to slice N-1 (topmost). The carry is the
    just-PE'd child state below the current slice; on the first iteration it's the
    external ``child_state_init``.

    ``child_is_input_layer=False`` throughout — Phase 8 v1 requires the layer below a
    stack to be non-leaf, so the boundary is interior.
    """

    def body(child_carry_state, slice_data):
        slice_state, slice_params, slice_weights = slice_data
        new_state = vectorized_layer_posterior_update(
            layer=slice_state,
            child=child_carry_state,
            weights=slice_weights,
            coupling_fn=stack.coupling_fn,
            parent_has_constant=stack.add_constant_input,
            max_posterior_precision=max_posterior_precision,
            child_is_input_layer=False,
        )
        new_state = vectorized_layer_prediction_error(
            layer=new_state,
            params=slice_params,
            volatility_updates=volatility_updates,
            has_volatility_parent=stack.has_volatility_parent,
            max_posterior_precision=max_posterior_precision,
        )
        return new_state, new_state

    _, new_full_state = jax.lax.scan(
        body,
        init=child_state_init,
        xs=(stack.state, stack.params, stack.weights_in),
    )
    return dataclasses.replace(stack, state=new_full_state)


def _bottomup_posterior_pe(
    parent_elem,
    child_elem,
    *,
    volatility_updates: str,
    max_posterior_precision: float,
):
    """Posterior update + PE for ``parent_elem`` using ``child_elem`` below."""
    child_state, _, child_is_input_layer = _child_view(child_elem)
    if isinstance(parent_elem, LayerStack):
        return _posterior_pe_stack(
            parent_elem,
            child_state,
            volatility_updates=volatility_updates,
            max_posterior_precision=max_posterior_precision,
        )
    return _posterior_pe_layer(
        parent_elem,
        child_state,
        child_is_input_layer,
        volatility_updates=volatility_updates,
        max_posterior_precision=max_posterior_precision,
    )


# ---------------------------------------------------------------------------
# Weight gradients
# ---------------------------------------------------------------------------


def _grad_layer(parent: Layer, child_elem, learning_kind: str):
    """Per-Layer weight gradient (same shape as ``parent.weights_in``)."""
    child_state, child_kind, _ = _child_view(child_elem)
    return vectorized_weight_gradient(
        parent_state=parent.state,
        child_state=child_state,
        coupling_fn=parent.coupling_fn,
        kind=learning_kind,
        parent_has_constant=parent.add_constant_input,
        child_is_binary=(child_kind == "binary"),
    )


def _grad_stack(stack: LayerStack, child_elem, learning_kind: str):
    """Per-slice weight gradients for a ``LayerStack``.

    The child of slice 0 is the layer below the stack (``child_elem``); the child of
    slice k>0 is slice k-1 within the stack. Pre-pend the external child's state to the
    stack's state along axis 0 to form an ``(N+1, ...)`` array, then ``vmap`` the per-
    slice grad over the N parent slices and the N child slots.
    """
    child_state, _, _ = _child_view(child_elem)

    combined_state = jax.tree_util.tree_map(
        lambda c, s: jnp.concatenate([c[None, ...], s], axis=0),
        child_state,
        stack.state,
    )
    # parent_states[k] = stack.state[k]; child_states[k] = combined[k]
    parent_states = stack.state
    child_states = jax.tree_util.tree_map(lambda x: x[:-1], combined_state)

    def per_slice(parent_state, child_state_for_slice):
        return vectorized_weight_gradient(
            parent_state=parent_state,
            child_state=child_state_for_slice,
            coupling_fn=stack.coupling_fn,
            kind=learning_kind,
            parent_has_constant=stack.add_constant_input,
            child_is_binary=False,
        )

    return jax.vmap(per_slice)(parent_states, child_states)


def _weight_grad(parent_elem, child_elem, learning_kind: str):
    """Weight gradient for ``parent_elem.weights_in``."""
    if isinstance(parent_elem, LayerStack):
        return _grad_stack(parent_elem, child_elem, learning_kind)
    return _grad_layer(parent_elem, child_elem, learning_kind)


# ---------------------------------------------------------------------------
# Element-level state writeback (for clamping x/y at the boundaries)
# ---------------------------------------------------------------------------


def _set_top_predictors(elem, x):
    """Clamp ``expected_mean`` and ``mean`` of the top element to ``x``.

    The top element must be a ``Layer`` (input layer).
    """
    if isinstance(elem, LayerStack):
        raise NotImplementedError("Top of network must be a Layer, not a LayerStack.")
    new_state = dataclasses.replace(elem.state, expected_mean=x, mean=x)
    return dataclasses.replace(elem, state=new_state)


def _set_bottom_observations(elem, y):
    """Clamp ``mean`` of the bottom element to ``y``.

    Must be a ``Layer`` (leaf).
    """
    if isinstance(elem, LayerStack):
        raise NotImplementedError(
            "Bottom of network must be a Layer, not a LayerStack."
        )
    new_state = dataclasses.replace(elem.state, mean=y)
    return dataclasses.replace(elem, state=new_state)


def _writeback_weights(elem, new_w):
    """Replace ``weights_in`` on a Layer or LayerStack with ``new_w``."""
    return dataclasses.replace(elem, weights_in=new_w)


# ---------------------------------------------------------------------------
# Top-level propagation step
# ---------------------------------------------------------------------------


[docs] def propagation_step( network: Network, opt_state: optax.OptState, inputs: tuple, *, optimizer: optax.GradientTransformation, time_step: float = 1.0, learning_kind: str = "precision_weighted", weight_update: bool = True, ) -> tuple[tuple[Network, optax.OptState], jnp.ndarray]: """Single propagation step through the network. Belief-propagation sweep (top-down prediction → leaf PE → interleaved posterior/PE bottom-up) followed by an optional weight-learning phase. Each step dispatches per element on ``Layer`` vs ``LayerStack``: * ``Layer`` → standard per-layer kernel call (unrolled). * ``LayerStack`` → ``jax.lax.scan`` over the stack's slices. Top and bottom elements must be ``Layer``s. A ``LayerStack``'s child below (and parent above) can themselves be ``Layer`` or ``LayerStack``; the stack-stack case requires the boundary widths to match. Parameters ---------- network : The current vectorised network state. opt_state : The current optax optimiser state. inputs : A tuple ``(x, y)`` with the predictors set on the top element and the observations clamped on the bottom element. optimizer : The optax optimiser used for the weight-learning phase. time_step : The time elapsed since the previous step. learning_kind : The weight-gradient mode passed to :py:func:`pyhgf.updates.vectorized.learning.vectorized_weight_gradient`. weight_update : Whether to apply the weight-learning phase after belief propagation. Returns ------- carry : A tuple ``((network, opt_state), surprise)`` where `network` and `opt_state` are updated and `surprise` is the step's surprise. """ x, y = inputs elements = list(network.layers) n_elements = len(elements) max_posterior_precision = network.max_posterior_precision volatility_updates = network.volatility_updates precision_clipping_value = network.precision_clipping_value # 1. Set predictors on the top element. elements[-1] = _set_top_predictors(elements[-1], x) # 2. Clamp observations on the bottom element. elements[0] = _set_bottom_observations(elements[0], y) # 3. Top-down prediction: predict each element from the one above. for i in range(n_elements - 1, 0, -1): elements[i - 1] = _topdown_predict( elements[i], elements[i - 1], time_step=time_step, precision_clipping_value=precision_clipping_value, ) # 4a. PE on the bottom (leaf) element. elements[0] = _leaf_pe( elements[0], volatility_updates=volatility_updates, max_posterior_precision=max_posterior_precision, ) # 4b. Interleaved bottom-up: posterior + PE on every interior element. for i in range(1, n_elements - 1): elements[i] = _bottomup_posterior_pe( elements[i], elements[i - 1], volatility_updates=volatility_updates, max_posterior_precision=max_posterior_precision, ) # 5. Weight learning — same optax flow as before, but element-shaped grads. if weight_update: weights = tuple(elem.weights_in for elem in elements) grads_list: list = [None] # bottom element has no weights_in for i in range(1, n_elements): grads_list.append(_weight_grad(elements[i], elements[i - 1], learning_kind)) grads = tuple(grads_list) updates, new_opt_state = optimizer.update(grads, opt_state, weights) new_weights = optax.apply_updates(weights, updates) for i, new_w in enumerate(new_weights): if new_w is not None: elements[i] = _writeback_weights(elements[i], new_w) else: new_opt_state = opt_state new_network = dataclasses.replace(network, layers=tuple(elements)) output_pred = new_network.layers[0].state.expected_mean return (new_network, new_opt_state), output_pred
# --------------------------------------------------------------------------- # Scan driver + prediction-only sweep (unchanged contract) # ---------------------------------------------------------------------------
[docs] @eqx.filter_jit def run_scan( init_carry: tuple, inputs: tuple, optimizer: optax.GradientTransformation, learning_kind: str, weight_update: bool, record: tuple, time_step: float = 1.0, ) -> tuple: r"""Run ``jax.lax.scan`` over the belief-propagation step. Decorated with ``eqx.filter_jit``: arrays in ``init_carry`` / ``inputs`` are dynamic; ``optimizer`` / ``learning_kind`` / ``weight_update`` / ``record`` / ``time_step`` are static and form the JIT cache key. Parameters ---------- init_carry : The initial scan carry, a tuple ``(network, opt_state)``. inputs : The per-step inputs scanned over, a tuple of predictor/observation arrays with a leading time axis. optimizer : The optax optimiser used for the weight-learning phase. learning_kind : The weight-gradient mode passed to :py:func:`pyhgf.updates.vectorized.learning.vectorized_weight_gradient`. weight_update : Whether to apply the weight-learning phase at every step. record : Tuple of ``LayerState`` field names to record at every time step (e.g. ``("expected_mean", "precision")``). An empty tuple disables recording. The scan output is just the per-step ``output_pred``. With a non-empty tuple, the per-step output is ``(traj_step, output_pred)`` where ``traj_step`` is ``dict[field_name, tuple[Array, ...]]`` (one per-element array per field, with ``LayerStack`` elements contributing arrays of shape ``(N, n_nodes)``). After ``scan`` stacks across time, each leaf carries a leading ``(T,)`` axis. time_step : Uniform inference time step :math:`\\Delta t` passed to every ``propagation_step`` call. Defaults to ``1.0``. Returns ------- ``((final_network, final_opt_state), step_output)`` where ``step_output`` is either the stacked predictions alone (``record == ()``) or a ``(stacked_traj, stacked_predictions)`` tuple. """ def _scan_body(carry, xs): network, opt_state = carry (new_network, new_opt_state), pred = propagation_step( network, opt_state, xs, optimizer=optimizer, time_step=time_step, learning_kind=learning_kind, weight_update=weight_update, ) if record: traj_step = { field: tuple( getattr(_state_for_record(elem), field) for elem in new_network.layers ) for field in record } return (new_network, new_opt_state), (traj_step, pred) return (new_network, new_opt_state), pred return jax.lax.scan(_scan_body, init_carry, inputs)
def _state_for_record(elem): """Return the ``LayerState`` to read trajectory fields from. For a ``Layer`` this is ``elem.state`` (shape ``(n_nodes,)`` per field). For a ``LayerStack`` it's the stacked state (shape ``(N, n_nodes)`` per field) — the user gets the whole stack's trajectory in one block. """ return elem.state
[docs] @eqx.filter_jit def prediction_pass(network: Network, x: jnp.ndarray) -> jnp.ndarray: """Forward-only sweep through the network (no PE / posterior / learning). Sets predictors on the top element and runs the top-down prediction pass; returns the bottom element's ``expected_mean``. Used by :meth:`pyhgf.model.DeepNetwork.predict`. Parameters ---------- network : The current vectorised network state. x : The predictors set on the top element. Returns ------- expected_mean : The bottom element's ``expected_mean`` after the forward sweep. """ elements = list(network.layers) n_elements = len(elements) elements[-1] = _set_top_predictors(elements[-1], x) for i in range(n_elements - 1, 0, -1): elements[i - 1] = _topdown_predict( elements[i], elements[i - 1], time_step=1.0, precision_clipping_value=network.precision_clipping_value, ) return elements[0].state.expected_mean