Source code for pyhgf.plots.graphviz.plot_network

# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk>

from __future__ import annotations

from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
    from graphviz.sources import Source

    from pyhgf.model import DeepNetwork, Network

from graphviz import Digraph


[docs] def plot_network(network: Network) -> Source: """Visualization of node network using GraphViz. Parameters ---------- network : An instance of main Network class. Returns ------- graphviz_structure : Graphviz object. Notes ----- This function requires [Graphviz](https://github.com/xflr6/graphviz) to be installed to work correctly. """ try: import graphviz except ImportError: print( ( "Graphviz is required to plot networks. " "See https://pypi.org/project/graphviz/" ) ) graphviz_structure = graphviz.Digraph("hgf-nodes", comment="Nodes structure") graphviz_structure.attr("node", shape="circle") # create the rest of nodes for idx in range(len(network.edges)): style = "filled" if idx in network.input_idxs else "" if network.edges[idx].node_type == 1: # binary state node graphviz_structure.node( f"x_{idx}", label=str(idx), shape="square", style=style ) elif network.edges[idx].node_type == 2: # Continuous state nore graphviz_structure.node( f"x_{idx}", label=str(idx), shape="circle", style=style ) elif network.edges[idx].node_type == 3: # Exponential family state nore graphviz_structure.node( f"x_{idx}", label=f"EF-{idx}", style="filled", shape="circle", fillcolor="#ced6e4", ) elif network.edges[idx].node_type == 4: # Dirichlet Process state node graphviz_structure.node( f"x_{idx}", label=f"DP-{idx}", style="filled", shape="doublecircle", fillcolor="#e2d8c1", ) elif network.edges[idx].node_type == 5: # Categorical state node graphviz_structure.node( f"x_{idx}", label=f"Ca-{idx}", style=style, shape="diamond", fillcolor="#e2d8c1", ) elif network.edges[idx].node_type == 6: # Value-volatility hybrid node # Double circle with gray outer ring (volatility) and solid inner (value) graphviz_structure.node( f"x_{idx}", label=f"{idx}", style="filled", shape="doublecircle", color="gray", fillcolor="white" if idx not in network.input_idxs else "lightgray", ) # connect value parents for i, index in enumerate(network.edges): value_parents = index.value_parents if value_parents is not None: for value_parents_idx in value_parents: # get the coupling function from the value parent child_idx = network.edges[value_parents_idx].value_children.index(i) coupling_fn = network.edges[value_parents_idx].coupling_fn[child_idx] graphviz_structure.edge( f"x_{value_parents_idx}", f"x_{i}", color="black" if coupling_fn is None else "black:invis:black", ) # connect volatility parents for i, index in enumerate(network.edges): volatility_parents = index.volatility_parents if volatility_parents is not None: for volatility_parents_idx in volatility_parents: graphviz_structure.edge( f"x_{volatility_parents_idx}", f"x_{i}", color="gray", style="dashed", arrowhead="dot", ) # unflat the structure to better handle large/uneven networks graphviz_structure = graphviz_structure.unflatten(stagger=3) return graphviz_structure
[docs] def plot_deep_network( deep_network: DeepNetwork, filename: Optional[str] = None, view: bool = True ): """Visualisation of a fully connected deep network using GraphViz. Parameters ---------- deep_network : layers : Returns ------- graphviz_structure : Graphviz object. """ graphviz_structure = Digraph( "deep-network", graph_attr={ "rankdir": "TB", # Top → Bottom flow "splines": "ortho", "nodesep": "0.7", "ranksep": "0.9", }, node_attr={ "shape": "box", "style": "rounded,filled", "fillcolor": "#E8E8E8", "color": "#444444", "penwidth": "1.2", "fontname": "Helvetica", "fontsize": "12", }, edge_attr={ "arrowhead": "vee", "arrowsize": "0.9", "color": "#444444", "penwidth": "1.0", }, ) # Reverse so the bottom layer appears at the bottom visually layers_reversed = list(reversed(deep_network.layer_sizes)) layer_names = [] num_layers = len(deep_network.layer_sizes) # Create each layer block for i, n_units in enumerate(layers_reversed): true_idx = num_layers - 1 - i # index in original (bottom=0) if true_idx == 0: label = f"Outcome Layer (Y) \n({n_units} units)" elif true_idx == num_layers - 1: label = f"Prediction Layer (X)\n({n_units} units)" else: label = f"Hidden Layer {true_idx}\n({n_units} units)" name = f"layer_{i}" layer_names.append(name) graphviz_structure.node(name, label=label) # Draw downward arrows between layers for i in range(len(layer_names) - 1): # layer_names[i] is the top-most visual layer; true_idx counts from # the bottom (0 = outcome). coupling_fns[j] is applied to layer j's # output, so the edge from parent layer to child layer j uses # coupling_fns[j]. child_true_idx = num_layers - 1 - (i + 1) fn = deep_network.coupling_fns[child_true_idx] fn_name = getattr(fn, "__name__", None) or type(fn).__name__ # Replace the anonymous identity with a readable label if fn_name == "<lambda>": fn_name = "linear" graphviz_structure.edge(layer_names[i], layer_names[i + 1], xlabel=fn_name) if filename is not None: graphviz_structure.render(filename, view=view, format="pdf") return graphviz_structure