Source code for pyhgf.plots.matplotlib.plot_correlations
# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk>
from typing import TYPE_CHECKING
import pandas as pd
import seaborn as sns
from matplotlib.axes import Axes
if TYPE_CHECKING:
from pyhgf.model import Network
[docs]
def plot_correlations(network: "Network") -> Axes:
"""Plot the heatmap correlation of the sufficient statistics trajectories.
Parameters
----------
network :
An instance of the HGF model.
Returns
-------
axs :
The Matplotlib axe instance containing the heatmap of parameters trajectories
correlation.
"""
trajectories_df = network.to_pandas()
trajectories_df = pd.concat(
[
trajectories_df[["time"]],
trajectories_df[
[
f"x_{i}_mean"
for i in range(len(network.edges))
if i in network.input_idxs
]
],
trajectories_df.filter(regex="expected"),
trajectories_df.filter(regex="surprise"),
],
axis=1,
)
correlation_mat = trajectories_df.corr()
ax = sns.heatmap(
correlation_mat,
cmap="RdBu",
vmin=-1,
vmax=1,
linewidths=2,
square=True,
)
ax.set_title("Correlations between the model trajectories")
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", size=8)
ax.set_yticklabels(ax.get_yticklabels(), size=8)
return ax