pyhgf.typing.vectorised.LayerStack#
- class pyhgf.typing.vectorised.LayerStack(state, params, weights_in, coupling_fn, add_constant_input, has_volatility_parent, fully_connected, kind, n_layers)[source]#
N identical layers stacked into one PyTree with a leading
(N,)axis.state/paramshave leading axisN(each field shape goes from(n_nodes,)to(N, n_nodes)).weights_ingoes from(n_child, n_self[+1])to(N, n_child, n_self[+1]). Slice index 0 is the bottommost slice in the stack (closest to layer 0 of the network); sliceN-1is the topmost.Validation constraints, enforced at build time:
The layer immediately below the stack must have the same node count as the stack
width (so
weights_in[0]shape matches). *weights_in[k]for k > 0 is a square(W, W+bias)block connecting slice k (parent) to slice k-1 (child) within the stack.- Parameters:
state (pyhgf.typing.vectorised.LayerState) – The stacked per-layer state, each field with a leading
(N,)axis.params (pyhgf.typing.vectorised.LayerParams) – The stacked per-layer static parameters, each field with a leading
(N,)axis.weights_in (jax.Array) – The stacked incoming weight matrices, shape
(N, n_child, n_self[+1]).coupling_fn (Callable) – The coupling function shared by all stacked layers.
add_constant_input (bool) – Whether a constant (bias) input column is appended to the weights.
has_volatility_parent (bool) – Whether the layers have a volatility parent.
fully_connected (bool) – Whether the incoming weights are fully connected.
kind (str) – The kind of layer, either
"volatile"or"binary".n_layers (int) – The number of stacked layers
N.
- __init__(state, params, weights_in, coupling_fn, add_constant_input, has_volatility_parent, fully_connected, kind, n_layers)#
- Parameters:
state (LayerState)
params (LayerParams)
weights_in (Array)
coupling_fn (Callable)
add_constant_input (bool)
has_volatility_parent (bool)
fully_connected (bool)
kind (str)
n_layers (int)
- Return type:
None
Methods
__init__(state, params, weights_in, ...)Attributes
stateparamsweights_incoupling_fnadd_constant_inputhas_volatility_parentfully_connectedkindn_layers