# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk>
from typing import Callable, Optional, Union
import jax.numpy as jnp
from pyhgf.typing import AdjacencyLists, Edges
[docs]
def add_edges(
    attributes: dict,
    edges: Edges,
    kind="value",
    parent_idxs=Union[int, list[int]],
    children_idxs=Union[int, list[int]],
    coupling_strengths: Union[float, list[float], tuple[float]] = 1.0,
    coupling_fn: tuple[Optional[Callable], ...] = (None,),
) -> tuple:
    """Add a value or volatility coupling link between a set of nodes.
    Parameters
    ----------
    attributes :
        Attributes of the neural network.
    edges :
        Edges of the neural network.
    kind :
        The kind of coupling can be `"value"` or `"volatility"`.
    parent_idxs :
        The index(es) of the parent node(s).
    children_idxs :
        The index(es) of the children node(s).
    coupling_strengths :
        The coupling strength between the parents and children.
    coupling_fn :
        Coupling function(s) between the current node and its value children.
        It has to be provided as a tuple. If multiple value children are specified,
        the coupling functions must be stated in the same order of the children.
        Note: if a node has multiple parents nodes with different coupling
        functions, a coupling function should be indicated for all the parent nodes.
        If no coupling function is stated, the relationship between nodes is assumed
        linear.
    """
    if kind not in ["value", "volatility"]:
        raise ValueError(
            f"The kind of coupling should be value or volatility, got {kind}"
        )
    if isinstance(children_idxs, int):
        children_idxs = [children_idxs]
    assert isinstance(children_idxs, (list, tuple))
    if isinstance(parent_idxs, int):
        parent_idxs = [parent_idxs]
    assert isinstance(parent_idxs, (list, tuple))
    if isinstance(coupling_strengths, int):
        coupling_strengths = [float(coupling_strengths)]
    if isinstance(coupling_strengths, float):
        coupling_strengths = [coupling_strengths]
    assert isinstance(coupling_strengths, (list, tuple))
    edges_as_list = list(edges)
    # update the parent nodes
    # -----------------------
    for parent_idx in parent_idxs:
        # unpack the parent's edges
        (
            node_type,
            value_parents,
            volatility_parents,
            value_children,
            volatility_children,
            this_coupling_fn,
        ) = edges_as_list[parent_idx]
        if kind == "value":
            if value_children is None:
                value_children = tuple(children_idxs)
                attributes[parent_idx]["value_coupling_children"] = jnp.array(
                    coupling_strengths
                )
            else:
                value_children = value_children + tuple(children_idxs)
                attributes[parent_idx]["value_coupling_children"] = jnp.concatenate(
                    [
                        attributes[parent_idx]["value_coupling_children"],
                        jnp.array(coupling_strengths),
                    ]
                )
        elif kind == "volatility":
            if volatility_children is None:
                volatility_children = tuple(children_idxs)
                attributes[parent_idx]["volatility_coupling_children"] = jnp.array(
                    coupling_strengths
                )
            else:
                volatility_children = volatility_children + tuple(children_idxs)
                attributes[parent_idx]["volatility_coupling_children"] = (
                    jnp.concatenate(
                        [
                            attributes[parent_idx]["volatility_coupling_children"],
                            jnp.array(coupling_strengths),
                        ]
                    )
                )
        # save the updated edges back
        edges_as_list[parent_idx] = AdjacencyLists(
            node_type,
            value_parents,
            volatility_parents,
            value_children,
            volatility_children,
            this_coupling_fn,
        )
    # update the children nodes
    # -------------------------
    for children_idx in children_idxs:
        # unpack this node's edges
        (
            node_type,
            value_parents,
            volatility_parents,
            value_children,
            volatility_children,
            coupling_fn,
        ) = edges_as_list[children_idx]
        if kind == "value":
            if value_parents is None:
                value_parents = tuple(parent_idxs)
                attributes[children_idx]["value_coupling_parents"] = jnp.array(
                    coupling_strengths
                )
            else:
                value_parents = value_parents + tuple(parent_idxs)
                attributes[children_idx]["value_coupling_parents"] = jnp.concatenate(
                    [
                        attributes[children_idx]["value_coupling_parents"],
                        jnp.array(coupling_strengths),
                    ]
                )
        elif kind == "volatility":
            if volatility_parents is None:
                volatility_parents = tuple(parent_idxs)
                attributes[children_idx]["volatility_coupling_parents"] = jnp.array(
                    coupling_strengths
                )
            else:
                volatility_parents = volatility_parents + tuple(parent_idxs)
                attributes[children_idx]["volatility_coupling_parents"] = (
                    jnp.concatenate(
                        [
                            attributes[children_idx]["volatility_coupling_parents"],
                            jnp.array(coupling_strengths),
                        ]
                    )
                )
        # save the updated edges back
        edges_as_list[children_idx] = AdjacencyLists(
            node_type,
            value_parents,
            volatility_parents,
            value_children,
            volatility_children,
            coupling_fn,
        )
    # convert the list back to a tuple
    edges = tuple(edges_as_list)
    return attributes, edges