@@ -52,17 +52,19 @@ def __init__(self, shift_before: int = 0):
5252 # each time a (non-initial!) observation is added.
5353 self .count = 0
5454
55- def add_init_obs (self , episode_id : EpisodeID , agent_id : AgentID ,
56- env_id : EnvID , init_obs : TensorType ,
55+ def add_init_obs (self , episode_id : EpisodeID , agent_index : int ,
56+ env_id : EnvID , t : int , init_obs : TensorType ,
5757 view_requirements : Dict [str , ViewRequirement ]) -> None :
5858 """Adds an initial observation (after reset) to the Agent's trajectory.
5959
6060 Args:
6161 episode_id (EpisodeID): Unique ID for the episode we are adding the
6262 initial observation for.
63- agent_id (AgentID ): Unique ID for the agent we are adding the
64- initial observation for .
63+ agent_index (int ): Unique int index (starting from 0) for the agent
64+ within its episode .
6565 env_id (EnvID): The environment index (in a vectorized setup).
66+ t (int): The time step (episode length - 1). The initial obs has
67+ ts=-1(!), then an action/reward/next-obs at t=0, etc..
6668 init_obs (TensorType): The initial observation tensor (after
6769 `env.reset()`).
6870 view_requirements (Dict[str, ViewRequirements])
@@ -72,10 +74,15 @@ def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID,
7274 single_row = {
7375 SampleBatch .OBS : init_obs ,
7476 SampleBatch .EPS_ID : episode_id ,
75- SampleBatch .AGENT_INDEX : agent_id ,
77+ SampleBatch .AGENT_INDEX : agent_index ,
7678 "env_id" : env_id ,
79+ "t" : t ,
7780 })
7881 self .buffers [SampleBatch .OBS ].append (init_obs )
82+ self .buffers [SampleBatch .EPS_ID ].append (episode_id )
83+ self .buffers [SampleBatch .AGENT_INDEX ].append (agent_index )
84+ self .buffers ["env_id" ].append (env_id )
85+ self .buffers ["t" ].append (t )
7986
8087 def add_action_reward_next_obs (self , values : Dict [str , TensorType ]) -> \
8188 None :
@@ -133,7 +140,7 @@ def build(self, view_requirements: Dict[str, ViewRequirement]) -> \
133140 continue
134141 # OBS are already shifted by -1 (the initial obs starts one ts
135142 # before all other data columns).
136- shift = view_req .shift - \
143+ shift = view_req .data_rel_pos - \
137144 (1 if data_col == SampleBatch .OBS else 0 )
138145 if data_col not in np_data :
139146 np_data [data_col ] = to_float_np_array (self .buffers [data_col ])
@@ -187,7 +194,10 @@ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None:
187194 for col , data in single_row .items ():
188195 if col in self .buffers :
189196 continue
190- shift = self .shift_before - (1 if col == SampleBatch .OBS else 0 )
197+ shift = self .shift_before - (1 if col in [
198+ SampleBatch .OBS , SampleBatch .EPS_ID , SampleBatch .AGENT_INDEX ,
199+ "env_id" , "t"
200+ ] else 0 )
191201 # Python primitive or dict (e.g. INFOs).
192202 if isinstance (data , (int , float , bool , str , dict )):
193203 self .buffers [col ] = [0 for _ in range (shift )]
@@ -360,7 +370,7 @@ def episode_step(self, episode_id: EpisodeID) -> None:
360370
361371 @override (_SampleCollector )
362372 def add_init_obs (self , episode : MultiAgentEpisode , agent_id : AgentID ,
363- env_id : EnvID , policy_id : PolicyID ,
373+ env_id : EnvID , policy_id : PolicyID , t : int ,
364374 init_obs : TensorType ) -> None :
365375 # Make sure our mappings are up to date.
366376 agent_key = (episode .episode_id , agent_id )
@@ -378,8 +388,9 @@ def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
378388 self .agent_collectors [agent_key ] = _AgentCollector ()
379389 self .agent_collectors [agent_key ].add_init_obs (
380390 episode_id = episode .episode_id ,
381- agent_id = agent_id ,
391+ agent_index = episode . _agent_index ( agent_id ) ,
382392 env_id = env_id ,
393+ t = t ,
383394 init_obs = init_obs ,
384395 view_requirements = view_reqs )
385396
@@ -429,7 +440,7 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> \
429440 # Create the batch of data from the different buffers.
430441 data_col = view_req .data_col or view_col
431442 time_indices = \
432- view_req .shift - (
443+ view_req .data_rel_pos - (
433444 1 if data_col in [SampleBatch .OBS , "t" , "env_id" ,
434445 SampleBatch .EPS_ID ,
435446 SampleBatch .AGENT_INDEX ] else 0 )
0 commit comments