@@ -51,10 +51,6 @@ function RLBase.optimise!(policy::AbstractPolicy, stage::AbstractStage, trajecto
5151
5252@functor Agent (policy,)
5353
54- function Base. push! (agent:: Agent , :: PreActStage , env:: AbstractEnv )
55- push! (agent, state (env))
56- end
57-
5854# !!! TODO : In async scenarios, parameters of the policy may still be updating
5955# (partially), which will result to incorrect action. This should be addressed
6056# in Oolong.jl with a wrapper
@@ -64,11 +60,8 @@ function RLBase.plan!(agent::Agent{P,T,C}, env::AbstractEnv) where {P,T,C}
6460 action
6561end
6662
67- # Multiagent Version
68- function RLBase. plan! (agent:: Agent{P,T,C} , env:: E , p:: Symbol ) where {P,T,C,E<: AbstractEnv }
69- action = RLBase. plan! (agent. policy, env, p)
70- push! (agent. trajectory, agent. cache, action)
71- action
63+ function Base. push! (agent:: Agent , :: PreActStage , env:: AbstractEnv )
64+ push! (agent, state (env))
7265end
7366
7467function Base. push! (agent:: Agent{P,T,C} , :: PostActStage , env:: E ) where {P,T,C,E<: AbstractEnv }
@@ -79,11 +72,26 @@ function Base.push!(agent::Agent, ::PostExperimentStage, env::E) where {E<:Abstr
7972 RLBase. reset! (agent. cache)
8073end
8174
82- function Base. push! (agent:: Agent , :: PostExperimentStage , env:: E , player:: Symbol ) where {E<: AbstractEnv }
83- RLBase. reset! (agent. cache)
84- end
85-
8675function Base. push! (agent:: Agent{P,T,C} , state:: S ) where {P,T,C,S}
8776 push! (agent. cache, state)
8877end
8978
79+ # Multiagent Version
80+ function RLBase. plan! (agent:: Agent{P,T,C} , env:: E , p:: Symbol ) where {P,T,C,E<: AbstractEnv }
81+ action = RLBase. plan! (agent. policy, env, p)
82+ push! (agent. trajectory, agent. cache, action)
83+ action
84+ end
85+
86+ # for simultaneous DynamicStyle environments, we have to define push! operations
87+ function Base. push! (agent:: Agent , :: PreActStage , env:: AbstractEnv , player:: Symbol )
88+ push! (agent, state (env, player))
89+ end
90+
91+ function Base. push! (agent:: Agent{P,T,C} , :: PostActStage , env:: E , player:: Symbol ) where {P,T,C,E<: AbstractEnv }
92+ push! (agent. cache, reward (env, player), is_terminated (env))
93+ end
94+
95+ function Base. push! (agent:: Agent , :: PostExperimentStage , env:: E , player:: Symbol ) where {E<: AbstractEnv }
96+ RLBase. reset! (agent. cache)
97+ end
0 commit comments