Source code for pyhgf.utils.fill_categorical_state_node

# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk>

from typing import TYPE_CHECKING, Dict, List

from pyhgf.typing import AdjacencyLists

if TYPE_CHECKING:
    from pyhgf.model import Network


[docs] def fill_categorical_state_node( network: "Network", node_idx: int, binary_states_idxs: List[int], binary_parameters: Dict, ) -> "Network": """Generate a binary network implied by categorical state(-transition) nodes. Parameters ---------- network : Instance of a Network. node_idx : Index to the categorical state node. binary_states_idxs : The indexes of the binary state nodes. binary_parameters : Parameters for the set of implied binary HGFs. Returns ------- hgf : The updated instance of the HGF model. """ # add the binary states - one for each category network.add_nodes( kind="binary-state", n_nodes=len(binary_states_idxs), node_parameters={ "mean": binary_parameters["mean_1"], "precision": binary_parameters["precision_1"], }, ) # add the value coupling between the categorical and binary states edges_as_list: List[AdjacencyLists] = list(network.edges) edges_as_list[node_idx] = AdjacencyLists( 5, tuple(binary_states_idxs), None, None, None, (None,) ) for binary_idx in binary_states_idxs: edges_as_list[binary_idx] = AdjacencyLists( 1, None, None, (node_idx,), None, (None,) ) network.edges = tuple(edges_as_list) # add continuous state parent nodes n_nodes = len(network.edges) for i in range(binary_parameters["n_categories"]): network.add_nodes( value_children=i + n_nodes - binary_parameters["n_categories"], node_parameters={ "mean": binary_parameters["mean_2"], "precision": binary_parameters["precision_2"], "tonic_volatility": binary_parameters["tonic_volatility_2"], }, ) # add the higher level volatility parents # as a shared parents between the second level nodes network.add_nodes( volatility_children=[ idx + binary_parameters["n_categories"] for idx in binary_states_idxs ], node_parameters={ "mean": binary_parameters["mean_3"], "precision": binary_parameters["precision_3"], "tonic_volatility": binary_parameters["tonic_volatility_3"], }, ) return network