pyhgf.plots.matplotlib.plot_layers#

pyhgf.plots.matplotlib.plot_layers(network, layers=None, variables=('expected_mean',), mode='all', figsize=None, color=None, axs=None)[source]#

Plot layer-wise parameter trajectories of a DeepNetwork.

Each row of the resulting figure corresponds to a variable (a field of pyhgf.typing.LayerState) and each column to a layer. In "all" mode every node trajectory is drawn as its own line; in "mean_ci" mode the across-node mean and a 95% confidence interval are drawn as a Matplotlib line + shaded band.

Parameters:
  • network (DeepNetwork) – A pyhgf.model.DeepNetwork instance whose trajectories attribute has been populated (call net.fit(..., record_trajectories=True) first).

  • layers (int | Sequence[int] | None) – Index or indices of the layers to plot. A single int is accepted as shorthand for a one-element list. None (default) plots every layer.

  • variables (str | Sequence[str]) – Name (or sequence of names) of pyhgf.typing.LayerState fields to plot — for example "expected_mean", "precision", "value_prediction_error", "mean_vol". The derived name "PWPE" is also accepted: it plots the magnitude of the precision-weighted prediction error, |mean - expected_mean| * expected_precision (the absolute value of the PE is used so that positive and negative deviations both contribute positively to the displayed signal). A single string is accepted as shorthand for a one-element list.

  • mode (str) – "all" to draw one line per node, "mean_ci" to draw the across-node mean with a 95% normal-approximation confidence band.

  • figsize (tuple | None) – Figure size in inches. Defaults to (3.5 * n_cols, 2.5 * n_rows).

  • color (tuple | str | None) – The color of the lines ("all" mode) or of the mean curve and confidence band ("mean_ci" mode). When None (default), Matplotlib’s default colour cycle is used.

  • axs (ndarray | None) – A 2D array of Matplotlib axes (rows = variables, cols = layers) where to draw the trajectories. The default is None (create a new figure), matching plot_trajectories().

Returns:

2D ndarray of Matplotlib axes, shape (len(variables), len(layers)).

Return type:

axs

Raises:

ValueError – If network.trajectories is None, if a variable name is not a LayerState field, if a layer index is out of range, or if mode is not one of "all"/"mean_ci".