pyhgf.typing.vectorised.LayerStack#

class pyhgf.typing.vectorised.LayerStack(state, params, weights_in, coupling_fn, add_constant_input, has_volatility_parent, fully_connected, kind, n_layers)[source]#

N identical layers stacked into one PyTree with a leading (N,) axis.

state/params have leading axis N (each field shape goes from (n_nodes,) to (N, n_nodes)). weights_in goes from (n_child, n_self[+1]) to (N, n_child, n_self[+1]). Slice index 0 is the bottommost slice in the stack (closest to layer 0 of the network); slice N-1 is the topmost.

Validation constraints, enforced at build time:

  • The layer immediately below the stack must have the same node count as the stack

width (so weights_in[0] shape matches). * weights_in[k] for k > 0 is a square (W, W+bias) block connecting slice k (parent) to slice k-1 (child) within the stack.

Parameters:
  • state (pyhgf.typing.vectorised.LayerState) – The stacked per-layer state, each field with a leading (N,) axis.

  • params (pyhgf.typing.vectorised.LayerParams) – The stacked per-layer static parameters, each field with a leading (N,) axis.

  • weights_in (jax.Array) – The stacked incoming weight matrices, shape (N, n_child, n_self[+1]).

  • coupling_fn (Callable) – The coupling function shared by all stacked layers.

  • add_constant_input (bool) – Whether a constant (bias) input column is appended to the weights.

  • has_volatility_parent (bool) – Whether the layers have a volatility parent.

  • fully_connected (bool) – Whether the incoming weights are fully connected.

  • kind (str) – The kind of layer, either "volatile" or "binary".

  • n_layers (int) – The number of stacked layers N.

__init__(state, params, weights_in, coupling_fn, add_constant_input, has_volatility_parent, fully_connected, kind, n_layers)#
Parameters:
Return type:

None

Methods

__init__(state, params, weights_in, ...)

Attributes

state

params

weights_in

coupling_fn

add_constant_input

has_volatility_parent

fully_connected

kind

n_layers