Source code for pyhgf.utils.add_edges

# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk>

from typing import Callable, Dict, List, Optional, Tuple, Union

from pyhgf.typing import AdjacencyLists, Edges


[docs] def add_edges( attributes: Dict, edges: Edges, kind="value", parent_idxs=Union[int, List[int]], children_idxs=Union[int, List[int]], coupling_strengths: Union[float, List[float], Tuple[float]] = 1.0, coupling_fn: Tuple[Optional[Callable], ...] = (None,), ) -> Tuple: """Add a value or volatility coupling link between a set of nodes. Parameters ---------- attributes : Attributes of the neural network. edges : Edges of the neural network. kind : The kind of coupling can be `"value"` or `"volatility"`. parent_idxs : The index(es) of the parent node(s). children_idxs : The index(es) of the children node(s). coupling_strengths : The coupling strength between the parents and children. coupling_fn : Coupling function(s) between the current node and its value children. It has to be provided as a tuple. If multiple value children are specified, the coupling functions must be stated in the same order of the children. Note: if a node has multiple parents nodes with different coupling functions, a coupling function should be indicated for all the parent nodes. If no coupling function is stated, the relationship between nodes is assumed linear. """ if kind not in ["value", "volatility"]: raise ValueError( f"The kind of coupling should be value or volatility, got {kind}" ) if isinstance(children_idxs, int): children_idxs = [children_idxs] assert isinstance(children_idxs, (list, tuple)) if isinstance(parent_idxs, int): parent_idxs = [parent_idxs] assert isinstance(parent_idxs, (list, tuple)) if isinstance(coupling_strengths, int): coupling_strengths = [float(coupling_strengths)] if isinstance(coupling_strengths, float): coupling_strengths = [coupling_strengths] assert isinstance(coupling_strengths, (list, tuple)) edges_as_list = list(edges) # update the parent nodes # ----------------------- for parent_idx in parent_idxs: # unpack the parent's edges ( node_type, value_parents, volatility_parents, value_children, volatility_children, this_coupling_fn, ) = edges_as_list[parent_idx] if kind == "value": if value_children is None: value_children = tuple(children_idxs) attributes[parent_idx]["value_coupling_children"] = tuple( coupling_strengths ) else: value_children = value_children + tuple(children_idxs) attributes[parent_idx]["value_coupling_children"] += tuple( coupling_strengths ) this_coupling_fn = this_coupling_fn + coupling_fn elif kind == "volatility": if volatility_children is None: volatility_children = tuple(children_idxs) attributes[parent_idx]["volatility_coupling_children"] = tuple( coupling_strengths ) else: volatility_children = volatility_children + tuple(children_idxs) attributes[parent_idx]["volatility_coupling_children"] += tuple( coupling_strengths ) # save the updated edges back edges_as_list[parent_idx] = AdjacencyLists( node_type, value_parents, volatility_parents, value_children, volatility_children, this_coupling_fn, ) # update the children nodes # ------------------------- for children_idx in children_idxs: # unpack this node's edges ( node_type, value_parents, volatility_parents, value_children, volatility_children, coupling_fn, ) = edges_as_list[children_idx] if kind == "value": if value_parents is None: value_parents = tuple(parent_idxs) attributes[children_idx]["value_coupling_parents"] = tuple( coupling_strengths ) else: value_parents = value_parents + tuple(parent_idxs) attributes[children_idx]["value_coupling_parents"] += tuple( coupling_strengths ) elif kind == "volatility": if volatility_parents is None: volatility_parents = tuple(parent_idxs) attributes[children_idx]["volatility_coupling_parents"] = tuple( coupling_strengths ) else: volatility_parents = volatility_parents + tuple(parent_idxs) attributes[children_idx]["volatility_coupling_parents"] += tuple( coupling_strengths ) # save the updated edges back edges_as_list[children_idx] = AdjacencyLists( node_type, value_parents, volatility_parents, value_children, volatility_children, coupling_fn, ) # convert the list back to a tuple edges = tuple(edges_as_list) return attributes, edges