Source code for pyhgf.plots.graphviz.plot_network
# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk>
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from graphviz.sources import Source
from pyhgf.model import Network
[docs]
def plot_network(network: "Network") -> "Source":
"""Visualization of node network using GraphViz.
Parameters
----------
network :
An instance of main Network class.
Notes
-----
This function requires [Graphviz](https://github.com/xflr6/graphviz) to be
installed to work correctly.
"""
try:
import graphviz
except ImportError:
print(
(
"Graphviz is required to plot networks. "
"See https://pypi.org/project/graphviz/"
)
)
graphviz_structure = graphviz.Digraph("hgf-nodes", comment="Nodes structure")
graphviz_structure.attr("node", shape="circle")
# create the rest of nodes
for idx in range(len(network.edges)):
style = "filled" if idx in network.input_idxs else ""
if network.edges[idx].node_type == 1:
# binary state node
graphviz_structure.node(
f"x_{idx}", label=str(idx), shape="square", style=style
)
elif network.edges[idx].node_type == 2:
# Continuous state nore
graphviz_structure.node(
f"x_{idx}", label=str(idx), shape="circle", style=style
)
elif network.edges[idx].node_type == 3:
# Exponential family state nore
graphviz_structure.node(
f"x_{idx}",
label=f"EF-{idx}",
style="filled",
shape="circle",
fillcolor="#ced6e4",
)
elif network.edges[idx].node_type == 4:
# Dirichlet Process state node
graphviz_structure.node(
f"x_{idx}",
label=f"DP-{idx}",
style="filled",
shape="doublecircle",
fillcolor="#e2d8c1",
)
elif network.edges[idx].node_type == 5:
# Categorical state node
graphviz_structure.node(
f"x_{idx}",
label=f"Ca-{idx}",
style=style,
shape="diamond",
fillcolor="#e2d8c1",
)
# connect value parents
for i, index in enumerate(network.edges):
value_parents = index.value_parents
if value_parents is not None:
for value_parents_idx in value_parents:
# get the coupling function from the value parent
child_idx = network.edges[value_parents_idx].value_children.index(i)
coupling_fn = network.edges[value_parents_idx].coupling_fn[child_idx]
graphviz_structure.edge(
f"x_{value_parents_idx}",
f"x_{i}",
color="black" if coupling_fn is None else "black:invis:black",
)
# connect volatility parents
for i, index in enumerate(network.edges):
volatility_parents = index.volatility_parents
if volatility_parents is not None:
for volatility_parents_idx in volatility_parents:
graphviz_structure.edge(
f"x_{volatility_parents_idx}",
f"x_{i}",
color="gray",
style="dashed",
arrowhead="dot",
)
# unflat the structure to better handle large/uneven networks
graphviz_structure = graphviz_structure.unflatten(stagger=3)
return graphviz_structure