Full API reference
ActionModels.AbstractRescorlaWagnerSubmodel
— TypeAbstractRescorlaWagnerSubmodel
Abstract supertype for Rescorla-Wagner submodels. Used for dispatching between continuous, binary, and categorical variants.
ActionModels.Action
— TypeAction(distributiontype, actiontype)
Construct an action output for the model, specifying the distribution type (e.g., Normal, Bernoulli).
Arguments
T
: Type of the action (optional, inferred from distribution_type if not given).distribution_type
: The distribution type for the action (e.g., Normal, Bernoulli, MvNormal). Can also be an abstract type likeDistribution{Multivariate, Continuous}
to allow multiple types of distributions.
Examples
julia> Action(Normal)
Action{Float64, Normal}(Float64, Normal)
julia> Action(Bernoulli)
Action{Int64, Bernoulli}(Int64, Bernoulli)
julia> Action(MvNormal)
Action{Array{Float64}, MvNormal}(Array{Float64}, MvNormal)
julia> Action(Distribution{Multivariate, Discrete})
Action{Array{Int64}, Distribution{Multivariate, Discrete}}(Array{Int64}, Distribution{Multivariate, Discrete})
ActionModels.ActionModel
— TypeActionModel(action_model; parameters, states, observations, actions, submodel, verbose)
Main container for a user-defined or premade action model.
Arguments
action_model
: The function implementing the model's update and action logic.parameters
: NamedTuple of model parameters or a single Parameter.states
: NamedTuple of model states or a single State (optional).observations
: NamedTuple of model observations or a single Observation (optional).actions
: NamedTuple of model actions or a single Action.submodel
: Optional submodel for hierarchical or modular models (default: NoSubModel()).verbose
: Print warnings for singletons (default: true).
Examples
julia> model_fn = (attributes, obs) -> Normal(0, 1);
julia> ActionModel(model_fn, parameters=(learning_rate=Parameter(0.1),), states=(expected_value=State(0.0),), observations=(observation=Observation(),), actions=(report=Action(Normal),))
-- ActionModel --
Action model function: #1
Number of parameters: 1
Number of states: 1
Number of observations: 1
Number of actions: 1
julia> ActionModel(model_fn, parameters=Parameter(0.1), actions = Action(Normal), verbose = false)
-- ActionModel --
Action model function: #1
Number of parameters: 1
Number of states: 0
Number of observations: 0
Number of actions: 1
ActionModels.Agent
— TypeAgent{T<:AbstractSubmodel}
Container for a simulated agent, including the action model, model attributes, state history, and number of timesteps.
Fields
action_model
: The function implementing the agent's action model logic.model_attributes
: TheModelAttributes
instance containing parameters, states, and actions.history
: NamedTuple of vectors storing the history of selected states.n_timesteps
: Variable tracking the number of timesteps simulated.
Examples
julia> am = ActionModel(RescorlaWagner());
julia> agent = init_agent(am, save_history=true)
-- ActionModels Agent --
Action model: rescorla_wagner_act_after_update
This agent has received 0 observations
ActionModels.AttributeError
— TypeAttributeError()
Custom error type for missing or invalid attribute access. Used for error handling in custom submodels.
Examples
julia> throw(AttributeError())
ERROR: AttributeError()
ActionModels.BinaryRescorlaWagner
— TypeBinaryRescorlaWagner(; initialvalue=0.0, learningrate=0.1)
Binary Rescorla-Wagner submodel for use in action models.
Fields
initial_value
: Initial expected value parameter (Float64).learning_rate
: Learning rate parameter (Float64).
Examples
julia> ActionModels.BinaryRescorlaWagner()
ActionModels.BinaryRescorlaWagner(0.0, 0.1)
ActionModels.CategoricalRescorlaWagner
— TypeCategoricalRescorlaWagner(; ncategories, initialvalue=zeros(ncategories), learningrate=0.1)
Categorical Rescorla-Wagner submodel for use in action models.
Fields
n_categories
: Number of categories (Int64).initial_value
: Initial expected value vector parameter (Vector{Float64}).learning_rate
: Learning rate parameter (Float64).
Examples
julia> ActionModels.CategoricalRescorlaWagner(n_categories=3)
ActionModels.CategoricalRescorlaWagner(3, [0.0, 0.0, 0.0], 0.1)
ActionModels.ContinuousRescorlaWagner
— TypeContinuousRescorlaWagner(; initialvalue=0.0, learningrate=0.1)
Continuous Rescorla-Wagner submodel for use in action models.
Fields
initial_value
: Initial expected value parameter (Float64).learning_rate
: Learning rate parameter (Float64).
Examples
julia> ActionModels.ContinuousRescorlaWagner()
ActionModels.ContinuousRescorlaWagner(0.0, 0.1)
ActionModels.InitialStateParameter
— TypeInitialStateParameter(initialvalue, statename; discrete=false)
Type for defining initial state parameters in action models. Initial state parameters define the initial value of a state in the model. Can be continuous or discrete.
Arguments
value
: The default value for the initial state parameter. Can be a single value or an array.state_name
: The symbol name of the state this parameter controls.discrete
: If true, value is treated as discrete (Int), otherwise as continuous (Float).
Examples
julia> InitialStateParameter(0.0, :expected_value)
InitialStateParameter{Float64}(:expected_value, 0.0, Float64)
julia> InitialStateParameter(1, :counter, discrete=true)
InitialStateParameter{Int64}(:counter, 1, Int64)
julia> InitialStateParameter([0.0, 0.0], :weights)
InitialStateParameter{Array{Float64}}(:weights, [0.0, 0.0], Array{Float64})
ActionModels.ModelAttributes
— TypeModelAttributes(parameters, states, actions, initial_states, submodel)
Internal container for all model variables and submodel attributes used in an action model instance.
Arguments
parameters
: NamedTuple of parameter variablesstates
: NamedTuple of state variablesactions
: NamedTuple of action variablesinitial_states
: NamedTuple of initial state valuessubmodel
: Submodel attributes (orNoSubModelAttributes
if not used)
Examples
julia> ModelAttributes((learning_rate=ActionModels.Variable(0.1),), (expected_value=ActionModels.Variable(0.0),), (report=ActionModels.Variable(missing),), (expected_value=ActionModels.Variable(0.0),), ActionModels.NoSubModelAttributes())
ModelAttributes{@NamedTuple{learning_rate::ActionModels.Variable{Float64}}, @NamedTuple{expected_value::ActionModels.Variable{Float64}}, @NamedTuple{report::ActionModels.Variable{Missing}}, @NamedTuple{expected_value::ActionModels.Variable{Float64}}, ActionModels.NoSubModelAttributes}((learning_rate = ActionModels.Variable{Float64}(0.1),), (expected_value = ActionModels.Variable{Float64}(0.0),), (report = ActionModels.Variable{Missing}(missing),), (expected_value = ActionModels.Variable{Float64}(0.0),), ActionModels.NoSubModelAttributes())
ActionModels.ModelFit
— TypeModelFit{T}
Container for a fitted model, including the Turing model, population model type, data, and results (both posterior and prior).
Type Parameters
T
: The population model type (subtype ofAbstractPopulationModel
).
Fields
model
: The Turing model object.population_model_type
: The population model type.population_data
: The data describing each session (i.e., after calling unique(data, session_cols), so no observations and actions).info
: Metadata about the fit (as aModelFitInfo
).prior
: Prior fit results (empty untilsample_posterior!
is called).posterior
: Posterior fit results (empty untilsample_prior!
is called).
ActionModels.ModelFitInfo
— TypeModelFitInfo
Container for metadata about a model fit, including session IDs and estimated parameter names.
Fields
session_ids::Vector{String}
: List of session identifiers.estimated_parameter_names::Tuple{Vararg{Symbol}}
: Names of estimated parameters.
ActionModels.ModelFitResult
— TypeModelFitResult
Container for the results of a model fit (either prior or posterior), including MCMC chains and (optionally) session-level parameters.
Fields
chains::Chains
: MCMC chains from Turing.jl.session_parameters::Union{Nothing,AbstractFittingResult}
: Session-level parameter results (optional).
ActionModels.NoSubModel
— TypeNoSubModel()
Default submodel type used when no submodel is specified in an ActionModel.
Examples
julia> ActionModels.NoSubModel()
ActionModels.NoSubModel()
ActionModels.Observation
— TypeObservation(; discrete=false) Observation([T])
Construct an observation input to the model. Can be continuous (Float64), discrete (Int64), or a custom type.
Arguments
discrete
: If true, observation is treated as discrete (Int64), otherwise as continuous (Float64).T
: Type of the observation (e.g., Float64, Int64, Vector{Float64}). Used for setting specific types.
Examples
julia> Observation()
Observation{Float64}(Float64)
julia> Observation(discrete=true)
Observation{Int64}(Int64)
julia> Observation(Vector{Float64})
Observation{Vector{Float64}}(Vector{Float64})
julia> Observation(String)
Observation{String}(String)
ActionModels.Parameter
— TypeParameter(value; discrete=false)
Type constructor for defining parameters in action models. Can be continuous or discrete.
Arguments
value
: The default value of the parameter. Can be a single value or an array.discrete
: If true, parameter is treated as discrete (Int), otherwise as continuous (Float).
Examples
julia> Parameter(0.1)
Parameter{Float64}(0.1, Float64)
julia> Parameter(1, discrete=true)
Parameter{Int64}(1, Int64)
julia> Parameter([0.0, 0.0])
Parameter{Array{Float64}}([0.0, 0.0], Array{Float64})
ActionModels.Regression
— TypeRegression
Type for specifying a regression model in a population model. Contains the formula, prior, and inverse link function.
Fields
formula
: The regression formula (as aFormulaTerm
).prior
: The prior for regression coefficients and error terms (as aRegressionPrior
). Default isRegressionPrior()
.inv_link
: The inverse link function (default:identity
).
Examples
julia> Regression(@formula(y ~ x))
Regression(y ~ x, RegressionPrior{TDist{Float64}, Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}}(TDist{Float64}(ν=3.0), Truncated(TDist{Float64}(ν=3.0); lower=0.0)), identity)
julia> Regression(@formula(y ~ x + (1|ID)), exp) # With a random effect and a exponential inverse link function
Regression(y ~ x + :(1 | ID), RegressionPrior{TDist{Float64}, Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}}(TDist{Float64}(ν=3.0), Truncated(TDist{Float64}(ν=3.0); lower=0.0)), exp)
julia> Regression(@formula(y ~ x), RegressionPrior(β = TDist(4), σ = truncated(TDist(4), lower = 0)), logistic) #With a custom prior and logistic inverse link function
Regression(y ~ x, RegressionPrior{TDist{Float64}, Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}}(TDist{Float64}(ν=4.0), Truncated(TDist{Float64}(ν=4.0); lower=0.0)), logistic)
ActionModels.RegressionPrior
— TypeRegressionPrior{Dβ,Dσ}
Type for specifying priors for regression coefficients and random effects in regression population models.
Type Parameters
Dβ
: Distribution type for β regression coefficients.Dσ
: Distribution type for σ random effect deviations.
Fields
β
: Prior or vector of priors for regression coefficients (default:TDist(3)
). If only one prior is given, it is used for all coefficients. If a vector is given, it should match the number of coefficients in the regression formula.σ
: Prior or vector of priors for random effect deviations (default: truncatedTDist(3)
at 0). If only one prior is given, it is used for all random effects. If a vector is given, it should match the number of random effects in the regression formula.
Examples
julia> RegressionPrior()
RegressionPrior{TDist{Float64}, Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}}(TDist{Float64}(ν=3.0), Truncated(TDist{Float64}(ν=3.0); lower=0.0))
julia> RegressionPrior(β = TDist(4), σ = truncated(TDist(4), lower = 0))
RegressionPrior{TDist{Float64}, Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}}(TDist{Float64}(ν=4.0), Truncated(TDist{Float64}(ν=4.0); lower=0.0))
julia> RegressionPrior(β = [TDist(4), TDist(2)], σ = truncated(TDist(4), lower = 0)) #For setting multiple coefficients separately
RegressionPrior{TDist{Float64}, Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}}(TDist{Float64}[TDist{Float64}(ν=4.0), TDist{Float64}(ν=2.0)], Truncated(TDist{Float64}(ν=4.0); lower=0.0))
ActionModels.RejectParameters
— TypeRejectParameters(errortext)
Custom error type for rejecting parameter samples during inference.
Arguments
errortext
: Explanation for the rejection
Examples
julia> throw(RejectParameters("Parameter out of bounds"))
ERROR: RejectParameters("Parameter out of bounds")
ActionModels.RescorlaWagner
— TypeRescorlaWagner(; type=:continuous, initialvalue=nothing, learningrate=0.1, ncategories=nothing, responsemodel=nothing, responsemodelparameters=nothing, responsemodelobservations=nothing, responsemodelactions=nothing, actionnoise=nothing, actbefore_update=false)
Premade configuration type for the Rescorla-Wagner model. Used to construct an ActionModel
for continuous, binary, or categorical learning tasks.
Keyword Arguments
type
: Symbol specifying the model type (:continuous
,:binary
, or:categorical
).initial_value
: Initial value for the expected value (Float64 or Vector{Float64}).learning_rate
: Learning rate parameter (Float64).n_categories
: Number of categories (required for categorical models).response_model
: Custom response model function (optional).response_model_parameters
: NamedTuple of response model parameters (optional).response_model_observations
: NamedTuple of response model observations (optional).response_model_actions
: NamedTuple of response model actions (optional).action_noise
: Action noise parameter (Float64, optional; used if no custom response model is provided).act_before_update
: If true, action is selected before updating the expectation.
Examples
```jldoctest julia> rw = RescorlaWagner(type=:continuous, initialvalue=0.0, learningrate=0.2) RescorlaWagner(:continuous, 0.0, 0.2, nothing, ActionModels.var"#gaussianreport#270"(), (actionnoise = Parameter{Float64}(1.0, Float64),), (observation = Observation{Float64}(Float64),), (report = Action{Float64, Normal}(Float64, Normal),), false)
julia> ActionModel(rw)
julia> rwcat = RescorlaWagner(type=:categorical, ncategories=3, initialvalue=[0.0, 0.0, 0.0], learningrate=0.1) RescorlaWagner(:categorical, [0.0, 0.0, 0.0], 0.1, 3, ActionModels.var"#categoricalreport#272"(), (actionnoise = Parameter{Float64}(1.0, Float64),), (observation = Observation{Int64}(Int64),), (report = Action{Int64, Categorical{P} where P<:Real}(Int64, Categorical{P} where P<:Real),), false)
julia> ActionModel(rw_cat)
julia> customresponsemodel = (attributes) -> Normal(2 * attributes.submodel.expectedvalue, loadparameters(attributes).noise);
julia> rwcustomresp = RescorlaWagner(responsemodel=customresponsemodel, responsemodelobservations=(observation=Observation(Float64),), responsemodelactions=(report=Action(Normal),), responsemodel_parameters=(noise=Parameter(1.0, Float64),)) RescorlaWagner(...)
julia> ActionModel(rwcustomresp)
ActionModels.RescorlaWagnerAttributes
— TypeRescorlaWagnerAttributes{T,AT}
Container for the parameters and state of a Rescorla-Wagner submodel. Has a type parameter T
to allow using types supplied by the Turing header, and an attribute type AT
to allow dispatching between categorical and continuous/binary models.
Type Parameters
T
: Element type (e.g., Float64).AT
: Categorical vs continuous/binary typing (e.g., Float64 or Vector{Float64}).
Fields
initial_value
: Initial value parameter for the expected value.learning_rate
: Learning rate parameter.expected_value
: Current expected value.
Examples
julia> ActionModels.RescorlaWagnerAttributes{Float64,Float64}(initial_value=0.0, learning_rate=0.1, expected_value=0.0)
ActionModels.RescorlaWagnerAttributes{Float64, Float64}(0.0, 0.1, 0.0)
ActionModels.SampleSaveResume
— TypeSampleSaveResume
Configuration for saving and resuming MCMC sampling progress to disk.
This struct is used to enable periodic saving of sampler state during long MCMC runs, allowing interrupted sampling to be resumed from the last save point.
Fields
save_every::Int
: Number of iterations between saves (default: 100).path::String
: Directory path for saving sampler state (default: "./.samplingstate").chain_prefix::String
: Prefix for saved chain segment files (default: "ActionModelschainsegment").
Examples
julia> SampleSaveResume()
SampleSaveResume(100, "./.samplingstate", "ActionModels_chain_segment")
julia> SampleSaveResume(save_every=500, path="./.tmp", chain_prefix="my_chain")
SampleSaveResume(500, "./.tmp", "my_chain")
ActionModels.SessionParameters
— TypeSessionParameters
Container for session-level parameter estimates from model fitting.
Fields
value
: NamedTuple for each parameter, containing a NamedTuple for each session. Within is an AxisArray of samples.modelfit
: The associatedModelFit
object.estimated_parameter_names
: Tuple of estimated parameter names.session_ids
: Vector of session identifiers.parameter_types
: NamedTuple of parameter types.n_samples
: Number of posterior samples.n_chains
: Number of MCMC chains.
ActionModels.State
— TypeState(initialvalue; discrete=nothing) State(initialvalue, ::Type{T}) State(; discrete=false) State(::Type{T})
Construct a model state variable, which can be continuous, discrete, or a custom type.
Arguments
initial_value
: The initial value for the state (can be Real, Array, or custom type). Set tomissing
for no initial value.discrete
: If true, state is treated as discrete (Int), otherwise as continuous (Float). Only valid for Real or Array{<:Real} types.T
: Type of the state (for non-Real types).
Examples
julia> State(0.0)
State{Float64}(0.0, Float64)
julia> State(1, discrete=true)
State{Int64}(1, Int64)
julia> State([0.0, 0.0])
State{Array{Float64}}([0.0, 0.0], Array{Float64})
julia> State(discrete=true)
State{Int64}(missing, Int64)
julia> State(String)
State{String}(missing, String)
ActionModels.StateTrajectories
— TypeStateTrajectories
Container for state trajectory estimates from model fitting.
Fields
value
: NamedTuple for each state, containing a NamedTuple for each session. Within is an AxisArray of samples per timestep.modelfit
: The associatedModelFit
object.state_names
: Vector of state names.session_ids
: Vector of session identifiers.state_types
: NamedTuple of state types.n_samples
: Number of posterior samples.n_chains
: Number of MCMC chains.
ActionModels.Variable
— TypeVariable(value)
A mutable container for a value of type T
. Used for model parameters, states, and actions.
Arguments
value
: The value to store (any type)
Examples
julia> v = ActionModels.Variable(0.5)
ActionModels.Variable{Float64}(0.5)
julia> v.value = 1.0
1.0
julia> v
ActionModels.Variable{Float64}(1.0)
ActionModels.create_model
— Methodcreate_model(action_model::ActionModel, population_model::DynamicPPL.Model, data::DataFrame; observation_cols, action_cols, session_cols=Vector{Symbol}(), parameters_to_estimate, impute_missing_actions=false, check_parameter_rejections=false, population_model_type=CustomPopulationModel(), verbose=true)
Create a ModelFit
structure that can be used for sampling posterior and prior probability distributions. Consists of an action model, a population model, and a dataset.
This function prepares the data, checks consistency with the action and population models, handles missing data, and returns a ModelFit
object ready for sampling and inference.
Arguments
action_model::ActionModel
: The action model to fit.population_model::DynamicPPL.Model
: The population model (e.g. a Turing model that generated parameters for each session).data::DataFrame
: The dataset containing observations, actions, and session/grouping columns.observation_cols
: Columns indata
for observations. Can be aNamedTuple
,Vector{Symbol}
, orSymbol
.action_cols
: Columns indata
for actions. Can be aNamedTuple
,Vector{Symbol}
, orSymbol
.session_cols
: Columns indata
identifying sessions/groups (default: empty vector).parameters_to_estimate
: Tuple of parameter names to estimate.impute_missing_actions
: Whether to impute missing actions (default:false
).check_parameter_rejections
: Whether to check for parameter rejections (default:false
).population_model_type
: Type of population model (default:CustomPopulationModel()
).verbose
: Whether to print warnings and info (default:true
).
Returns
ModelFit
: Struct containing the model, data, and metadata for fitting and inference.
Example
julia> model = create_model(action_model, population_model, data; action_cols = :action, observation_cols = :observation, session_cols = :id, parameters_to_estimate = (:learning_rate,));
julia> model isa ActionModels.ModelFit
true
Notes
- The returned
ModelFit
object can be used withsample_posterior!
,sample_prior!
, and other inference utilities. - Handles missing actions according to the
impute_missing_actions
argument. - Checks that columns and types in
data
match the action model specification.
ActionModels.create_model
— Methodcreate_model(action_model::ActionModel, regressions::Union{Regression,FormulaTerm,Vector}, data::DataFrame; observation_cols, action_cols, session_cols=Vector{Symbol}(), verbose=true, kwargs...)
Create a hierarchical Turing model with a regression-based population model for parameter estimation.
This function builds a model where one or more agent parameters are modeled as a function of covariates using (generalized) linear regression, optionally with random effects. Returns a ModelFit
object ready for sampling and inference.
Arguments
action_model::ActionModel
: The agent/action model to fit.regressions::Union{Regression, FormulaTerm, Vector}
: One or more regression specifications (asRegression
objects or formulae).data::DataFrame
: The dataset containing observations, actions, covariates, and session/grouping columns.observation_cols
: Columns indata
for observations. Can be aNamedTuple
,Vector{Symbol}
, orSymbol
.action_cols
: Columns indata
for actions. Can be aNamedTuple
,Vector{Symbol}
, orSymbol
.session_cols
: Columns indata
identifying sessions/groups (default: empty vector).verbose
: Whether to print warnings and info (default:true
).kwargs...
: Additional keyword arguments passed to the underlying model constructor.
Returns
ModelFit
: Struct containing the model, data, and metadata for fitting and inference.
Example
julia> create_model(action_model, regression, data; action_cols = :action, observation_cols = :observation, session_cols = :id)
-- ModelFit object --
Action model: rescorla_wagner_act_after_update
Linear regression population model
1 estimated action model parameters, 2 sessions
Posterior not sampled
Prior not sampled
Notes
- Supports both fixed and random effects (random effects via formula syntax, e.g.,
learning_rate ~ 1 + (1|group)
). - Multiple regressions can be provided as a vector.
- The returned
ModelFit
object can be used withsample_posterior!
,sample_prior!
, and other inference utilities. - Covariate columns must be present in
data
.
ActionModels.create_model
— Methodcreate_model(action_model::ActionModel, prior::NamedTuple, data::DataFrame; observation_cols, action_cols, session_cols=Vector{Symbol}(), population_model_type=IndependentPopulationModel(), verbose=true, kwargs...)
Create a hierarchical Turing model with independent session-level parameters for each session.
This function builds a model where each session's parameters are sampled independently from the specified prior distributions. Returns a ModelFit
object ready for sampling and inference.
Arguments
action_model::ActionModel
: The agent/action model to fit.prior::NamedTuple
: Named tuple of prior distributions for each parameter (e.g.,(; learning_rate = LogitNormal())
).data::DataFrame
: The dataset containing observations, actions, and session/grouping columns.observation_cols
: Columns indata
for observations. Can be aNamedTuple
,Vector{Symbol}
, orSymbol
.action_cols
: Columns indata
for actions. Can be aNamedTuple
,Vector{Symbol}
, orSymbol
.session_cols
: Columns indata
identifying sessions/groups (default: empty vector).population_model_type
: Type of population model (default:IndependentPopulationModel()
).verbose
: Whether to print warnings and info (default:true
).kwargs...
: Additional keyword arguments passed to the underlying model constructor.
Returns
ModelFit
: Struct containing the model, data, and metadata for fitting and inference.
Example
julia> model = create_model(action_model, prior, data; action_cols = :action, observation_cols = :observation, session_cols = :id);
julia> model isa ActionModels.ModelFit
true
Notes
- Each session's parameters are sampled independently from the specified priors.
- The returned
ModelFit
object can be used withsample_posterior!
,sample_prior!
, and other inference utilities. - Use this model when you do not want to share information across sessions/groups.
ActionModels.create_model
— Methodcreate_model(action_model::ActionModel, prior::NamedTuple, observations::Vector, actions::Vector; verbose=true, kwargs...)
Create a Turing model for a single session with user-supplied observations and actions.
This function builds a model where all data belong to a single session, and parameters are sampled from the specified prior distributions. Returns a ModelFit
object ready for sampling and inference.
Arguments
action_model::ActionModel
: The agent/action model to fit.prior::NamedTuple
: Named tuple of prior distributions for each parameter (e.g.,(; learning_rate = LogitNormal())
).observations::Vector
: Vector of observations (or tuples of observations) for the session.actions::Vector
: Vector of actions (or tuples of actions) for the session.verbose
: Whether to print warnings and info (default:true
).kwargs...
: Additional keyword arguments passed to the underlying model constructor.
Returns
ModelFit
: Struct containing the model, data, and metadata for fitting and inference.
Example
julia> model = create_model(action_model, prior, obs, acts);
julia> model isa ActionModels.ModelFit
true
Notes
- Use this model for fitting a single session or subject.
- The returned
ModelFit
object can be used withsample_posterior!
,sample_prior!
, and other inference utilities. - Handles both scalar and tuple-valued observations/actions.
ActionModels.get_actions
— Methodget_actions(agent::Agent, target_action::Symbol)
Get a single model action for an agent.
Arguments
agent::Agent
: The agent whose action will be retrieved.target_action::Symbol
: Name of the action.
Returns
- Value of the specified action.
Example
julia> get_actions(agent, :report)
missing
julia> observe!(agent, 0.5);
julia> action = get_actions(agent, :report);
julia> action isa Real
true
ActionModels.get_actions
— Methodget_actions(agent::Agent, target_actions::Tuple{Vararg{Symbol}})
Get multiple previously selected actions for an agent using a tuple of names.
Arguments
agent::Agent
: The agent whose actions will be retrieved.target_actions::Tuple{Vararg{Symbol}}
: Tuple of action names.
Returns
- Tuple of action values.
Example
julia> get_actions(agent, (:report,))
(report = missing,)
ActionModels.get_actions
— Methodget_actions(agent::Agent)
Get all previously selected actions for an agent.
Arguments
agent::Agent
: The agent whose actions will be retrieved.
Returns
- NamedTuple of all action names and values.
Example
julia> get_actions(agent)
(report = missing,)
ActionModels.get_history
— Methodget_history(agent::Agent, target_state::Symbol)
Get the history for a single state for an agent.
Arguments
agent::Agent
: The agent whose history will be retrieved.target_state::Symbol
: Name of the state.
Returns
- Array of state values over time.
Example
julia> get_history(agent, :expected_value)
1-element Vector{Float64}:
0.0
ActionModels.get_history
— Methodget_history(agent::Agent, target_state::Tuple{Vararg{Symbol}})
Get the history for multiple states for an agent.
Arguments
agent::Agent
: The agent whose history will be retrieved.target_state::Tuple{Vararg{Symbol}}
: Tuple of state names.
Returns
- Tuple of history arrays for the specified states.
Example
julia> get_history(agent, (:expected_value,))
(expected_value = [0.0],)
ActionModels.get_history
— Methodget_history(agent::Agent)
Get the full state history for an agent.
Arguments
agent::Agent
: The agent whose history will be retrieved.
Returns
- Dictionary mapping state names to their history arrays.
Example
julia> get_history(agent)
(expected_value = [0.0],)
ActionModels.get_parameters
— Methodget_parameters(agent::Agent, target_param::Symbol)
Get a single model parameter for an agent.
Arguments
agent::Agent
: The agent whose parameter will be retrieved.target_param::Symbol
: Name of the parameter.
Returns
- Value of the specified parameter.
Example
julia> get_parameters(agent, :learning_rate)
0.1
ActionModels.get_parameters
— Methodget_parameters(agent::Agent, target_parameters::Tuple{Vararg{Symbol}})
Get multiple model parameters for an agent using a tuple of names.
Arguments
agent::Agent
: The agent whose parameters will be retrieved.target_parameters::Tuple{Vararg{Symbol}}
: Tuple of parameter names.
Returns
- Tuple of parameter values.
Example
julia> get_parameters(agent, (:learning_rate, :action_noise))
(learning_rate = 0.1, action_noise = 1.0)
ActionModels.get_parameters
— Methodget_parameters(agent::Agent)
Get all model parameters for an agent.
Arguments
agent::Agent
: The agent whose parameters will be retrieved.
Returns
- NamedTuple of all parameter names and values.
Example
julia> get_parameters(agent)
(action_noise = 1.0, learning_rate = 0.1, initial_value = 0.0)
ActionModels.get_session_parameters!
— Functionget_session_parameters!(modelfit::ModelFit, prior_or_posterior::Symbol = :posterior; verbose::Bool = true)
Extract posterior or prior samples of session-level parameters from a fitted model.
If the requested samples have not yet been drawn, this function will call sample_posterior!
or sample_prior!
as needed. Returns a SessionParameters
struct containing the samples for each session and parameter.
Arguments
modelfit::ModelFit
: The fitted model object.prior_or_posterior::Symbol = :posterior
: Whether to extract from the posterior (:posterior
) or prior (:prior
).verbose::Bool = true
: Whether to print warnings if sampling is triggered.
Returns
SessionParameters
: Struct containing samples for each session and parameter.
Example
julia> params = get_session_parameters!(model);
julia> params isa ActionModels.SessionParameters
true
Notes
- Use
prior_or_posterior = :prior
to extract prior samples instead of posterior. - The returned object can be summarized with
Turing.summarize
.
ActionModels.get_state_trajectories!
— Functionget_state_trajectories!(modelfit::ModelFit, target_states::Union{Symbol,Vector{Symbol}}, prior_or_posterior::Symbol = :posterior)
Extract posterior or prior samples of state trajectories for specified states from a fitted model.
If the requested samples have not yet been drawn, this function will call sample_posterior!
or sample_prior!
as needed. Returns a StateTrajectories
struct containing the samples for each session, state, and timestep.
Arguments
modelfit::ModelFit
: The fitted model object.target_states::Union{Symbol,Vector{Symbol}}
: State or states to extract trajectories for (e.g.,:expected_value
).prior_or_posterior::Symbol = :posterior
: Whether to extract from the posterior (:posterior
) or prior (:prior
).
Returns
StateTrajectories
: Struct containing samples for each session, state, and timestep.
Example
julia> trajs = get_state_trajectories!(model, :expected_value);
julia> trajs isa ActionModels.StateTrajectories
true
Notes
- Use
prior_or_posterior = :prior
to extract prior samples instead of posterior. - The returned object can be summarized with
Turing.summarize
.
ActionModels.get_states
— Methodget_states(agent::Agent, target_state::Symbol)
Get a single model state for an agent.
Arguments
agent::Agent
: The agent whose state will be retrieved.target_state::Symbol
: Name of the state.
Returns
- Value of the specified state.
Example
julia> get_states(agent, :expected_value)
0.0
ActionModels.get_states
— Methodget_states(agent::Agent, target_states::Tuple{Vararg{Symbol}})
Get multiple model states for an agent using a tuple of names.
Arguments
agent::Agent
: The agent whose states will be retrieved.target_states::Tuple{Vararg{Symbol}}
: Tuple of state names.
Returns
- Tuple of state values.
Example
julia> get_states(agent, (:expected_value,))
(expected_value = 0.0,)
ActionModels.get_states
— Methodget_states(agent::Agent)
Get all model states for an agent.
Arguments
agent::Agent
: The agent whose states will be retrieved.
Returns
- NamedTuple of all state names and values.
Example
julia> get_states(agent)
(expected_value = 0.0,)
ActionModels.init_agent
— Methodinit_agent(action_model::ActionModel; save_history=false)
Initialize an Agent
for simulation, given an ActionModel
. Optionally specify which states to save in the agent's history.
Arguments
action_model
: TheActionModel
to use for the agent.save_history
: Iftrue
, save all states; iffalse
, save none; if aSymbol
orVector{Symbol}
, save only those states (default:false
).
Returns
Agent
: An initialized agent ready for simulation.
Examples
julia> am = ActionModel(RescorlaWagner());
julia> agent = init_agent(am, save_history=true)
-- ActionModels Agent --
Action model: rescorla_wagner_act_after_update
This agent has received 0 observations
ActionModels.load_actions
— Methodloadactions(modelattributes)
Load the actions from the model attributes. This is used within the action model definition to extract the current action values.
Arguments
model_attributes
: The model attributes object.
Returns
A vector of action values.
Example
julia> actions = load_actions(attributes)
(report = missing,)
ActionModels.load_parameters
— Methodloadparameters(modelattributes)
Load the parameters from the model attributes. This is used within the action model definition to extract the current parameter values.
Arguments
model_attributes
: The model attributes object.
Returns
A vector of parameter values.
Example
julia> params = load_parameters(attributes)
(learning_rate = 0.1,)
ActionModels.load_states
— Methodloadstates(modelattributes)
Load the states from the model attributes. This is used within the action model definition to extract the current state values.
Arguments
model_attributes
: The model attributes object.
Returns
A vector of state values.
Example
julia> states = load_states(attributes)
(expected_value = 0.0,)
ActionModels.observe!
— Methodobserve!(agent::Agent, observation)
Advance the agent by one timestep using the given observation, updating its state and returning the sampled action.
Arguments
agent
: TheAgent
to update.observation
: The observation(s) for this timestep (tuple or value).
Returns
- The sampled action(s) for this timestep.
Examples
julia> action = observe!(agent, 0.5);
julia> action isa Real
true
ActionModels.reset!
— Methodreset!(agent::Agent)
Reset the agent's model attributes and history to their initial states.
This function resets all model parameters, states, and actions to their initial values, clears the agent's history, and sets the timestep counter to zero. This is useful for running new simulations with the same agent instance.
Arguments
agent::Agent
: The agent whose attributes and history will be reset.
Example
julia> reset!(agent)
ActionModels.sample_posterior!
— Functionsample_posterior!(modelfit::ModelFit, parallelization::AbstractMCMC.AbstractMCMCEnsemble = MCMCSerial(); verbose=true, resample=false, save_resume=nothing, init_params=:sample_prior, n_samples=1000, n_chains=2, adtype=AutoForwardDiff(), sampler=NUTS(; adtype=adtype), sampler_kwargs...)
Sample from the posterior distribution of a fitted model using MCMC.
This function runs MCMC sampling for the provided ModelFit
object, storing the results in modelfit.posterior
. It supports saving and resuming sampling, parallelization, various ways of initializing parameters for the sampling, and specifying detailed settings for the sampling. Returns the sampled chains.
Arguments
modelfit::ModelFit
: The model structure to sample with.parallelization::AbstractMCMC.AbstractMCMCEnsemble
: Parallelization strategy (default:MCMCSerial()
).verbose::Bool
: Whether to display warnings (default:true
).resample::Bool
: Whether to force resampling even if results exist (default:false
).save_resume::Union{SampleSaveResume,Nothing}
: Save/resume configuration (default:nothing
).init_params::Union{Nothing,Symbol,Vector{Float64}}
: How to initialize the sampler (default::sample_prior
).n_samples::Integer
: Number of samples per chain (default:1000
).n_chains::Integer
: Number of MCMC chains (default:2
).adtype
: Automatic differentiation type (default:AutoForwardDiff()
).sampler
: Sampler algorithm (default:NUTS
).sampler_kwargs...
: Additional keyword arguments for the sampler.
Returns
Chains
: The sampled posterior chains.
Example
julia> chns = sample_posterior!(model, sampler = HMC(0.8, 10), n_chains = 1, n_samples = 100, progress = false);
julia> chns isa Chains
true
ActionModels.sample_prior!
— Functionsample_prior!(modelfit::ModelFit, parallelization::AbstractMCMC.AbstractMCMCEnsemble = MCMCSerial(); resample=false, n_samples=1000, n_chains=2)
Sample from the prior distribution of a fitted model using MCMC.
This function samples from the prior for the provided ModelFit
object, storing the results in modelfit.prior
. Returns the sampled chains.
Arguments
modelfit::ModelFit
: The model structure to sample with.parallelization::AbstractMCMC.AbstractMCMCEnsemble
: Parallelization strategy (default:MCMCSerial()
).resample::Bool
: Whether to force resampling even if results exist (default:false
).n_samples::Integer
: Number of samples per chain (default:1000
).n_chains::Integer
: Number of MCMC chains (default:2
).
Returns
Chains
: The sampled prior chains.
Example
julia> chns = sample_prior!(model, progress = false);
julia> chns isa Chains
true
ActionModels.set_actions!
— Methodset_actions!(agent::Agent, actions::NamedTuple)
Set multiple actions as having been previously chosen for an agent using a NamedTuple.
Arguments
agent::Agent
: The agent whose actions will be set.actions::NamedTuple
: NamedTuple of action names and values.
Example
julia> set_actions!(agent, (;report=1.,))
ActionModels.set_actions!
— Methodset_actions!(agent::Agent, target_action::Symbol, target_value::Real)
Set a single model action for an agent.
Arguments
agent::Agent
: The agent whose action will be set.target_action::Symbol
: Name of the action.target_value::Real
: Value to set.
Example
julia> set_actions!(agent, :report, 1)
ActionModels.set_actions!
— Methodset_actions!(agent::Agent, action_names::Tuple{Vararg{Symbol}}, action_values::Tuple{Vararg{Real}})
Set multiple actions as having been previously chosen for an agent using tuples of names and values.
Arguments
agent::Agent
: The agent whose actions will be set.action_names::Tuple{Vararg{Symbol}}
: Tuple of action names.action_values::Tuple{Vararg{Real}}
: Tuple of action values.
Example
julia> set_actions!(agent, (:report,), (1,))
ActionModels.set_actions!
— Methodset_actions!(agent::Agent, action_names::Vector{Symbol}, action_values::Tuple{Vararg{Real}})
Set multiple action as previously selected for an agent using a vector of names and a tuple of values.
Arguments
agent::Agent
: The agent whose actions will be set.action_names::Vector{Symbol}
: Vector of action names.action_values::Tuple{Vararg{Real}}
: Tuple of action values.
Example
julia> set_actions!(agent, [:report], (1,))
ActionModels.set_parameters!
— Methodset_parameters!(agent::Agent, parameters::NamedTuple)
Set multiple model parameters for an agent using a NamedTuple.
Arguments
agent::Agent
: The agent whose parameters will be set.parameters::NamedTuple
: NamedTuple of parameter names and values.
Example
julia> set_parameters!(agent, (learning_rate=0.2, action_noise=0.1))
ActionModels.set_parameters!
— Methodset_parameters!(agent::Agent, target_param::Symbol, target_value::Union{R,AbstractArray{R}})
Set a single model parameter for an agent.
Arguments
agent::Agent
: The agent whose parameter will be set.target_param::Symbol
: Name of the parameter.target_value::Union{R,AbstractArray{R}}
: Value to set.
Example
julia> set_parameters!(agent, :learning_rate, 0.2)
ActionModels.set_parameters!
— Methodset_parameters!(agent::Agent, parameter_names::Tuple{Vararg{Symbol}}, parameter_values::Tuple{Vararg{Union{R,AbstractArray{R}}}})
Set multiple model parameters for an agent using tuples of names and values.
Arguments
agent::Agent
: The agent whose parameters will be set.parameter_names::Tuple{Vararg{Symbol}}
: Tuple of parameter names.parameter_values::Tuple{Vararg{Union{R,AbstractArray{R}}}}
: Tuple of parameter values.
Example
julia> set_parameters!(agent, (:learning_rate, :action_noise), (0.2, 0.1))
ActionModels.set_parameters!
— Methodset_parameters!(agent::Agent, parameter_names::Vector{Symbol}, parameter_values::Tuple{Vararg{Union{R,AbstractArray{R}}}})
Set multiple model parameters for an agent using a vector of names and a tuple of values.
Arguments
agent::Agent
: The agent whose parameters will be set.parameter_names::Vector{Symbol}
: Vector of parameter names.parameter_values::Tuple{Vararg{Union{R,AbstractArray{R}}}}
: Tuple of parameter values.
Example
julia> set_parameters!(agent, [:learning_rate, :action_noise], (0.2, 0.1))
ActionModels.set_states!
— Methodset_states!(agent::Agent, states::NamedTuple)
Set multiple model states for an agent using a NamedTuple.
Arguments
agent::Agent
: The agent whose states will be set.states::NamedTuple
: NamedTuple of state names and values.
Example
julia> set_states!(agent, (expected_value=0.0,))
ActionModels.set_states!
— Methodset_states!(agent::Agent, target_state::Symbol, target_value::Any)
Set a single model state for an agent.
Arguments
agent::Agent
: The agent whose state will be set.target_state::Symbol
: Name of the state.target_value::Any
: Value to set.
Example
julia> set_states!(agent, :expected_value, 0.0)
ActionModels.set_states!
— Methodset_states!(agent::Agent, state_names::Tuple{Vararg{Symbol}}, state_values::Tuple{Vararg{Any}})
Set multiple model states for an agent using tuples of names and values.
Arguments
agent::Agent
: The agent whose states will be set.state_names::Tuple{Vararg{Symbol}}
: Tuple of state names.state_values::Tuple{Vararg{Any}}
: Tuple of state values.
Example
julia> set_states!(agent, (:expected_value,), (0.0,))
ActionModels.set_states!
— Methodset_states!(agent::Agent, state_names::Vector{Symbol}, state_values::Tuple{Vararg{Any}})
Set multiple model states for an agent using a vector of names and a tuple of values.
Arguments
agent::Agent
: The agent whose states will be set.state_names::Vector{Symbol}
: Vector of state names.state_values::Tuple{Vararg{Any}}
: Tuple of state values.
Example
julia> set_states!(agent, [:expected_value], (0.0,))
ActionModels.simulate!
— Methodsimulate!(agent::Agent, observations::AbstractMatrix)
Simulate the agent forward for multiple timesteps, using a matrix of observations (each row is a timestep). Returns a vector of actions.
Arguments
agent
: TheAgent
to simulate.observations
: Matrix of observations (each row is a tuple or value for one timestep).
Returns
- Vector of actions, one per timestep.
ActionModels.simulate!
— Methodsimulate!(agent::Agent, observations::AbstractVector)
Simulate the agent forward for multiple timesteps, using a vector of observations. Returns a vector of actions for each timestep.
Arguments
agent
: TheAgent
to simulate.observations
: Vector of observations (each element is a tuple or value for one timestep).
Returns
- Vector of actions, one per timestep.
Examples
julia> actions = simulate!(agent, [0.5, 0.7, 0.9]);
julia> actions isa AbstractVector
true
ActionModels.update!
— Methodupdate!(attributes::RescorlaWagnerAttributes, observation)
Update the expected value(s) in a Rescorla-Wagner submodel attributes according to the observation. This function is dispatched for continuous, binary, and categorical models.
Arguments
attributes
: ARescorlaWagnerAttributes
struct containing the current state and parameters.observation
: The observed outcome. Type depends on the model variant:Float64
for continuous modelsInt64
for binary or categorical models (category index)Vector{Int64}
for categorical models (one-hot or binary vector)Vector{Float64}
for categorical models (continuous vector)
Examples
julia> attrs = ActionModels.RescorlaWagnerAttributes{Float64,Float64}(initial_value=0.0, learning_rate=0.2, expected_value=0.0);
julia> ActionModels.update!(attrs, 1.0); # Continuous update
julia> attrs.expected_value
0.2
julia> attrs_bin = ActionModels.RescorlaWagnerAttributes{Float64,Float64}(initial_value=0.0, learning_rate=0.5, expected_value=0.0);
julia> ActionModels.update!(attrs_bin, 1); # Binary update
julia> attrs_bin.expected_value
0.25
julia> attrs_cat = ActionModels.RescorlaWagnerAttributes{Float64,Vector{Float64}}(initial_value=[0.0, 0.0, 0.0], learning_rate=0.1, expected_value=[0.0, 0.0, 0.0]);
julia> ActionModels.update!(attrs_cat, 2); # Categorical update with category index
julia> attrs_cat.expected_value
3-element Vector{Float64}:
-0.05
0.05
-0.05
ActionModels.update_state!
— Methodupdatestate!(modelattributes, statename, statevalue)
Update the value of a state in the model attributes. This function is used within the action model definition to update the state values during the simulation or fitting process.
Arguments
model_attributes
: The model attributes object.state_name
: The name of the state to update.state_value
: The new value for the state.
````
# Example
jldoctest; setup = :(using ActionModels; attributes = ModelAttributes((learningrate=ActionModels.Variable(0.1),), (expectedvalue=ActionModels.Variable(0.0),), (report=ActionModels.Variable(missing),), (expectedvalue=ActionModels.Variable(0.0),), ActionModels.NoSubModelAttributes())) julia> updatestate!(attributes, :expectedvalue, 2.0) # Update state expectedvalue to 2.0 ```
MCMCChains.summarize
— FunctionTuring.summarize(session_parameters::SessionParameters, summary_function::Function=median)
Summarize posterior samples of session-level parameters into a tidy DataFrame
.
Each row corresponds to a session, and each column to a parameter (or parameter element, for arrays). The summary statistic (e.g., median
, mean
) is applied to the posterior samples for each parameter and session.
Arguments
session_parameters::SessionParameters
: Posterior samples of session-level parameters, as returned by model fitting.summary_function::Function=median
: Function to summarize the samples (e.g.,mean
,median
,std
).
Returns
DataFrame
: Table with one row per session and columns for each parameter (or parameter element).
Example
julia> df = summarize(get_session_parameters!(model), median);
julia> df isa DataFrame
true
Notes
- For array-valued parameters, columns are named with indices, e.g.,
:expected_value[1]
. - Session identifiers are split into columns if composite.
MCMCChains.summarize
— FunctionTuring.summarize(state_trajectories::StateTrajectories, summary_function::Function=median)
Summarize posterior samples of state trajectories into a tidy DataFrame
.
Each row corresponds to a session and timestep, and each column to a state variable (or state element, for arrays). The summary statistic (e.g., median
, mean
) is applied to the posterior samples for each state, session, and timestep.
Arguments
state_trajectories::StateTrajectories
: Posterior samples of state trajectories, as returned by model fitting.summary_function::Function=median
: Function to summarize the samples (e.g.,mean
,median
,std
).
Returns
DataFrame
: Table with one row per session and timestep, and columns for each state (or state element).
Example
julia> df = summarize(get_state_trajectories!(model, :expected_value), mean);
julia> df isa DataFrame
true
Notes
- For array-valued states, columns are named with indices, e.g.,
:expected_value[1]
. - Session identifiers are split into columns if composite.
- The column
timestep
indicates the time index (starting from 0).
RecipesBase.apply_recipe
— Functionf(agent::Agent, target_state::Symbol, index::Union{Nothing,Vector{Int}}=nothing)
Plot the trajectory (history) of a state variable for an agent over time.
This plotting recipe visualizes the time course of a given state variable from an agent's history. If the state is multivariate (e.g., a vector or array), you can specify an index to plot a particular dimension. If the state is univariate, the index should be left as nothing
(the default).
Arguments
agent::Agent
: The agent whose state trajectory will be plotted.target_state::Symbol
: The name of the state variable to plot (e.g.,:expected_value
).index::Union{Nothing,Vector{Int}}
: (Optional) Index for multivariate states (e.g.,[1]
for the first element).
Example
julia> plot(agent, :expected_value)
Plot{Plots.GRBackend() n=1}
For a multivariate state (e.g., a vector):
julia> plot(agent, :expected_value, [1])
Plot{Plots.GRBackend() n=1}