3232_ , tf , _ = try_import_tf ()
3333
3434
35+ # TODO (sven): Use SingleAgentEnvRunner instead of this as soon as we have the new
36+ # ConnectorV2 example classes to make Atari work properly with these (w/o requiring the
37+ # classes at the bottom of this file here, e.g. `ActionClip`).
3538class DreamerV3EnvRunner (EnvRunner ):
3639 """An environment runner to collect data from vectorized gymnasium environments."""
3740
@@ -144,6 +147,7 @@ def __init__(
144147
145148 self ._needs_initial_reset = True
146149 self ._episodes = [None for _ in range (self .num_envs )]
150+ self ._states = [None for _ in range (self .num_envs )]
147151
148152 # TODO (sven): Move metrics temp storage and collection out of EnvRunner
149153 # and RolloutWorkers. These classes should not continue tracking some data
@@ -254,10 +258,8 @@ def _sample_timesteps(
254258
255259 # Set initial obs and states in the episodes.
256260 for i in range (self .num_envs ):
257- self ._episodes [i ].add_initial_observation (
258- initial_observation = obs [i ],
259- initial_state = {k : s [i ] for k , s in states .items ()},
260- )
261+ self ._episodes [i ].add_env_reset (observation = obs [i ])
262+ self ._states [i ] = {k : s [i ] for k , s in states .items ()}
261263 # Don't reset existing envs; continue in already started episodes.
262264 else :
263265 # Pick up stored observations and states from previous timesteps.
@@ -268,7 +270,9 @@ def _sample_timesteps(
268270 states = {
269271 k : np .stack (
270272 [
271- initial_states [k ][i ] if eps .states is None else eps .states [k ]
273+ initial_states [k ][i ]
274+ if self ._states [i ] is None
275+ else self ._states [i ][k ]
272276 for i , eps in enumerate (self ._episodes )
273277 ]
274278 )
@@ -278,7 +282,7 @@ def _sample_timesteps(
278282 # to 1.0, otherwise 0.0.
279283 is_first = np .zeros ((self .num_envs ,))
280284 for i , eps in enumerate (self ._episodes ):
281- if eps . states is None :
285+ if len ( eps ) == 0 :
282286 is_first [i ] = 1.0
283287
284288 # Loop through env for n timesteps.
@@ -319,37 +323,39 @@ def _sample_timesteps(
319323 if terminateds [i ] or truncateds [i ]:
320324 # Finish the episode with the actual terminal observation stored in
321325 # the info dict.
322- self ._episodes [i ].add_timestep (
323- infos ["final_observation" ][i ],
324- actions [i ],
325- rewards [i ],
326- state = s ,
327- is_terminated = terminateds [i ],
328- is_truncated = truncateds [i ],
326+ self ._episodes [i ].add_env_step (
327+ observation = infos ["final_observation" ][i ],
328+ action = actions [i ],
329+ reward = rewards [i ],
330+ terminated = terminateds [i ],
331+ truncated = truncateds [i ],
329332 )
333+ self ._states [i ] = s
330334 # Reset h-states to the model's initial ones b/c we are starting a
331335 # new episode.
332336 for k , v in self .module .get_initial_state ().items ():
333337 states [k ][i ] = v .numpy ()
334338 is_first [i ] = True
335339 done_episodes_to_return .append (self ._episodes [i ])
336340 # Create a new episode object.
337- self ._episodes [i ] = SingleAgentEpisode (
338- observations = [obs [i ]], states = s
339- )
341+ self ._episodes [i ] = SingleAgentEpisode (observations = [obs [i ]])
340342 else :
341- self ._episodes [i ].add_timestep (
342- obs [i ], actions [i ], rewards [i ], state = s
343+ self ._episodes [i ].add_env_step (
344+ observation = obs [i ],
345+ action = actions [i ],
346+ reward = rewards [i ],
343347 )
344348 is_first [i ] = False
345349
350+ self ._states [i ] = s
351+
346352 # Return done episodes ...
347353 self ._done_episodes_for_metrics .extend (done_episodes_to_return )
348354 # ... and all ongoing episode chunks. Also, make sure, we return
349355 # a copy and start new chunks so that callers of this function
350356 # don't alter our ongoing and returned Episode objects.
351357 ongoing_episodes = self ._episodes
352- self ._episodes = [eps .create_successor () for eps in self ._episodes ]
358+ self ._episodes = [eps .cut () for eps in self ._episodes ]
353359 for eps in ongoing_episodes :
354360 self ._ongoing_episodes_for_metrics [eps .id_ ].append (eps )
355361
@@ -385,10 +391,9 @@ def _sample_episodes(
385391 render_images = [e .render () for e in self .env .envs ]
386392
387393 for i in range (self .num_envs ):
388- episodes [i ].add_initial_observation (
389- initial_observation = obs [i ],
390- initial_state = {k : s [i ] for k , s in states .items ()},
391- initial_render_image = render_images [i ],
394+ episodes [i ].add_env_reset (
395+ observation = obs [i ],
396+ render_image = render_images [i ],
392397 )
393398
394399 eps = 0
@@ -419,19 +424,17 @@ def _sample_episodes(
419424 render_images = [e .render () for e in self .env .envs ]
420425
421426 for i in range (self .num_envs ):
422- s = {k : s [i ] for k , s in states .items ()}
423427 # The last entry in self.observations[i] is already the reset
424428 # obs of the new episode.
425429 if terminateds [i ] or truncateds [i ]:
426430 eps += 1
427431
428- episodes [i ].add_timestep (
429- infos ["final_observation" ][i ],
430- actions [i ],
431- rewards [i ],
432- state = s ,
433- is_terminated = terminateds [i ],
434- is_truncated = truncateds [i ],
432+ episodes [i ].add_env_step (
433+ observation = infos ["final_observation" ][i ],
434+ action = actions [i ],
435+ reward = rewards [i ],
436+ terminated = terminateds [i ],
437+ truncated = truncateds [i ],
435438 )
436439 done_episodes_to_return .append (episodes [i ])
437440
@@ -448,15 +451,15 @@ def _sample_episodes(
448451
449452 episodes [i ] = SingleAgentEpisode (
450453 observations = [obs [i ]],
451- states = s ,
452- render_images = [render_images [i ]],
454+ render_images = (
455+ [render_images [i ]] if with_render_data else None
456+ ),
453457 )
454458 else :
455- episodes [i ].add_timestep (
456- obs [i ],
457- actions [i ],
458- rewards [i ],
459- state = s ,
459+ episodes [i ].add_env_step (
460+ observation = obs [i ],
461+ action = actions [i ],
462+ reward = rewards [i ],
460463 render_image = render_images [i ],
461464 )
462465 is_first [i ] = False
0 commit comments