# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk>
"""Equinox PyTree types for the vectorised deep network."""
from __future__ import annotations
from typing import Callable, Optional
import equinox as eqx
import jax
import jax.numpy as jnp
from equinox import field
from jax import Array
[docs]
class LayerState(eqx.Module):
"""Vectorised per-layer state, as an ``eqx.Module``.
Each field is an array with one entry per node in the layer.
Parameters
----------
mean :
The posterior mean of the value level.
precision :
The posterior precision of the value level.
expected_mean :
The predicted (expected) mean of the value level.
expected_precision :
The marginal predicted precision of the value level.
conditional_expected_precision :
The conditional predicted precision of the value level used by the
structured-Gaussian (smoothing) update.
effective_precision :
The effective precision of the value-level prediction.
value_prediction_error :
The value prediction error of the value level.
mean_vol :
The posterior mean of the volatility level.
precision_vol :
The posterior precision of the volatility level.
expected_mean_vol :
The predicted (expected) mean of the volatility level.
expected_precision_vol :
The marginal predicted precision of the volatility level.
effective_precision_vol :
The effective precision of the volatility-level prediction.
volatility_prediction_error :
The volatility prediction error of the volatility level.
"""
# Value level (external)
mean: Array
precision: Array
expected_mean: Array
expected_precision: Array
conditional_expected_precision: Array
effective_precision: Array
value_prediction_error: Array
# Volatility level (internal)
mean_vol: Array
precision_vol: Array
expected_mean_vol: Array
expected_precision_vol: Array
effective_precision_vol: Array
volatility_prediction_error: Array
@classmethod
def create(cls, n_nodes: int) -> "LayerState":
"""Initialise a layer state with defaults."""
return cls(
mean=jnp.zeros(n_nodes),
precision=jnp.ones(n_nodes),
expected_mean=jnp.zeros(n_nodes),
expected_precision=jnp.ones(n_nodes),
conditional_expected_precision=jnp.ones(n_nodes),
effective_precision=jnp.zeros(n_nodes),
value_prediction_error=jnp.zeros(n_nodes),
mean_vol=jnp.zeros(n_nodes),
precision_vol=jnp.ones(n_nodes),
expected_mean_vol=jnp.zeros(n_nodes),
expected_precision_vol=jnp.ones(n_nodes),
effective_precision_vol=jnp.zeros(n_nodes),
volatility_prediction_error=jnp.zeros(n_nodes),
)
[docs]
class LayerParams(eqx.Module):
"""Per-layer static parameters.
Each field is an array with one entry per node in the layer.
Parameters
----------
tonic_volatility :
The tonic (baseline) volatility of the value level.
tonic_volatility_vol :
The tonic (baseline) volatility of the volatility level.
volatility_coupling :
The volatility-coupling strength between the value and volatility levels.
autoconnection_strength_vol :
The autoconnection (self-coupling) strength of the volatility level.
"""
tonic_volatility: Array
tonic_volatility_vol: Array
volatility_coupling: Array
autoconnection_strength_vol: Array
@classmethod
def create(
cls,
n_nodes: int,
tonic_volatility: float = -4.0,
tonic_volatility_vol: float = -4.0,
volatility_coupling: float = 1.0,
autoconnection_strength_vol: float = 1.0,
) -> "LayerParams":
"""Initialise layer params with defaults."""
return cls(
tonic_volatility=jnp.full(n_nodes, tonic_volatility),
tonic_volatility_vol=jnp.full(n_nodes, tonic_volatility_vol),
volatility_coupling=jnp.full(n_nodes, volatility_coupling),
autoconnection_strength_vol=jnp.full(n_nodes, autoconnection_strength_vol),
)
[docs]
class Layer(eqx.Module):
"""One layer of the vectorised deep network.
``weights_in`` is the matrix connecting the layer *below* (child) into this layer
(parent). The bottom layer (index 0) has ``weights_in=None`` because no layer sits
below it. Shape: ``(n_child, n_self[+1])``; the optional ``+1`` column carries the
bias when ``add_constant_input=True``.
Parameters
----------
state :
The per-layer state (see :py:class:`LayerState`).
params :
The per-layer static parameters (see :py:class:`LayerParams`).
weights_in :
The matrix connecting the layer below (child) into this layer, or `None`
for the bottom layer.
coupling_fn :
The coupling function applied to the incoming weights.
add_constant_input :
Whether a constant (bias) input column is appended to the weights.
has_volatility_parent :
Whether the layer has a volatility parent.
is_input_layer :
Whether the layer is the input (bottom) layer of the network.
fully_connected :
Whether the incoming weights are fully connected.
kind :
The kind of layer, either ``"volatile"`` or ``"binary"``.
"""
state: LayerState
params: LayerParams
weights_in: Optional[Array]
coupling_fn: Callable = field(static=True)
add_constant_input: bool = field(static=True)
has_volatility_parent: bool = field(static=True)
is_input_layer: bool = field(static=True)
fully_connected: bool = field(static=True)
kind: str = field(static=True) # "volatile" | "binary"
[docs]
class LayerStack(eqx.Module):
"""N identical layers stacked into one PyTree with a leading ``(N,)`` axis.
``state``/``params`` have leading axis ``N`` (each field shape goes from
``(n_nodes,)`` to ``(N, n_nodes)``). ``weights_in`` goes from
``(n_child, n_self[+1])`` to ``(N, n_child, n_self[+1])``. Slice index 0 is the
*bottommost* slice in the stack (closest to layer 0 of the network); slice ``N-1``
is the topmost.
Validation constraints, enforced at build time:
* The layer immediately below the stack must have the same node count as the stack
width (so ``weights_in[0]`` shape matches).
* ``weights_in[k]`` for k > 0 is a square ``(W, W+bias)`` block connecting slice k
(parent) to slice k-1 (child) within the stack.
Parameters
----------
state :
The stacked per-layer state, each field with a leading ``(N,)`` axis.
params :
The stacked per-layer static parameters, each field with a leading
``(N,)`` axis.
weights_in :
The stacked incoming weight matrices, shape ``(N, n_child, n_self[+1])``.
coupling_fn :
The coupling function shared by all stacked layers.
add_constant_input :
Whether a constant (bias) input column is appended to the weights.
has_volatility_parent :
Whether the layers have a volatility parent.
fully_connected :
Whether the incoming weights are fully connected.
kind :
The kind of layer, either ``"volatile"`` or ``"binary"``.
n_layers :
The number of stacked layers ``N``.
"""
state: LayerState # each field shape: (N, n_nodes)
params: LayerParams # each field shape: (N, n_nodes)
weights_in: Array # shape: (N, n_child, n_self[+1])
coupling_fn: Callable = field(static=True)
add_constant_input: bool = field(static=True)
has_volatility_parent: bool = field(static=True)
fully_connected: bool = field(static=True)
kind: str = field(static=True)
n_layers: int = field(static=True)
[docs]
def stack_layers(layers: list) -> LayerStack:
"""Combine N identical ``Layer`` instances into a single ``LayerStack``.
All ``Layer``s must share static-field values (kind, coupling_fn,
add_constant_input, has_volatility_parent, fully_connected) and have ``weights_in``
of identical shape. Static fields are taken from the first layer; arrays are stacked
along a new axis 0.
Parameters
----------
layers :
The list of identical ``Layer`` instances to stack.
Returns
-------
layer_stack :
The combined :py:class:`LayerStack`.
"""
if not layers:
raise ValueError("Cannot stack an empty list of Layers.")
first = layers[0]
for k, lay in enumerate(layers):
if not isinstance(lay, Layer):
raise TypeError(f"layers[{k}] is not a Layer: {type(lay).__name__}")
for attr in (
"add_constant_input",
"has_volatility_parent",
"fully_connected",
"kind",
):
if getattr(lay, attr) != getattr(first, attr):
raise ValueError(
f"Cannot stack layers with differing static field {attr!r}: "
f"layers[0].{attr}={getattr(first, attr)!r}, "
f"layers[{k}].{attr}={getattr(lay, attr)!r}."
)
if lay.coupling_fn is not first.coupling_fn:
raise ValueError(
f"Cannot stack layers with differing coupling_fn identities. "
f"Hoist the function to module scope so all layers share it."
)
if lay.weights_in is None:
raise ValueError(
f"layers[{k}] has weights_in=None (bottom layer of the network "
f"can't be inside a LayerStack)."
)
if lay.weights_in.shape != first.weights_in.shape:
raise ValueError(
f"layers[{k}].weights_in.shape={lay.weights_in.shape} differs "
f"from layers[0].weights_in.shape={first.weights_in.shape}."
)
stacked_state = jax.tree_util.tree_map(
lambda *xs: jnp.stack(xs), *(lay.state for lay in layers)
)
stacked_params = jax.tree_util.tree_map(
lambda *xs: jnp.stack(xs), *(lay.params for lay in layers)
)
stacked_weights = jnp.stack([lay.weights_in for lay in layers])
return LayerStack(
state=stacked_state,
params=stacked_params,
weights_in=stacked_weights,
coupling_fn=first.coupling_fn,
add_constant_input=first.add_constant_input,
has_volatility_parent=first.has_volatility_parent,
fully_connected=first.fully_connected,
kind=first.kind,
n_layers=len(layers),
)
[docs]
class Network(eqx.Module):
"""Complete vectorised network state.
``time_step`` is *not* stored on the network — it is passed as a per-step input to
``propagation_step``, matching the nodalised backend's
``input_data(time_steps=...)`` API.
Optimiser state lives in a separate ``optax`` opt-state carried alongside
``Network`` in the scan carry; it is not part of the network PyTree.
``layers`` is a mixed tuple of ``Layer`` and ``LayerStack`` elements.
Parameters
----------
layers :
A mixed tuple of ``Layer`` and ``LayerStack`` elements, ordered from the
bottom (input) layer to the top.
volatility_updates :
The volatility update scheme, e.g. ``"unbounded"``.
max_posterior_precision :
The maximum posterior precision used to clip the precision updates.
"""
layers: tuple
volatility_updates: str = field(static=True)
max_posterior_precision: float = field(static=True)
precision_clipping_value: float = field(static=True, default=1e-6)
@property
def n_layers(self) -> int:
"""Number of *elements* (``Layer`` or ``LayerStack``) in the network.
A ``LayerStack`` counts as one element; use ``n_total_slices`` for the number of
unrolled layers.
"""
return len(self.layers)
@property
def n_total_slices(self) -> int:
"""Total unrolled layer count, expanding every ``LayerStack``."""
return sum(
(e.n_layers if isinstance(e, LayerStack) else 1) for e in self.layers
)
def get_layer_sizes(self) -> list[int]:
"""Per-element node count (one entry per ``Layer`` / ``LayerStack``)."""
out = []
for elem in self.layers:
if isinstance(elem, LayerStack):
out.append(elem.state.mean.shape[1]) # (N, n_nodes) -> n_nodes
else:
out.append(elem.state.mean.shape[0])
return out
def weights_tuple(self) -> tuple:
"""Per-element ``weights_in`` tuple, matched 1:1 to ``self.layers``."""
return tuple(elem.weights_in for elem in self.layers)
# ------------------------------------------------------------------
# Legacy-shape views used by existing tests and the Rust-parity
# cross-check. These are not used in the hot path — the kernels read
# ``layer.state`` / ``layer.weights_in`` directly. For ``LayerStack``
# elements these views flatten the stack into its constituent slices
# so consumers see the unrolled shape.
# ------------------------------------------------------------------
@property
def weights(self) -> tuple:
"""Tuple of weight matrices (legacy view).
Stacks are flattened. Each entry is a ``(n_child, n_self[+1])`` array.
The ``None`` slot on layer 0 is dropped, and any ``LayerStack`` is
expanded slice-by-slice.
"""
out = []
for elem in self.layers:
if isinstance(elem, LayerStack):
for k in range(elem.n_layers):
out.append(elem.weights_in[k])
elif elem.weights_in is not None:
out.append(elem.weights_in)
return tuple(out)
@property
def params(self) -> tuple:
"""Per-layer ``LayerParams`` tuple."""
out = []
for elem in self.layers:
if isinstance(elem, LayerStack):
for k in range(elem.n_layers):
out.append(jax.tree_util.tree_map(lambda x, k=k: x[k], elem.params))
else:
out.append(elem.params)
return tuple(out)
# Convenience constant: every ``LayerState`` field, ordered as declared. Pass
# to ``DeepNetwork.fit(record=RECORD_ALL)`` for the legacy "record everything"
# behaviour without enumerating the field list at the call site.
RECORD_ALL: tuple = tuple(LayerState.__dataclass_fields__.keys())