Deep Bayesian predictive coding#

Open In Colab

Hide code cell content

import sys

from IPython.utils import io

if "google.colab" in sys.modules:
    with io.capture_output() as captured:
        ! pip uninstall -y jax jaxlib
        ! pip install pyhgf watermark jax[cuda12]==0.4.31
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import treescope
import matplotlib.animation as animation
from IPython.display import HTML
from pyhgf.model import DeepNetwork, Network
from pyhgf.plots.graphviz.plot_network import plot_deep_network

np.random.seed(123)
plt.rcParams["figure.constrained_layout.use"] = True

treescope.basic_interactive_setup(autovisualize_arrays=True)

Warning

The features exposed here are still a work in progress.

The hierarchical Gaussian filter is built on top of a generative model that is governed by Gaussian random walks, which is best used to model the time-resolved evolution of beliefs in volatile environments. But the framework can easily extend to traditional applications of predictive coding, such as deep neural networks, where the variational message passing replaces the use of iterative gradient descent during inference.

In this notebook, we show that the prospective configuration in predictive coding networks can be performed by one-shot variational updates, removing the need for gradient descent over the energy function, while learning expected precision in the hidden layers, which is often fixed in other approaches. We illustrate this with the “bear example” from [Song et al., 2024] and on a classification task with deep networks.

https://media.springernature.com/full/springer-static/image/art%3A10.1038%2Fs41593-023-01514-1/MediaObjects/41593_2023_1514_Fig1_HTML.png?as=webp

Fig. 4 Learning with prospective configuration.#

Learning in deep networks#

Prospective configuration#

In standard backpropagation, weights are updated using gradients computed from a fixed forward pass. The network’s activations are “frozen” while the error signal propagates backward. Prospective configuration [Song et al., 2024] takes a different approach: before any weight change, the network first infers the most likely activations at every layer by settling prediction errors across the whole hierarchy. Only once this inference step has converged are the weights updated. This two-phase process (infer activations, then update weights) prevents the catastrophic interference that arises when a local weight change inadvertently distorts representations elsewhere in the network.

In pyhgf, this is implemented naturally through the belief propagation cycle. At each observation:

  1. Prediction: predictors (\(x\)) are provided in the leaf nodes. Each node generates a top-down prediction for its children via a nonlinear coupling function \(g(\cdot)\).

  2. Observation & prediction errors: the observed values (\(y\)) are compared with predictions, and precision-weighted prediction errors propagate upward through the hierarchy.

  3. Posterior update: node activations (means and precisions) are updated to minimise free energy, settling the network into a new equilibrium. This is comparable to the prospective configuration step.

  4. Weight update: only after the activations have settled are the coupling strengths (weights) adjusted using the prediction errors and the inferred activations.

We illustrate this using the “bear” example network from Song et al. [2024].

# here x represents the visual input (River / No River)
x = np.array([1.0, 1.0] * 120)
x += np.random.normal(size=x.shape) / 100

# y represents the auditory and olfactory stimuli
y = np.array([
    x,
    np.concat([
        np.array([1.0, 1.0] * 40),
        np.array([-1.0, -1.0] * 40),
        np.array([1.0, 1.0] * 40),
    ])
    + np.random.normal(size=x.shape) / 100,
]).T
# We start by defining a simple network with two branches
network = (
    Network(update_type="unbounded")
    .add_nodes(n_nodes=2, precision=2.0, expected_precision=2.0)
    .add_nodes(
        kind="volatile-state",
        value_children=[0, 1],
        autoconnection_strength=0,
        tonic_volatility=-4.0,
        coupling_fn=(jnp.tanh, jnp.tanh),
    )
    .add_nodes(
        value_children=2,
        autoconnection_strength=0,
        coupling_fn=(jnp.tanh,),
        precision=2.0,
        expected_precision=2.0,
    )
)
network.plot_network()
../_images/071d38088da92cd0df86766e888c4279c215ce522f133366f3797926d6ec1788.svg

Deep networks trained for classification purposes differ from other predictive coding networks as both roots and leaves should receive inputs (predictors and outcomes, respectively).

network.fit(
    x=x,
    y=y,
    inputs_x_idxs=(3,),
    inputs_y_idxs=(0, 1),
    lr="dynamic",
    record_trajectories=True,
    optimizer=None,
);

Hide code cell source

_, axs = plt.subplots(figsize=(9, 5), nrows=3, sharey=True, sharex=True)

network.plot_nodes(2, show_surprise=False, axs=axs[0])
network.plot_nodes(0, show_surprise=False, axs=axs[1])
network.plot_nodes(1, show_surprise=False, axs=axs[2])
axs[0].grid(linestyle="--")
axs[1].grid(linestyle="--")
axs[2].grid(linestyle="--")

sns.despine();
../_images/aadd161a23211e8c8199ace2e5d9dc5c164b6053a7bb7f31506a7ce4eea0edc3.png

This example illustrates the effectiveness of the variational update to replace the prospective configuration step based on gradient descent. As new observations contradict the expected outcomes (i.e., observing the river without hearing the water, from trials 40 to 80 in the bottom panel), the network efficiently reorganizes without interfering with other predictions (i.e,. still expecting smelling the salmon while seeing the water, top panel, trials 40 to 80).

Hint

Volatile state nodes In a predictive coding network, two types of state nodes can be used to build the hierarchy. Continuous-state nodes are the standard HGF nodes: their precision is controlled by external volatility parents connected through dedicated volatility edges, and their mean persists across time steps by default (autoconnection_strength=1.0). This makes them well-suited for tracking slowly drifting quantities or serving as input/output layers that receive observations. Volatile-state nodes bundle a value level and an internal volatility level into a single node, removing the need for separate volatility parents. Their mean resets at each time step by default (autoconnection_strength=0.0), making them behave like stateless hidden units whose activation is determined entirely by incoming predictions. In practice, volatile-state nodes are the natural building block for hidden layers in deep predictive coding networks; they are more parameter-efficient than continuous-state nodes because the volatility coupling is handled internally, and their stateless default mirrors the feedforward activations of a conventional neural network.

Weight update#

Once the network has settled into its new posterior (the prospective configuration step), the coupling strengths \(w_i\) between a child node and its value parents are updated. Let \(\text{PE}\) denote the value prediction error at the child node, \(\pi_{\text{child}}\) its posterior precision, and \(g(\mu_i)\) the activation of parent \(i\) passed through the coupling function \(g\).

Fixed learning rate. With a constant step size \(\eta\):

\[ \Delta w_i \;=\; \eta \;\cdot\; \text{PE} \;\cdot\; \pi_{\text{child}} \;\cdot\; g(\mu_i) \]

Dynamic (precision-weighted) learning rate. When no fixed rate is specified, the update uses a Kalman-gain-like rule that automatically scales the step size by the relative precision of parent and child:

\[ K_i = \frac{\pi_{\text{parent}_i}}{\pi_{\text{parent}_i} + \pi_{\text{child}}} \]
\[ \Delta w_i \;=\; K_i \;\cdot\; \text{PE} \;\cdot\; g(\mu_i) \]

In the dynamic case, precise parents exert a larger influence on the weight update while uncertain parents are updated more cautiously. This precision weighting is what produces the depth-dependent learning effect demonstrated in the next section.

Input precision controls the depth of weight updates#

One natural consequence that emerges from this framework is that neural activations are defined by their precision, and as a result, the precision of the inputs (both predictors and outcomes) controls the strengths of prediction errors and the amplitude of weight updates in the vicinity of information flows. For example, more precise outcomes will guide weight updates to be larger at the nodes close to the inputs, and a more precise predictor will guide weight updates to be larger close to the internal representation.

We can simulate this with a deep stack of hidden nodes whose predictions and outcomes differ, which forces the network to reorganize. Here, we show that the balance of precision between predictors and outcomes shapes the depth of weight updates.

linear = lambda x: x
x = np.array([1.0] * 200)
# x += np.random.normal(size=x.shape) / 100

y = np.array([0.0] * 200)
# y += np.random.normal(size=y.shape) / 100
# We start by defining two networks with varying precision
# at the predictor and outcome levels
high_outcome_precision_network = (
    Network(update_type="unbounded")
    .add_nodes(n_nodes=1, precision=1e2, expected_precision=1e2)
    .add_nodes(
        kind="volatile-state",
        value_children=0,
        coupling_fn=(linear,),
        tonic_volatility=-2.0,
        autoconnection_strength=0,
    )
    .add_nodes(
        kind="volatile-state",
        value_children=1,
        coupling_fn=(linear,),
        tonic_volatility=-2.0,
        autoconnection_strength=0,
    )
    .add_nodes(
        kind="volatile-state",
        value_children=2,
        coupling_fn=(linear,),
        tonic_volatility=-2.0,
        autoconnection_strength=0,
    )
    .add_nodes(
        kind="volatile-state",
        value_children=3,
        coupling_fn=(linear,),
        tonic_volatility=-2.0,
        autoconnection_strength=0,
    )
    .add_nodes(
        kind="volatile-state",
        value_children=4,
        coupling_fn=(linear,),
        tonic_volatility=-2.0,
        autoconnection_strength=0,
        precision=1.0,
        expected_precision=1.0,
    )
)

high_predictor_precision_network = (
    Network(update_type="unbounded")
    .add_nodes(n_nodes=1, precision=1.0, expected_precision=1.0)
    .add_nodes(
        kind="volatile-state",
        value_children=0,
        coupling_fn=(linear,),
        tonic_volatility=-2.0,
        autoconnection_strength=0,
    )
    .add_nodes(
        kind="volatile-state",
        value_children=1,
        coupling_fn=(linear,),
        tonic_volatility=-2.0,
        autoconnection_strength=0,
    )
    .add_nodes(
        kind="volatile-state",
        value_children=2,
        coupling_fn=(linear,),
        tonic_volatility=-2.0,
        autoconnection_strength=0,
    )
    .add_nodes(
        kind="volatile-state",
        value_children=3,
        coupling_fn=(linear,),
        tonic_volatility=-2.0,
        autoconnection_strength=0,
    )
    .add_nodes(
        kind="volatile-state",
        value_children=4,
        coupling_fn=(linear,),
        tonic_volatility=-2.0,
        autoconnection_strength=0,
        precision=1e2,
        expected_precision=1e2,
    )
)
high_predictor_precision_network.fit(
    x=x,
    y=y,
    inputs_x_idxs=(5,),
    inputs_y_idxs=(0,),
    lr="dynamic",
    record_trajectories=True,
    optimizer=None,
)
high_outcome_precision_network.fit(
    x=x,
    y=y,
    inputs_x_idxs=(5,),
    inputs_y_idxs=(0,),
    lr="dynamic",
    record_trajectories=True,
    optimizer=None,
);
high_outcome_precision_network.plot_network()
../_images/4ec0a9666c2c9b8e12322e7a8b180061ddf2f88da64d8f8d9a12117a890d7de8.svg

Hide code cell source

_, ax = plt.subplots(figsize=(7, 6))

# high prediction precision
# -------------------------
palette = sns.color_palette("mako", 8)
handles, labels = [], []
for color, time_step in zip(palette, [0, 2, 4, 10, 25, 50, 100, 200]):
    expected_means = [
        high_predictor_precision_network.node_trajectories[i]["expected_mean"][
            time_step
        ]
        for i in range(6)
    ]
    scatterplot = ax.scatter(
        range(5, -1, -1),
        expected_means,
        s=150,
        color=color,
        label=time_step,
        edgecolor="k",
    )
    handles.append(scatterplot)
    labels.append(f"Step = {time_step}")
    ax.plot(range(5, -1, -1), expected_means, color=color, zorder=-1)

first_legend = ax.legend(
    handles=handles, labels=labels, title="Precise predictor", bbox_to_anchor=(1.3, 1.1)
)
ax.add_artist(first_legend)

# high outcome precision
# ----------------------
palette = sns.color_palette("rocket", 8)
handles, labels = [], []
for color, time_step in zip(palette, [0, 2, 4, 10, 25, 50, 100, 200]):
    expected_means = [
        high_outcome_precision_network.node_trajectories[i]["expected_mean"][time_step]
        for i in range(6)
    ]
    scatterplot = ax.scatter(
        range(5, -1, -1),
        expected_means,
        s=150,
        color=color,
        label=time_step,
        edgecolor="k",
    )
    handles.append(scatterplot)
    labels.append(f"Step = {time_step}")
    ax.plot(range(5, -1, -1), expected_means, color=color, zorder=-1)
second_legend = ax.legend(
    handles=handles, labels=labels, title="Precise outcome", bbox_to_anchor=(1.3, 0.5)
)

plt.xticks(
    ticks=[0, 1, 2, 3, 4, 5],
    labels=["Predictor", "Node 4", "Node 3", "Node 2", "Node 1", "Outcome"],
)
ax.set(
    ylabel="Activation level",
    title="Predictor and outcome precisions \n balance the depth of weight updates",
)
ax.grid(linestyle="--")

sns.despine()
../_images/81fa25b29b316fb53ab93bc4e863e81b407592f8df5cb1456aa248341a2c34d6.png

Building Deep-Network Structures in pyhgf#

Backends#

As of version 0.2.9, PyHGF supports the creation of deep neural networks for classification tasks using:

  • pyhgf.model.Network: using the standard network class. Deep networks have to be built manually using pyhgf.model.Network.add_nodes. This approach is more flexible, but will rapidly struggle with large structures (> 20 nodes) as each update function is cached by JAX.

  • pyhgf.rshgf.Network: is the Rust equivalent and will scale efficiently to larger structures, while remaining flexible in the network configuration. It supports layered designs with pyhgf.rshgf.Network.add_layer(). It is not differentiable.

  • pyhgf.model.DeepNetwork: is a vectorised JAX implementation, differentiable, and the fastest solution.

Here, we demonstrate two new high-level functions for constructing layered, fully connected value-parent structures in pyhgf:

  • add_layer() – adds a single fully connected parent layer.

  • add_layer_stack() – builds multiple layers at once, similar to Sequential in deep-learning frameworks.

These functions allow HGF models to be composed in a deep-network style while remaining fully compatible with the probabilistic belief-update dynamics.

Adding fully connected Layers with add_layer#

add_layer provides fine-grained control, letting you manually construct each layer.

This is useful when each layer should have different hyperparameters (precision, tonic volatility, autoconnection strength, etc.). The function creates a fully connected parent layer, in which each parent node connects to all children below it.

By default, add_layer automatically connects to all orphan nodes (nodes without value parents). You can also specify value_children explicitly to control which nodes the layer connects to.

# Or chain them in a single expression (like Keras/PyTorch):
net = (
    DeepNetwork(coupling_fn=jnp.tanh)
    .add_layer(size=2)
    .add_layer(size=5, tonic_volatility=-1.0)
    .add_layer(size=3, tonic_volatility=-2.0)
)

# Visualize the network structure
plot_deep_network(net)
../_images/fa2bc0d722b0fa5ab977c8fa6f2da5a3e86330d96de772330d8a285a071ffeb1.svg

Adding multiple layers with add_layer_stack#

add_layer_stack provides a compact way to build several fully connected parent layers at once. Instead of adding each layer manually, you simply specify the desired layer sizes (e.g., [3, 16, 32]), and the function creates them sequentially. Each layer is fully connected to the one below, using the same hyperparameters for all layers you add (precision, tonic volatility, autoconnection strength, etc.).

This is ideal when you want to quickly prototype deep hierarchical networks or mimic the “stacked layer” construction found in deep learning frameworks.

Like add_layer, it also supports method chaining and auto-connects to orphan nodes by default.

# Add 3 fully connected parent layers (4→4→16→32) using method chaining
net = (
    DeepNetwork()
    .add_layer(size=4)
    .add_layer_stack(
        layer_sizes=[4, 16, 32],
        tonic_volatility=-1.0,
    )
)

plot_deep_network(net)
../_images/b2d9b16d6a19b3984c003539b4bdf79679a6ac1909ff62474a70ca5d397d84d5.svg

Deep networks often require weights to be initialised using strategies that conserve the balance of variance in the inputs and output nodes. The classes have a pyhgf.DeepNetwork.weight_initialisation() method that supports the most popular strategies.

net.weight_initialisation(strategy="xavier");

Binary classification on a two-moons dataset#

We now demonstrate DeepNetwork on a non-linearly separable problem: the classic two moons dataset. We train a 2 → 16 → 16 → 16 → 16 → 16 → 16 → 16 → 16 → 1 (binary) predictive coding network with tanh coupling functions and a binary output layer, using the Adam optimiser. After training, we visualise how the learned decision boundary evolves across epochs.

Generate the dataset#

We create a synthetic two-moons dataset with 500 samples and split it into 80% training / 20% test. We also pre-compute the mesh grid that will be used for the decision boundary heatmap.

# --- Two-moons dataset ---
def make_moons(n_samples=500, noise=0.15, seed=42):
    """Generate a two-moons dataset."""
    rng = np.random.default_rng(seed)
    n_half = n_samples // 2
    theta_upper = np.linspace(0, np.pi, n_half)
    x_upper = np.column_stack([np.cos(theta_upper), np.sin(theta_upper)])
    theta_lower = np.linspace(0, np.pi, n_samples - n_half)
    x_lower = np.column_stack([1 - np.cos(theta_lower), 1 - np.sin(theta_lower) - 0.5])
    X = np.vstack([x_upper, x_lower]) + rng.normal(scale=noise, size=(n_samples, 2))
    y = np.hstack([np.zeros(n_half), np.ones(n_samples - n_half)])
    idx = rng.permutation(n_samples)
    return X[idx].astype(np.float32), y[idx].astype(np.int32)


N_SAMPLES = 1000
X_moons, y_moons = make_moons(n_samples=N_SAMPLES, noise=0.15, seed=42)

# Train / test split (80 / 20)
n_train = int(0.8 * N_SAMPLES)
X_train_m, X_test_m = X_moons[:n_train], X_moons[n_train:]
y_train_m, y_test_m = y_moons[:n_train], y_moons[n_train:]

# Mesh grid for decision boundary
h = 0.02
x_min, x_max = X_moons[:, 0].min() - 0.5, X_moons[:, 0].max() + 0.5
y_min, y_max = X_moons[:, 1].min() - 0.5, X_moons[:, 1].max() + 0.5
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
grid = jnp.array(np.column_stack([xx.ravel(), yy.ravel()]), dtype=jnp.float32)

print(f"Training: {X_train_m.shape[0]} samples  |  Test: {X_test_m.shape[0]} samples")
print(f"Grid points: {grid.shape[0]:,}")
Training: 800 samples  |  Test: 200 samples
Grid points: 37,848
# Quick scatter of the raw data
fig, ax = plt.subplots(figsize=(5, 5))
for cls, color in [(0, "#3b4cc0"), (1, "#b40426")]:
    mask = y_moons == cls
    ax.scatter(X_moons[mask, 0], X_moons[mask, 1], marker="o", s=20,
               edgecolors="k", linewidths=0.5, c=color, label=f"Class {cls}")
ax.set(xlabel="$x_1$", ylabel="$x_2$", title="Two-Moons Dataset")
ax.legend()
plt.minorticks_on()
sns.despine()
../_images/71c10f0b6329d2647b71db9c6e31d946daa679679f69d462fbb630c730b43acf.png

Build the network#

We construct a 2 → 16 → 16 → 16 → 16 → 16 → 16 → 16 → 16 → 1 (binary) DeepNetwork with tanh coupling functions and He weight initialisation. The binary output layer applies a sigmoid internally, producing probabilities directly.

# Build: 2 → 16 → 16 → 1 (binary output)
clf_net = (
    DeepNetwork(coupling_fn=jnp.tanh)
    .add_layer(size=1, kind="binary")
    .add_layer(size=16, tonic_volatility=-4.0)
    .add_layer(size=16, tonic_volatility=-4.0)
    .add_layer(size=16, tonic_volatility=-4.0)
    .add_layer(size=16, tonic_volatility=-4.0)
    .add_layer(size=16, tonic_volatility=-4.0)
    .add_layer(size=16, tonic_volatility=-4.0)
    .add_layer(size=16, tonic_volatility=-4.0)
    .add_layer(size=16, tonic_volatility=-4.0)
    .add_layer(size=2, add_constant_input=False, coupling_fn=lambda x: x)
    .weight_initialisation("he", seed=0)
)

print(f"Architecture: 2 → 16 → 16 → 1 (binary)")
print(f"Layers: {clf_net.n_layers}  |  Nodes: {clf_net.n_nodes}")
Architecture: 2 → 16 → 16 → 1 (binary)
Layers: 10  |  Nodes: 131

Train over multiple epochs#

We train for 100 epochs using the Adam optimiser and snapshot the decision boundary at regular intervals.

NUM_EPOCHS = 300
SNAPSHOT_EVERY = 2
LR = 0.2

# Prepare JAX arrays
jax_X_train = jnp.array(X_train_m)
jax_y_train = jnp.array(y_train_m, dtype=jnp.float32).reshape(-1, 1)
jax_X_test = jnp.array(X_test_m)
jax_y_test = jnp.array(y_test_m, dtype=jnp.float32).reshape(-1, 1)


def bce(probs, labels, eps=1e-7):
    """Binary cross-entropy from probabilities."""
    p = np.clip(probs, eps, 1 - eps)
    return -np.mean(labels * np.log(p) + (1 - labels) * np.log(1 - p))


train_losses, test_losses = [], []
train_accs, test_accs = [], []
snapshots = {}  # epoch → probability grid

for epoch in range(NUM_EPOCHS):
    # Evaluate on test set (forward pass, no weight updates)
    test_preds = np.array(clf_net.predict(jax_X_test)).ravel()
    test_labels = np.array(y_test_m)
    test_losses.append(bce(test_preds, test_labels))
    test_accs.append(np.mean((test_preds > 0.5).astype(int) == test_labels))

    # Train: one full pass through the training set
    clf_net.fit(jax_X_train, jax_y_train, lr=LR, optimizer="adam")

    # Training metrics from the predictions during this epoch
    train_preds = np.array(clf_net.predictions).ravel()
    train_labels = np.array(y_train_m)
    train_losses.append(bce(train_preds, train_labels))
    train_accs.append(np.mean((train_preds > 0.5).astype(int) == train_labels))

    # Snapshot decision boundary
    if epoch % SNAPSHOT_EVERY == 0 or epoch == NUM_EPOCHS - 1:
        snap_preds = np.array(clf_net.predict(grid)).ravel()
        snapshots[epoch] = snap_preds.reshape(xx.shape)

    if epoch % 10 == 0 or epoch == NUM_EPOCHS - 1:
        print(
            f"Epoch {epoch:>3d} | "
            f"train loss={train_losses[-1]:.4f}, acc={train_accs[-1]:.3f} | "
            f"test  loss={test_losses[-1]:.4f}, acc={test_accs[-1]:.3f}"
        )

print(f"\nSaved {len(snapshots)} decision boundary snapshots")
Epoch   0 | train loss=0.3935, acc=0.829 | test  loss=0.7196, acc=0.410
Epoch  10 | train loss=0.2083, acc=0.919 | test  loss=0.2228, acc=0.915
Epoch  20 | train loss=0.1496, acc=0.953 | test  loss=0.1600, acc=0.935
Epoch  30 | train loss=0.0938, acc=0.974 | test  loss=0.1139, acc=0.960
Epoch  40 | train loss=0.0769, acc=0.978 | test  loss=0.1012, acc=0.965
Epoch  50 | train loss=0.0614, acc=0.985 | test  loss=0.0779, acc=0.975
Epoch  60 | train loss=0.0573, acc=0.985 | test  loss=0.0837, acc=0.975
Epoch  70 | train loss=0.0628, acc=0.984 | test  loss=0.0722, acc=0.980
Epoch  80 | train loss=0.0893, acc=0.969 | test  loss=0.0852, acc=0.985
Epoch  90 | train loss=0.0507, acc=0.990 | test  loss=0.0716, acc=0.980
Epoch 100 | train loss=0.0458, acc=0.993 | test  loss=0.0669, acc=0.980
Epoch 110 | train loss=0.0462, acc=0.990 | test  loss=0.0641, acc=0.980
Epoch 120 | train loss=0.0460, acc=0.991 | test  loss=0.0623, acc=0.980
Epoch 130 | train loss=0.0532, acc=0.989 | test  loss=0.0632, acc=0.980
Epoch 140 | train loss=0.0413, acc=0.993 | test  loss=0.0653, acc=0.980
Epoch 150 | train loss=0.0441, acc=0.991 | test  loss=0.0628, acc=0.980
Epoch 160 | train loss=0.0496, acc=0.991 | test  loss=0.0579, acc=0.985
Epoch 170 | train loss=0.0401, acc=0.994 | test  loss=0.0593, acc=0.980
Epoch 180 | train loss=0.0390, acc=0.995 | test  loss=0.0576, acc=0.980
Epoch 190 | train loss=0.0387, acc=0.994 | test  loss=0.0545, acc=0.980
Epoch 200 | train loss=0.0388, acc=0.993 | test  loss=0.0556, acc=0.980
Epoch 210 | train loss=0.0393, acc=0.994 | test  loss=0.0584, acc=0.985
Epoch 220 | train loss=0.0369, acc=0.994 | test  loss=0.0581, acc=0.980
Epoch 230 | train loss=0.0364, acc=0.995 | test  loss=0.0627, acc=0.980
Epoch 240 | train loss=0.0369, acc=0.995 | test  loss=0.0613, acc=0.980
Epoch 250 | train loss=0.0362, acc=0.994 | test  loss=0.0634, acc=0.980
Epoch 260 | train loss=0.0364, acc=0.994 | test  loss=0.0688, acc=0.975
Epoch 270 | train loss=0.0356, acc=0.994 | test  loss=0.0625, acc=0.985
Epoch 280 | train loss=0.0353, acc=0.995 | test  loss=0.0622, acc=0.985
Epoch 290 | train loss=0.0351, acc=0.995 | test  loss=0.0622, acc=0.985
Epoch 299 | train loss=0.0349, acc=0.995 | test  loss=0.0614, acc=0.985

Saved 151 decision boundary snapshots

Training curves#

fig, axs = plt.subplots(1, 2, figsize=(12, 4))

axs[0].plot(train_losses, label="Train", alpha=0.8, color="#4c72b0")
axs[0].plot(test_losses, label="Test", alpha=0.8, color="#55a868")
axs[0].set(xlabel="Epoch", ylabel="Binary Cross-Entropy", title="Loss")
axs[0].legend()
axs[0].grid(linestyle="--")

axs[1].plot(train_accs, label="Train", alpha=0.8, color="#4c72b0")
axs[1].plot(test_accs, label="Test", alpha=0.8, color="#55a868")
axs[1].set(xlabel="Epoch", ylabel="Accuracy", title="Accuracy")
axs[1].legend()
axs[1].grid(linestyle="--")

sns.despine()
../_images/931cf4652443807d90317273f41491c2a99d5e7b5229e2df2b3ab2e05ac05b05.png

Decision boundary evolution#

We animate the decision boundary heatmap across training epochs alongside the loss curve.

../_images/two_moons_training.gif

Fig. 5 The GIF shows how the predictive coding network gradually carves out a non-linear separator to distinguish the two moons.#

System configuration#

%load_ext watermark
%watermark -n -u -v -iv -w -p pyhgf,jax,jaxlib
Last updated: Wed, 15 Apr 2026

Python implementation: CPython
Python version       : 3.12.3
IPython version      : 9.12.0

pyhgf : 0.2.10
jax   : 0.4.31
jaxlib: 0.4.31

IPython   : 9.12.0
jax       : 0.4.31
matplotlib: 3.10.8
numpy     : 2.4.4
pyhgf     : 0.2.10
seaborn   : 0.13.2
treescope : 0.1.10

Watermark: 2.6.0