Full API reference

ActionModels.ActionType

Action(distributiontype, actiontype)

Construct an action output for the model, specifying the distribution type (e.g., Normal, Bernoulli).

Arguments

  • T: Type of the action (optional, inferred from distribution_type if not given).
  • distribution_type: The distribution type for the action (e.g., Normal, Bernoulli, MvNormal). Can also be an abstract type like Distribution{Multivariate, Continuous} to allow multiple types of distributions.

Examples

julia> Action(Normal)
Action{Float64, Normal}(Float64, Normal)

julia> Action(Bernoulli)
Action{Int64, Bernoulli}(Int64, Bernoulli)

julia> Action(MvNormal)
Action{Array{Float64}, MvNormal}(Array{Float64}, MvNormal)

julia> Action(Distribution{Multivariate, Discrete})
Action{Array{Int64}, Distribution{Multivariate, Discrete}}(Array{Int64}, Distribution{Multivariate, Discrete})
source
ActionModels.ActionModelType

ActionModel(action_model; parameters, states, observations, actions, submodel, verbose)

Main container for a user-defined or premade action model.

Arguments

  • action_model: The function implementing the model's update and action logic.
  • parameters: NamedTuple of model parameters or a single Parameter.
  • states: NamedTuple of model states or a single State (optional).
  • observations: NamedTuple of model observations or a single Observation (optional).
  • actions: NamedTuple of model actions or a single Action.
  • submodel: Optional submodel for hierarchical or modular models (default: NoSubModel()).
  • verbose: Print warnings for singletons (default: true).

Examples

julia> model_fn = (attributes, obs) -> Normal(0, 1);

julia> ActionModel(model_fn, parameters=(learning_rate=Parameter(0.1),), states=(expected_value=State(0.0),), observations=(observation=Observation(),), actions=(report=Action(Normal),))
-- ActionModel --
Action model function: #1
Number of parameters: 1
Number of states: 1
Number of observations: 1
Number of actions: 1

julia> ActionModel(model_fn, parameters=Parameter(0.1), actions = Action(Normal), verbose = false)
-- ActionModel --
Action model function: #1
Number of parameters: 1
Number of states: 0
Number of observations: 0
Number of actions: 1
source
ActionModels.AgentType
Agent{T<:AbstractSubmodel}

Container for a simulated agent, including the action model, model attributes, state history, and number of timesteps.

Fields

  • action_model: The function implementing the agent's action model logic.
  • model_attributes: The ModelAttributes instance containing parameters, states, and actions.
  • history: NamedTuple of vectors storing the history of selected states.
  • n_timesteps: Variable tracking the number of timesteps simulated.

Examples

julia> am = ActionModel(RescorlaWagner());

julia> agent = init_agent(am, save_history=true)
-- ActionModels Agent --
Action model: rescorla_wagner_act_after_update
This agent has received 0 observations
source
ActionModels.AttributeErrorType

AttributeError()

Custom error type for missing or invalid attribute access. Used for error handling in custom submodels.

Examples

julia> throw(AttributeError())
ERROR: AttributeError()
source
ActionModels.BinaryRescorlaWagnerType

BinaryRescorlaWagner(; initialvalue=0.0, learningrate=0.1)

Binary Rescorla-Wagner submodel for use in action models.

Fields

  • initial_value: Initial expected value parameter (Float64).
  • learning_rate: Learning rate parameter (Float64).

Examples

julia> ActionModels.BinaryRescorlaWagner()
ActionModels.BinaryRescorlaWagner(0.0, 0.1)
source
ActionModels.CategoricalRescorlaWagnerType

CategoricalRescorlaWagner(; ncategories, initialvalue=zeros(ncategories), learningrate=0.1)

Categorical Rescorla-Wagner submodel for use in action models.

Fields

  • n_categories: Number of categories (Int64).
  • initial_value: Initial expected value vector parameter (Vector{Float64}).
  • learning_rate: Learning rate parameter (Float64).

Examples

julia> ActionModels.CategoricalRescorlaWagner(n_categories=3)
ActionModels.CategoricalRescorlaWagner(3, [0.0, 0.0, 0.0], 0.1)
source
ActionModels.ContinuousRescorlaWagnerType

ContinuousRescorlaWagner(; initialvalue=0.0, learningrate=0.1)

Continuous Rescorla-Wagner submodel for use in action models.

Fields

  • initial_value: Initial expected value parameter (Float64).
  • learning_rate: Learning rate parameter (Float64).

Examples

julia> ActionModels.ContinuousRescorlaWagner()
ActionModels.ContinuousRescorlaWagner(0.0, 0.1)
source
ActionModels.InitialStateParameterType

InitialStateParameter(initialvalue, statename; discrete=false)

Type for defining initial state parameters in action models. Initial state parameters define the initial value of a state in the model. Can be continuous or discrete.

Arguments

  • value: The default value for the initial state parameter. Can be a single value or an array.
  • state_name: The symbol name of the state this parameter controls.
  • discrete: If true, value is treated as discrete (Int), otherwise as continuous (Float).

Examples

julia> InitialStateParameter(0.0, :expected_value)
InitialStateParameter{Float64}(:expected_value, 0.0, Float64)

julia> InitialStateParameter(1, :counter, discrete=true)
InitialStateParameter{Int64}(:counter, 1, Int64)

julia> InitialStateParameter([0.0, 0.0], :weights)
InitialStateParameter{Array{Float64}}(:weights, [0.0, 0.0], Array{Float64})
source
ActionModels.ModelAttributesType

ModelAttributes(parameters, states, actions, initial_states, submodel)

Internal container for all model variables and submodel attributes used in an action model instance.

Arguments

  • parameters: NamedTuple of parameter variables
  • states: NamedTuple of state variables
  • actions: NamedTuple of action variables
  • initial_states: NamedTuple of initial state values
  • submodel: Submodel attributes (or NoSubModelAttributes if not used)

Examples

julia> ModelAttributes((learning_rate=ActionModels.Variable(0.1),), (expected_value=ActionModels.Variable(0.0),), (report=ActionModels.Variable(missing),), (expected_value=ActionModels.Variable(0.0),), ActionModels.NoSubModelAttributes())
ModelAttributes{@NamedTuple{learning_rate::ActionModels.Variable{Float64}}, @NamedTuple{expected_value::ActionModels.Variable{Float64}}, @NamedTuple{report::ActionModels.Variable{Missing}}, @NamedTuple{expected_value::ActionModels.Variable{Float64}}, ActionModels.NoSubModelAttributes}((learning_rate = ActionModels.Variable{Float64}(0.1),), (expected_value = ActionModels.Variable{Float64}(0.0),), (report = ActionModels.Variable{Missing}(missing),), (expected_value = ActionModels.Variable{Float64}(0.0),), ActionModels.NoSubModelAttributes())
source
ActionModels.ModelFitType

ModelFit{T}

Container for a fitted model, including the Turing model, population model type, data, and results (both posterior and prior).

Type Parameters

  • T: The population model type (subtype of AbstractPopulationModel).

Fields

  • model: The Turing model object.
  • population_model_type: The population model type.
  • population_data: The data describing each session (i.e., after calling unique(data, session_cols), so no observations and actions).
  • info: Metadata about the fit (as a ModelFitInfo).
  • prior: Prior fit results (empty until sample_posterior! is called).
  • posterior: Posterior fit results (empty until sample_prior! is called).
source
ActionModels.ModelFitInfoType

ModelFitInfo

Container for metadata about a model fit, including session IDs and estimated parameter names.

Fields

  • session_ids::Vector{String}: List of session identifiers.
  • estimated_parameter_names::Tuple{Vararg{Symbol}}: Names of estimated parameters.
source
ActionModels.ModelFitResultType

ModelFitResult

Container for the results of a model fit (either prior or posterior), including MCMC chains and (optionally) session-level parameters.

Fields

  • chains::Chains: MCMC chains from Turing.jl.
  • session_parameters::Union{Nothing,AbstractFittingResult}: Session-level parameter results (optional).
source
ActionModels.NoSubModelType

NoSubModel()

Default submodel type used when no submodel is specified in an ActionModel.

Examples

julia> ActionModels.NoSubModel()
ActionModels.NoSubModel()
source
ActionModels.ObservationType

Observation(; discrete=false) Observation([T])

Construct an observation input to the model. Can be continuous (Float64), discrete (Int64), or a custom type.

Arguments

  • discrete: If true, observation is treated as discrete (Int64), otherwise as continuous (Float64).
  • T: Type of the observation (e.g., Float64, Int64, Vector{Float64}). Used for setting specific types.

Examples

julia> Observation()
Observation{Float64}(Float64)

julia> Observation(discrete=true)
Observation{Int64}(Int64)

julia> Observation(Vector{Float64})
Observation{Vector{Float64}}(Vector{Float64})

julia> Observation(String)
Observation{String}(String)
source
ActionModels.ParameterType

Parameter(value; discrete=false)

Type constructor for defining parameters in action models. Can be continuous or discrete.

Arguments

  • value: The default value of the parameter. Can be a single value or an array.
  • discrete: If true, parameter is treated as discrete (Int), otherwise as continuous (Float).

Examples

julia> Parameter(0.1)
Parameter{Float64}(0.1, Float64)

julia> Parameter(1, discrete=true)
Parameter{Int64}(1, Int64)

julia> Parameter([0.0, 0.0])
Parameter{Array{Float64}}([0.0, 0.0], Array{Float64})
source
ActionModels.RegressionType

Regression

Type for specifying a regression model in a population model. Contains the formula, prior, and inverse link function.

Fields

  • formula: The regression formula (as a FormulaTerm).
  • prior: The prior for regression coefficients and error terms (as a RegressionPrior). Default is RegressionPrior().
  • inv_link: The inverse link function (default: identity).

Examples

julia> Regression(@formula(y ~ x))
Regression(y ~ x, RegressionPrior{TDist{Float64}, Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}}(TDist{Float64}(ν=3.0), Truncated(TDist{Float64}(ν=3.0); lower=0.0)), identity)

julia> Regression(@formula(y ~ x + (1|ID)), exp) # With a random effect and a exponential inverse link function
Regression(y ~ x + :(1 | ID), RegressionPrior{TDist{Float64}, Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}}(TDist{Float64}(ν=3.0), Truncated(TDist{Float64}(ν=3.0); lower=0.0)), exp)

julia> Regression(@formula(y ~ x), RegressionPrior(β = TDist(4), σ = truncated(TDist(4), lower = 0)), logistic) #With a custom prior and logistic inverse link function
Regression(y ~ x, RegressionPrior{TDist{Float64}, Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}}(TDist{Float64}(ν=4.0), Truncated(TDist{Float64}(ν=4.0); lower=0.0)), logistic)
source
ActionModels.RegressionPriorType

RegressionPrior{Dβ,Dσ}

Type for specifying priors for regression coefficients and random effects in regression population models.

Type Parameters

  • : Distribution type for β regression coefficients.
  • : Distribution type for σ random effect deviations.

Fields

  • β: Prior or vector of priors for regression coefficients (default: TDist(3)). If only one prior is given, it is used for all coefficients. If a vector is given, it should match the number of coefficients in the regression formula.
  • σ: Prior or vector of priors for random effect deviations (default: truncated TDist(3) at 0). If only one prior is given, it is used for all random effects. If a vector is given, it should match the number of random effects in the regression formula.

Examples

julia> RegressionPrior()
RegressionPrior{TDist{Float64}, Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}}(TDist{Float64}(ν=3.0), Truncated(TDist{Float64}(ν=3.0); lower=0.0))

julia> RegressionPrior(β = TDist(4), σ = truncated(TDist(4), lower = 0))
RegressionPrior{TDist{Float64}, Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}}(TDist{Float64}(ν=4.0), Truncated(TDist{Float64}(ν=4.0); lower=0.0))

julia> RegressionPrior(β = [TDist(4), TDist(2)], σ = truncated(TDist(4), lower = 0)) #For setting multiple coefficients separately
RegressionPrior{TDist{Float64}, Truncated{TDist{Float64}, Continuous, Float64, Float64, Nothing}}(TDist{Float64}[TDist{Float64}(ν=4.0), TDist{Float64}(ν=2.0)], Truncated(TDist{Float64}(ν=4.0); lower=0.0))
source
ActionModels.RejectParametersType

RejectParameters(errortext)

Custom error type for rejecting parameter samples during inference.

Arguments

  • errortext: Explanation for the rejection

Examples

julia> throw(RejectParameters("Parameter out of bounds"))
ERROR: RejectParameters("Parameter out of bounds")
source
ActionModels.RescorlaWagnerType

RescorlaWagner(; type=:continuous, initialvalue=nothing, learningrate=0.1, ncategories=nothing, responsemodel=nothing, responsemodelparameters=nothing, responsemodelobservations=nothing, responsemodelactions=nothing, actionnoise=nothing, actbefore_update=false)

Premade configuration type for the Rescorla-Wagner model. Used to construct an ActionModel for continuous, binary, or categorical learning tasks.

Keyword Arguments

  • type: Symbol specifying the model type (:continuous, :binary, or :categorical).
  • initial_value: Initial value for the expected value (Float64 or Vector{Float64}).
  • learning_rate: Learning rate parameter (Float64).
  • n_categories: Number of categories (required for categorical models).
  • response_model: Custom response model function (optional).
  • response_model_parameters: NamedTuple of response model parameters (optional).
  • response_model_observations: NamedTuple of response model observations (optional).
  • response_model_actions: NamedTuple of response model actions (optional).
  • action_noise: Action noise parameter (Float64, optional; used if no custom response model is provided).
  • act_before_update: If true, action is selected before updating the expectation.

Examples

```jldoctest julia> rw = RescorlaWagner(type=:continuous, initialvalue=0.0, learningrate=0.2) RescorlaWagner(:continuous, 0.0, 0.2, nothing, ActionModels.var"#gaussianreport#270"(), (actionnoise = Parameter{Float64}(1.0, Float64),), (observation = Observation{Float64}(Float64),), (report = Action{Float64, Normal}(Float64, Normal),), false)

julia> ActionModel(rw)

julia> rwcat = RescorlaWagner(type=:categorical, ncategories=3, initialvalue=[0.0, 0.0, 0.0], learningrate=0.1) RescorlaWagner(:categorical, [0.0, 0.0, 0.0], 0.1, 3, ActionModels.var"#categoricalreport#272"(), (actionnoise = Parameter{Float64}(1.0, Float64),), (observation = Observation{Int64}(Int64),), (report = Action{Int64, Categorical{P} where P<:Real}(Int64, Categorical{P} where P<:Real),), false)

julia> ActionModel(rw_cat)

julia> customresponsemodel = (attributes) -> Normal(2 * attributes.submodel.expectedvalue, loadparameters(attributes).noise);

julia> rwcustomresp = RescorlaWagner(responsemodel=customresponsemodel, responsemodelobservations=(observation=Observation(Float64),), responsemodelactions=(report=Action(Normal),), responsemodel_parameters=(noise=Parameter(1.0, Float64),)) RescorlaWagner(...)

julia> ActionModel(rwcustomresp)

source
ActionModels.RescorlaWagnerAttributesType

RescorlaWagnerAttributes{T,AT}

Container for the parameters and state of a Rescorla-Wagner submodel. Has a type parameter T to allow using types supplied by the Turing header, and an attribute type AT to allow dispatching between categorical and continuous/binary models.

Type Parameters

  • T: Element type (e.g., Float64).
  • AT: Categorical vs continuous/binary typing (e.g., Float64 or Vector{Float64}).

Fields

  • initial_value: Initial value parameter for the expected value.
  • learning_rate: Learning rate parameter.
  • expected_value: Current expected value.

Examples

julia> ActionModels.RescorlaWagnerAttributes{Float64,Float64}(initial_value=0.0, learning_rate=0.1, expected_value=0.0)
ActionModels.RescorlaWagnerAttributes{Float64, Float64}(0.0, 0.1, 0.0)
source
ActionModels.SampleSaveResumeType

SampleSaveResume

Configuration for saving and resuming MCMC sampling progress to disk.

This struct is used to enable periodic saving of sampler state during long MCMC runs, allowing interrupted sampling to be resumed from the last save point.

Fields

  • save_every::Int: Number of iterations between saves (default: 100).
  • path::String: Directory path for saving sampler state (default: "./.samplingstate").
  • chain_prefix::String: Prefix for saved chain segment files (default: "ActionModelschainsegment").

Examples

julia> SampleSaveResume()
SampleSaveResume(100, "./.samplingstate", "ActionModels_chain_segment")

julia> SampleSaveResume(save_every=500, path="./.tmp", chain_prefix="my_chain")
SampleSaveResume(500, "./.tmp", "my_chain")
source
ActionModels.SessionParametersType

SessionParameters

Container for session-level parameter estimates from model fitting.

Fields

  • value: NamedTuple for each parameter, containing a NamedTuple for each session. Within is an AxisArray of samples.
  • modelfit: The associated ModelFit object.
  • estimated_parameter_names: Tuple of estimated parameter names.
  • session_ids: Vector of session identifiers.
  • parameter_types: NamedTuple of parameter types.
  • n_samples: Number of posterior samples.
  • n_chains: Number of MCMC chains.
source
ActionModels.StateType

State(initialvalue; discrete=nothing) State(initialvalue, ::Type{T}) State(; discrete=false) State(::Type{T})

Construct a model state variable, which can be continuous, discrete, or a custom type.

Arguments

  • initial_value: The initial value for the state (can be Real, Array, or custom type). Set to missing for no initial value.
  • discrete: If true, state is treated as discrete (Int), otherwise as continuous (Float). Only valid for Real or Array{<:Real} types.
  • T: Type of the state (for non-Real types).

Examples

julia> State(0.0)
State{Float64}(0.0, Float64)

julia> State(1, discrete=true)
State{Int64}(1, Int64)

julia> State([0.0, 0.0])
State{Array{Float64}}([0.0, 0.0], Array{Float64})

julia> State(discrete=true)
State{Int64}(missing, Int64)

julia> State(String)
State{String}(missing, String)
source
ActionModels.StateTrajectoriesType

StateTrajectories

Container for state trajectory estimates from model fitting.

Fields

  • value: NamedTuple for each state, containing a NamedTuple for each session. Within is an AxisArray of samples per timestep.
  • modelfit: The associated ModelFit object.
  • state_names: Vector of state names.
  • session_ids: Vector of session identifiers.
  • state_types: NamedTuple of state types.
  • n_samples: Number of posterior samples.
  • n_chains: Number of MCMC chains.
source
ActionModels.VariableType

Variable(value)

A mutable container for a value of type T. Used for model parameters, states, and actions.

Arguments

  • value: The value to store (any type)

Examples

julia> v = ActionModels.Variable(0.5)
ActionModels.Variable{Float64}(0.5)

julia> v.value = 1.0
1.0

julia> v
ActionModels.Variable{Float64}(1.0)
source
ActionModels.create_modelMethod
create_model(action_model::ActionModel, population_model::DynamicPPL.Model, data::DataFrame; observation_cols, action_cols, session_cols=Vector{Symbol}(), parameters_to_estimate, impute_missing_actions=false, check_parameter_rejections=false, population_model_type=CustomPopulationModel(), verbose=true)

Create a ModelFit structure that can be used for sampling posterior and prior probability distributions. Consists of an action model, a population model, and a dataset.

This function prepares the data, checks consistency with the action and population models, handles missing data, and returns a ModelFit object ready for sampling and inference.

Arguments

  • action_model::ActionModel: The action model to fit.
  • population_model::DynamicPPL.Model: The population model (e.g. a Turing model that generated parameters for each session).
  • data::DataFrame: The dataset containing observations, actions, and session/grouping columns.
  • observation_cols: Columns in data for observations. Can be a NamedTuple, Vector{Symbol}, or Symbol.
  • action_cols: Columns in data for actions. Can be a NamedTuple, Vector{Symbol}, or Symbol.
  • session_cols: Columns in data identifying sessions/groups (default: empty vector).
  • parameters_to_estimate: Tuple of parameter names to estimate.
  • impute_missing_actions: Whether to impute missing actions (default: false).
  • check_parameter_rejections: Whether to check for parameter rejections (default: false).
  • population_model_type: Type of population model (default: CustomPopulationModel()).
  • verbose: Whether to print warnings and info (default: true).

Returns

  • ModelFit: Struct containing the model, data, and metadata for fitting and inference.

Example

julia> model = create_model(action_model, population_model, data; action_cols = :action, observation_cols = :observation, session_cols = :id, parameters_to_estimate = (:learning_rate,));

julia> model isa ActionModels.ModelFit
true

Notes

  • The returned ModelFit object can be used with sample_posterior!, sample_prior!, and other inference utilities.
  • Handles missing actions according to the impute_missing_actions argument.
  • Checks that columns and types in data match the action model specification.
source
ActionModels.create_modelMethod
create_model(action_model::ActionModel, regressions::Union{Regression,FormulaTerm,Vector}, data::DataFrame; observation_cols, action_cols, session_cols=Vector{Symbol}(), verbose=true, kwargs...)

Create a hierarchical Turing model with a regression-based population model for parameter estimation.

This function builds a model where one or more agent parameters are modeled as a function of covariates using (generalized) linear regression, optionally with random effects. Returns a ModelFit object ready for sampling and inference.

Arguments

  • action_model::ActionModel: The agent/action model to fit.
  • regressions::Union{Regression, FormulaTerm, Vector}: One or more regression specifications (as Regression objects or formulae).
  • data::DataFrame: The dataset containing observations, actions, covariates, and session/grouping columns.
  • observation_cols: Columns in data for observations. Can be a NamedTuple, Vector{Symbol}, or Symbol.
  • action_cols: Columns in data for actions. Can be a NamedTuple, Vector{Symbol}, or Symbol.
  • session_cols: Columns in data identifying sessions/groups (default: empty vector).
  • verbose: Whether to print warnings and info (default: true).
  • kwargs...: Additional keyword arguments passed to the underlying model constructor.

Returns

  • ModelFit: Struct containing the model, data, and metadata for fitting and inference.

Example

julia> create_model(action_model, regression, data; action_cols = :action, observation_cols = :observation, session_cols = :id)
-- ModelFit object --
Action model: rescorla_wagner_act_after_update
Linear regression population model
1 estimated action model parameters, 2 sessions
Posterior not sampled
Prior not sampled

Notes

  • Supports both fixed and random effects (random effects via formula syntax, e.g., learning_rate ~ 1 + (1|group)).
  • Multiple regressions can be provided as a vector.
  • The returned ModelFit object can be used with sample_posterior!, sample_prior!, and other inference utilities.
  • Covariate columns must be present in data.
source
ActionModels.create_modelMethod
create_model(action_model::ActionModel, prior::NamedTuple, data::DataFrame; observation_cols, action_cols, session_cols=Vector{Symbol}(), population_model_type=IndependentPopulationModel(), verbose=true, kwargs...)

Create a hierarchical Turing model with independent session-level parameters for each session.

This function builds a model where each session's parameters are sampled independently from the specified prior distributions. Returns a ModelFit object ready for sampling and inference.

Arguments

  • action_model::ActionModel: The agent/action model to fit.
  • prior::NamedTuple: Named tuple of prior distributions for each parameter (e.g., (; learning_rate = LogitNormal())).
  • data::DataFrame: The dataset containing observations, actions, and session/grouping columns.
  • observation_cols: Columns in data for observations. Can be a NamedTuple, Vector{Symbol}, or Symbol.
  • action_cols: Columns in data for actions. Can be a NamedTuple, Vector{Symbol}, or Symbol.
  • session_cols: Columns in data identifying sessions/groups (default: empty vector).
  • population_model_type: Type of population model (default: IndependentPopulationModel()).
  • verbose: Whether to print warnings and info (default: true).
  • kwargs...: Additional keyword arguments passed to the underlying model constructor.

Returns

  • ModelFit: Struct containing the model, data, and metadata for fitting and inference.

Example

julia> model = create_model(action_model, prior, data; action_cols = :action, observation_cols = :observation, session_cols = :id);

julia> model isa ActionModels.ModelFit
true

Notes

  • Each session's parameters are sampled independently from the specified priors.
  • The returned ModelFit object can be used with sample_posterior!, sample_prior!, and other inference utilities.
  • Use this model when you do not want to share information across sessions/groups.
source
ActionModels.create_modelMethod
create_model(action_model::ActionModel, prior::NamedTuple, observations::Vector, actions::Vector; verbose=true, kwargs...)

Create a Turing model for a single session with user-supplied observations and actions.

This function builds a model where all data belong to a single session, and parameters are sampled from the specified prior distributions. Returns a ModelFit object ready for sampling and inference.

Arguments

  • action_model::ActionModel: The agent/action model to fit.
  • prior::NamedTuple: Named tuple of prior distributions for each parameter (e.g., (; learning_rate = LogitNormal())).
  • observations::Vector: Vector of observations (or tuples of observations) for the session.
  • actions::Vector: Vector of actions (or tuples of actions) for the session.
  • verbose: Whether to print warnings and info (default: true).
  • kwargs...: Additional keyword arguments passed to the underlying model constructor.

Returns

  • ModelFit: Struct containing the model, data, and metadata for fitting and inference.

Example

julia> model = create_model(action_model, prior, obs, acts); 

julia> model isa ActionModels.ModelFit
true

Notes

  • Use this model for fitting a single session or subject.
  • The returned ModelFit object can be used with sample_posterior!, sample_prior!, and other inference utilities.
  • Handles both scalar and tuple-valued observations/actions.
source
ActionModels.get_actionsMethod
get_actions(agent::Agent, target_action::Symbol)

Get a single model action for an agent.

Arguments

  • agent::Agent: The agent whose action will be retrieved.
  • target_action::Symbol: Name of the action.

Returns

  • Value of the specified action.

Example

julia> get_actions(agent, :report)
missing

julia> observe!(agent, 0.5);

julia> action = get_actions(agent, :report);

julia> action isa Real
true
source
ActionModels.get_actionsMethod
get_actions(agent::Agent, target_actions::Tuple{Vararg{Symbol}})

Get multiple previously selected actions for an agent using a tuple of names.

Arguments

  • agent::Agent: The agent whose actions will be retrieved.
  • target_actions::Tuple{Vararg{Symbol}}: Tuple of action names.

Returns

  • Tuple of action values.

Example

julia> get_actions(agent, (:report,))
(report = missing,)
source
ActionModels.get_actionsMethod
get_actions(agent::Agent)

Get all previously selected actions for an agent.

Arguments

  • agent::Agent: The agent whose actions will be retrieved.

Returns

  • NamedTuple of all action names and values.

Example

julia> get_actions(agent)
(report = missing,)
source
ActionModels.get_historyMethod
get_history(agent::Agent, target_state::Symbol)

Get the history for a single state for an agent.

Arguments

  • agent::Agent: The agent whose history will be retrieved.
  • target_state::Symbol: Name of the state.

Returns

  • Array of state values over time.

Example

julia> get_history(agent, :expected_value)
1-element Vector{Float64}:
 0.0
source
ActionModels.get_historyMethod
get_history(agent::Agent, target_state::Tuple{Vararg{Symbol}})

Get the history for multiple states for an agent.

Arguments

  • agent::Agent: The agent whose history will be retrieved.
  • target_state::Tuple{Vararg{Symbol}}: Tuple of state names.

Returns

  • Tuple of history arrays for the specified states.

Example

julia> get_history(agent, (:expected_value,))
(expected_value = [0.0],)
source
ActionModels.get_historyMethod
get_history(agent::Agent)

Get the full state history for an agent.

Arguments

  • agent::Agent: The agent whose history will be retrieved.

Returns

  • Dictionary mapping state names to their history arrays.

Example

julia> get_history(agent)
(expected_value = [0.0],)
source
ActionModels.get_parametersMethod
get_parameters(agent::Agent, target_param::Symbol)

Get a single model parameter for an agent.

Arguments

  • agent::Agent: The agent whose parameter will be retrieved.
  • target_param::Symbol: Name of the parameter.

Returns

  • Value of the specified parameter.

Example

julia> get_parameters(agent, :learning_rate)
0.1
source
ActionModels.get_parametersMethod
get_parameters(agent::Agent, target_parameters::Tuple{Vararg{Symbol}})

Get multiple model parameters for an agent using a tuple of names.

Arguments

  • agent::Agent: The agent whose parameters will be retrieved.
  • target_parameters::Tuple{Vararg{Symbol}}: Tuple of parameter names.

Returns

  • Tuple of parameter values.

Example

julia> get_parameters(agent, (:learning_rate, :action_noise))
(learning_rate = 0.1, action_noise = 1.0)
source
ActionModels.get_parametersMethod
get_parameters(agent::Agent)

Get all model parameters for an agent.

Arguments

  • agent::Agent: The agent whose parameters will be retrieved.

Returns

  • NamedTuple of all parameter names and values.

Example

julia> get_parameters(agent)
(action_noise = 1.0, learning_rate = 0.1, initial_value = 0.0)
source
ActionModels.get_session_parameters!Function
get_session_parameters!(modelfit::ModelFit, prior_or_posterior::Symbol = :posterior; verbose::Bool = true)

Extract posterior or prior samples of session-level parameters from a fitted model.

If the requested samples have not yet been drawn, this function will call sample_posterior! or sample_prior! as needed. Returns a SessionParameters struct containing the samples for each session and parameter.

Arguments

  • modelfit::ModelFit: The fitted model object.
  • prior_or_posterior::Symbol = :posterior: Whether to extract from the posterior (:posterior) or prior (:prior).
  • verbose::Bool = true: Whether to print warnings if sampling is triggered.

Returns

  • SessionParameters: Struct containing samples for each session and parameter.

Example

julia> params = get_session_parameters!(model);

julia> params isa ActionModels.SessionParameters
true

Notes

  • Use prior_or_posterior = :prior to extract prior samples instead of posterior.
  • The returned object can be summarized with Turing.summarize.
source
ActionModels.get_state_trajectories!Function
get_state_trajectories!(modelfit::ModelFit, target_states::Union{Symbol,Vector{Symbol}}, prior_or_posterior::Symbol = :posterior)

Extract posterior or prior samples of state trajectories for specified states from a fitted model.

If the requested samples have not yet been drawn, this function will call sample_posterior! or sample_prior! as needed. Returns a StateTrajectories struct containing the samples for each session, state, and timestep.

Arguments

  • modelfit::ModelFit: The fitted model object.
  • target_states::Union{Symbol,Vector{Symbol}}: State or states to extract trajectories for (e.g., :expected_value).
  • prior_or_posterior::Symbol = :posterior: Whether to extract from the posterior (:posterior) or prior (:prior).

Returns

  • StateTrajectories: Struct containing samples for each session, state, and timestep.

Example

julia> trajs = get_state_trajectories!(model, :expected_value);

julia> trajs isa ActionModels.StateTrajectories
true

Notes

  • Use prior_or_posterior = :prior to extract prior samples instead of posterior.
  • The returned object can be summarized with Turing.summarize.
source
ActionModels.get_statesMethod
get_states(agent::Agent, target_state::Symbol)

Get a single model state for an agent.

Arguments

  • agent::Agent: The agent whose state will be retrieved.
  • target_state::Symbol: Name of the state.

Returns

  • Value of the specified state.

Example

julia> get_states(agent, :expected_value)
0.0
source
ActionModels.get_statesMethod
get_states(agent::Agent, target_states::Tuple{Vararg{Symbol}})

Get multiple model states for an agent using a tuple of names.

Arguments

  • agent::Agent: The agent whose states will be retrieved.
  • target_states::Tuple{Vararg{Symbol}}: Tuple of state names.

Returns

  • Tuple of state values.

Example

julia> get_states(agent, (:expected_value,))
(expected_value = 0.0,)
source
ActionModels.get_statesMethod
get_states(agent::Agent)

Get all model states for an agent.

Arguments

  • agent::Agent: The agent whose states will be retrieved.

Returns

  • NamedTuple of all state names and values.

Example

julia> get_states(agent)
(expected_value = 0.0,)
source
ActionModels.init_agentMethod
init_agent(action_model::ActionModel; save_history=false)

Initialize an Agent for simulation, given an ActionModel. Optionally specify which states to save in the agent's history.

Arguments

  • action_model: The ActionModel to use for the agent.
  • save_history: If true, save all states; if false, save none; if a Symbol or Vector{Symbol}, save only those states (default: false).

Returns

  • Agent: An initialized agent ready for simulation.

Examples

julia> am = ActionModel(RescorlaWagner());

julia> agent = init_agent(am, save_history=true)
-- ActionModels Agent --
Action model: rescorla_wagner_act_after_update
This agent has received 0 observations
source
ActionModels.load_actionsMethod

loadactions(modelattributes)

Load the actions from the model attributes. This is used within the action model definition to extract the current action values.

Arguments

  • model_attributes: The model attributes object.

Returns

A vector of action values.

Example

julia> actions = load_actions(attributes)
(report = missing,)
source
ActionModels.load_parametersMethod

loadparameters(modelattributes)

Load the parameters from the model attributes. This is used within the action model definition to extract the current parameter values.

Arguments

  • model_attributes: The model attributes object.

Returns

A vector of parameter values.

Example

julia> params = load_parameters(attributes)
(learning_rate = 0.1,)
source
ActionModels.load_statesMethod

loadstates(modelattributes)

Load the states from the model attributes. This is used within the action model definition to extract the current state values.

Arguments

  • model_attributes: The model attributes object.

Returns

A vector of state values.

Example

julia> states = load_states(attributes)
(expected_value = 0.0,)
source
ActionModels.observe!Method
observe!(agent::Agent, observation)

Advance the agent by one timestep using the given observation, updating its state and returning the sampled action.

Arguments

  • agent: The Agent to update.
  • observation: The observation(s) for this timestep (tuple or value).

Returns

  • The sampled action(s) for this timestep.

Examples

julia> action = observe!(agent, 0.5);

julia> action isa Real
true
source
ActionModels.reset!Method
reset!(agent::Agent)

Reset the agent's model attributes and history to their initial states.

This function resets all model parameters, states, and actions to their initial values, clears the agent's history, and sets the timestep counter to zero. This is useful for running new simulations with the same agent instance.

Arguments

  • agent::Agent: The agent whose attributes and history will be reset.

Example

julia> reset!(agent)
source
ActionModels.sample_posterior!Function
sample_posterior!(modelfit::ModelFit, parallelization::AbstractMCMC.AbstractMCMCEnsemble = MCMCSerial(); verbose=true, resample=false, save_resume=nothing, init_params=:sample_prior, n_samples=1000, n_chains=2, adtype=AutoForwardDiff(), sampler=NUTS(; adtype=adtype), sampler_kwargs...)

Sample from the posterior distribution of a fitted model using MCMC.

This function runs MCMC sampling for the provided ModelFit object, storing the results in modelfit.posterior. It supports saving and resuming sampling, parallelization, various ways of initializing parameters for the sampling, and specifying detailed settings for the sampling. Returns the sampled chains.

Arguments

  • modelfit::ModelFit: The model structure to sample with.
  • parallelization::AbstractMCMC.AbstractMCMCEnsemble: Parallelization strategy (default: MCMCSerial()).
  • verbose::Bool: Whether to display warnings (default: true).
  • resample::Bool: Whether to force resampling even if results exist (default: false).
  • save_resume::Union{SampleSaveResume,Nothing}: Save/resume configuration (default: nothing).
  • init_params::Union{Nothing,Symbol,Vector{Float64}}: How to initialize the sampler (default: :sample_prior).
  • n_samples::Integer: Number of samples per chain (default: 1000).
  • n_chains::Integer: Number of MCMC chains (default: 2).
  • adtype: Automatic differentiation type (default: AutoForwardDiff()).
  • sampler: Sampler algorithm (default: NUTS).
  • sampler_kwargs...: Additional keyword arguments for the sampler.

Returns

  • Chains: The sampled posterior chains.

Example

julia> chns = sample_posterior!(model, sampler = HMC(0.8, 10), n_chains = 1, n_samples = 100, progress = false);

julia> chns isa Chains
true
source
ActionModels.sample_prior!Function
sample_prior!(modelfit::ModelFit, parallelization::AbstractMCMC.AbstractMCMCEnsemble = MCMCSerial(); resample=false, n_samples=1000, n_chains=2)

Sample from the prior distribution of a fitted model using MCMC.

This function samples from the prior for the provided ModelFit object, storing the results in modelfit.prior. Returns the sampled chains.

Arguments

  • modelfit::ModelFit: The model structure to sample with.
  • parallelization::AbstractMCMC.AbstractMCMCEnsemble: Parallelization strategy (default: MCMCSerial()).
  • resample::Bool: Whether to force resampling even if results exist (default: false).
  • n_samples::Integer: Number of samples per chain (default: 1000).
  • n_chains::Integer: Number of MCMC chains (default: 2).

Returns

  • Chains: The sampled prior chains.

Example

julia> chns = sample_prior!(model, progress = false);

julia> chns isa Chains
true
source
ActionModels.set_actions!Method
set_actions!(agent::Agent, actions::NamedTuple)

Set multiple actions as having been previously chosen for an agent using a NamedTuple.

Arguments

  • agent::Agent: The agent whose actions will be set.
  • actions::NamedTuple: NamedTuple of action names and values.

Example

julia> set_actions!(agent, (;report=1.,))
source
ActionModels.set_actions!Method
set_actions!(agent::Agent, target_action::Symbol, target_value::Real)

Set a single model action for an agent.

Arguments

  • agent::Agent: The agent whose action will be set.
  • target_action::Symbol: Name of the action.
  • target_value::Real: Value to set.

Example

julia> set_actions!(agent, :report, 1)
source
ActionModels.set_actions!Method
set_actions!(agent::Agent, action_names::Tuple{Vararg{Symbol}}, action_values::Tuple{Vararg{Real}})

Set multiple actions as having been previously chosen for an agent using tuples of names and values.

Arguments

  • agent::Agent: The agent whose actions will be set.
  • action_names::Tuple{Vararg{Symbol}}: Tuple of action names.
  • action_values::Tuple{Vararg{Real}}: Tuple of action values.

Example

julia> set_actions!(agent, (:report,), (1,))
source
ActionModels.set_actions!Method
set_actions!(agent::Agent, action_names::Vector{Symbol}, action_values::Tuple{Vararg{Real}})

Set multiple action as previously selected for an agent using a vector of names and a tuple of values.

Arguments

  • agent::Agent: The agent whose actions will be set.
  • action_names::Vector{Symbol}: Vector of action names.
  • action_values::Tuple{Vararg{Real}}: Tuple of action values.

Example

julia> set_actions!(agent, [:report], (1,))
source
ActionModels.set_parameters!Method
set_parameters!(agent::Agent, parameters::NamedTuple)

Set multiple model parameters for an agent using a NamedTuple.

Arguments

  • agent::Agent: The agent whose parameters will be set.
  • parameters::NamedTuple: NamedTuple of parameter names and values.

Example

julia> set_parameters!(agent, (learning_rate=0.2, action_noise=0.1))
source
ActionModels.set_parameters!Method
set_parameters!(agent::Agent, target_param::Symbol, target_value::Union{R,AbstractArray{R}})

Set a single model parameter for an agent.

Arguments

  • agent::Agent: The agent whose parameter will be set.
  • target_param::Symbol: Name of the parameter.
  • target_value::Union{R,AbstractArray{R}}: Value to set.

Example

julia> set_parameters!(agent, :learning_rate, 0.2)
source
ActionModels.set_parameters!Method
set_parameters!(agent::Agent, parameter_names::Tuple{Vararg{Symbol}}, parameter_values::Tuple{Vararg{Union{R,AbstractArray{R}}}})

Set multiple model parameters for an agent using tuples of names and values.

Arguments

  • agent::Agent: The agent whose parameters will be set.
  • parameter_names::Tuple{Vararg{Symbol}}: Tuple of parameter names.
  • parameter_values::Tuple{Vararg{Union{R,AbstractArray{R}}}}: Tuple of parameter values.

Example

julia> set_parameters!(agent, (:learning_rate, :action_noise), (0.2, 0.1))
source
ActionModels.set_parameters!Method
set_parameters!(agent::Agent, parameter_names::Vector{Symbol}, parameter_values::Tuple{Vararg{Union{R,AbstractArray{R}}}})

Set multiple model parameters for an agent using a vector of names and a tuple of values.

Arguments

  • agent::Agent: The agent whose parameters will be set.
  • parameter_names::Vector{Symbol}: Vector of parameter names.
  • parameter_values::Tuple{Vararg{Union{R,AbstractArray{R}}}}: Tuple of parameter values.

Example

julia> set_parameters!(agent, [:learning_rate, :action_noise], (0.2, 0.1))
source
ActionModels.set_states!Method
set_states!(agent::Agent, states::NamedTuple)

Set multiple model states for an agent using a NamedTuple.

Arguments

  • agent::Agent: The agent whose states will be set.
  • states::NamedTuple: NamedTuple of state names and values.

Example

julia> set_states!(agent, (expected_value=0.0,))
source
ActionModels.set_states!Method
set_states!(agent::Agent, target_state::Symbol, target_value::Any)

Set a single model state for an agent.

Arguments

  • agent::Agent: The agent whose state will be set.
  • target_state::Symbol: Name of the state.
  • target_value::Any: Value to set.

Example

julia> set_states!(agent, :expected_value, 0.0)
source
ActionModels.set_states!Method
set_states!(agent::Agent, state_names::Tuple{Vararg{Symbol}}, state_values::Tuple{Vararg{Any}})

Set multiple model states for an agent using tuples of names and values.

Arguments

  • agent::Agent: The agent whose states will be set.
  • state_names::Tuple{Vararg{Symbol}}: Tuple of state names.
  • state_values::Tuple{Vararg{Any}}: Tuple of state values.

Example

julia> set_states!(agent, (:expected_value,), (0.0,))
source
ActionModels.set_states!Method
set_states!(agent::Agent, state_names::Vector{Symbol}, state_values::Tuple{Vararg{Any}})

Set multiple model states for an agent using a vector of names and a tuple of values.

Arguments

  • agent::Agent: The agent whose states will be set.
  • state_names::Vector{Symbol}: Vector of state names.
  • state_values::Tuple{Vararg{Any}}: Tuple of state values.

Example

julia> set_states!(agent, [:expected_value], (0.0,))
source
ActionModels.simulate!Method
simulate!(agent::Agent, observations::AbstractMatrix)

Simulate the agent forward for multiple timesteps, using a matrix of observations (each row is a timestep). Returns a vector of actions.

Arguments

  • agent: The Agent to simulate.
  • observations: Matrix of observations (each row is a tuple or value for one timestep).

Returns

  • Vector of actions, one per timestep.
source
ActionModels.simulate!Method
simulate!(agent::Agent, observations::AbstractVector)

Simulate the agent forward for multiple timesteps, using a vector of observations. Returns a vector of actions for each timestep.

Arguments

  • agent: The Agent to simulate.
  • observations: Vector of observations (each element is a tuple or value for one timestep).

Returns

  • Vector of actions, one per timestep.

Examples

julia> actions = simulate!(agent, [0.5, 0.7, 0.9]);

julia> actions isa AbstractVector
true
source
ActionModels.update!Method
update!(attributes::RescorlaWagnerAttributes, observation)

Update the expected value(s) in a Rescorla-Wagner submodel attributes according to the observation. This function is dispatched for continuous, binary, and categorical models.

Arguments

  • attributes: A RescorlaWagnerAttributes struct containing the current state and parameters.
  • observation: The observed outcome. Type depends on the model variant:
    • Float64 for continuous models
    • Int64 for binary or categorical models (category index)
    • Vector{Int64} for categorical models (one-hot or binary vector)
    • Vector{Float64} for categorical models (continuous vector)

Examples

julia> attrs = ActionModels.RescorlaWagnerAttributes{Float64,Float64}(initial_value=0.0, learning_rate=0.2, expected_value=0.0);

julia> ActionModels.update!(attrs, 1.0);  # Continuous update

julia> attrs.expected_value
0.2

julia> attrs_bin = ActionModels.RescorlaWagnerAttributes{Float64,Float64}(initial_value=0.0, learning_rate=0.5, expected_value=0.0);

julia> ActionModels.update!(attrs_bin, 1);  # Binary update

julia> attrs_bin.expected_value
0.25

julia> attrs_cat = ActionModels.RescorlaWagnerAttributes{Float64,Vector{Float64}}(initial_value=[0.0, 0.0, 0.0], learning_rate=0.1, expected_value=[0.0, 0.0, 0.0]);

julia> ActionModels.update!(attrs_cat, 2);  # Categorical update with category index

julia> attrs_cat.expected_value
3-element Vector{Float64}:
 -0.05
  0.05
 -0.05
source
ActionModels.update_state!Method

updatestate!(modelattributes, statename, statevalue)

Update the value of a state in the model attributes. This function is used within the action model definition to update the state values during the simulation or fitting process.

Arguments

  • model_attributes: The model attributes object.
  • state_name: The name of the state to update.
  • state_value: The new value for the state.

````

# Example

jldoctest; setup = :(using ActionModels; attributes = ModelAttributes((learningrate=ActionModels.Variable(0.1),), (expectedvalue=ActionModels.Variable(0.0),), (report=ActionModels.Variable(missing),), (expectedvalue=ActionModels.Variable(0.0),), ActionModels.NoSubModelAttributes())) julia> updatestate!(attributes, :expectedvalue, 2.0) # Update state expectedvalue to 2.0 ```

source
MCMCChains.summarizeFunction
Turing.summarize(state_trajectories::StateTrajectories, summary_function::Function=median)

Summarize posterior samples of state trajectories into a tidy DataFrame.

Each row corresponds to a session and timestep, and each column to a state variable (or state element, for arrays). The summary statistic (e.g., median, mean) is applied to the posterior samples for each state, session, and timestep.

Arguments

  • state_trajectories::StateTrajectories: Posterior samples of state trajectories, as returned by model fitting.
  • summary_function::Function=median: Function to summarize the samples (e.g., mean, median, std).

Returns

  • DataFrame: Table with one row per session and timestep, and columns for each state (or state element).

Example

julia> df = summarize(get_state_trajectories!(model, :expected_value), mean);

julia> df isa DataFrame
true

Notes

  • For array-valued states, columns are named with indices, e.g., :expected_value[1].
  • Session identifiers are split into columns if composite.
  • The column timestep indicates the time index (starting from 0).
source
MCMCChains.summarizeFunction
Turing.summarize(session_parameters::SessionParameters, summary_function::Function=median)

Summarize posterior samples of session-level parameters into a tidy DataFrame.

Each row corresponds to a session, and each column to a parameter (or parameter element, for arrays). The summary statistic (e.g., median, mean) is applied to the posterior samples for each parameter and session.

Arguments

  • session_parameters::SessionParameters: Posterior samples of session-level parameters, as returned by model fitting.
  • summary_function::Function=median: Function to summarize the samples (e.g., mean, median, std).

Returns

  • DataFrame: Table with one row per session and columns for each parameter (or parameter element).

Example

julia> df = summarize(get_session_parameters!(model), median);

julia> df isa DataFrame
true

Notes

  • For array-valued parameters, columns are named with indices, e.g., :expected_value[1].
  • Session identifiers are split into columns if composite.
source
RecipesBase.apply_recipeFunction
f(agent::Agent, target_state::Symbol, index::Union{Nothing,Vector{Int}}=nothing)

Plot the trajectory (history) of a state variable for an agent over time.

This plotting recipe visualizes the time course of a given state variable from an agent's history. If the state is multivariate (e.g., a vector or array), you can specify an index to plot a particular dimension. If the state is univariate, the index should be left as nothing (the default).

Arguments

  • agent::Agent: The agent whose state trajectory will be plotted.
  • target_state::Symbol: The name of the state variable to plot (e.g., :expected_value).
  • index::Union{Nothing,Vector{Int}}: (Optional) Index for multivariate states (e.g., [1] for the first element).

Example

julia> plot(agent, :expected_value)
Plot{Plots.GRBackend() n=1}

For a multivariate state (e.g., a vector):

julia> plot(agent, :expected_value, [1])
Plot{Plots.GRBackend() n=1}
source