Skip to content

Commit 3ad9365

Browse files
authored
[RLlib] Attention Net prep PR #2: Smaller cleanups. (#12449)
1 parent e72147d commit 3ad9365

File tree

12 files changed

+68
-55
lines changed

12 files changed

+68
-55
lines changed

rllib/evaluation/collectors/sample_collector.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ class _SampleCollector(metaclass=ABCMeta):
3131

3232
@abstractmethod
3333
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
34-
policy_id: PolicyID, init_obs: TensorType) -> None:
34+
policy_id: PolicyID, t: int,
35+
init_obs: TensorType) -> None:
3536
"""Adds an initial obs (after reset) to this collector.
3637
3738
Since the very first observation in an environment is collected w/o
@@ -48,6 +49,8 @@ def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
4849
values for.
4950
env_id (EnvID): The environment index (in a vectorized setup).
5051
policy_id (PolicyID): Unique id for policy controlling the agent.
52+
t (int): The time step (episode length - 1). The initial obs has
53+
ts=-1(!), then an action/reward/next-obs at t=0, etc..
5154
init_obs (TensorType): Initial observation (after env.reset()).
5255
5356
Examples:
@@ -172,9 +175,10 @@ def postprocess_episode(self,
172175
MultiAgentBatch. Used for batch_mode=`complete_episodes`.
173176
174177
Returns:
175-
Any: An ID that can be used in `build_multi_agent_batch` to
176-
retrieve the samples that have been postprocessed as a
177-
ready-built MultiAgentBatch.
178+
Optional[MultiAgentBatch]: If `build` is True, the
179+
SampleBatch or MultiAgentBatch built from `episode` (either
180+
just from that episde or from the `_PolicyCollectorGroup`
181+
in the `episode.batch_builder` property).
178182
"""
179183
raise NotImplementedError
180184

rllib/evaluation/collectors/simple_list_collector.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,19 @@ def __init__(self, shift_before: int = 0):
5252
# each time a (non-initial!) observation is added.
5353
self.count = 0
5454

55-
def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID,
56-
env_id: EnvID, init_obs: TensorType,
55+
def add_init_obs(self, episode_id: EpisodeID, agent_index: int,
56+
env_id: EnvID, t: int, init_obs: TensorType,
5757
view_requirements: Dict[str, ViewRequirement]) -> None:
5858
"""Adds an initial observation (after reset) to the Agent's trajectory.
5959
6060
Args:
6161
episode_id (EpisodeID): Unique ID for the episode we are adding the
6262
initial observation for.
63-
agent_id (AgentID): Unique ID for the agent we are adding the
64-
initial observation for.
63+
agent_index (int): Unique int index (starting from 0) for the agent
64+
within its episode.
6565
env_id (EnvID): The environment index (in a vectorized setup).
66+
t (int): The time step (episode length - 1). The initial obs has
67+
ts=-1(!), then an action/reward/next-obs at t=0, etc..
6668
init_obs (TensorType): The initial observation tensor (after
6769
`env.reset()`).
6870
view_requirements (Dict[str, ViewRequirements])
@@ -72,10 +74,15 @@ def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID,
7274
single_row={
7375
SampleBatch.OBS: init_obs,
7476
SampleBatch.EPS_ID: episode_id,
75-
SampleBatch.AGENT_INDEX: agent_id,
77+
SampleBatch.AGENT_INDEX: agent_index,
7678
"env_id": env_id,
79+
"t": t,
7780
})
7881
self.buffers[SampleBatch.OBS].append(init_obs)
82+
self.buffers[SampleBatch.EPS_ID].append(episode_id)
83+
self.buffers[SampleBatch.AGENT_INDEX].append(agent_index)
84+
self.buffers["env_id"].append(env_id)
85+
self.buffers["t"].append(t)
7986

8087
def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \
8188
None:
@@ -133,7 +140,7 @@ def build(self, view_requirements: Dict[str, ViewRequirement]) -> \
133140
continue
134141
# OBS are already shifted by -1 (the initial obs starts one ts
135142
# before all other data columns).
136-
shift = view_req.shift - \
143+
shift = view_req.data_rel_pos - \
137144
(1 if data_col == SampleBatch.OBS else 0)
138145
if data_col not in np_data:
139146
np_data[data_col] = to_float_np_array(self.buffers[data_col])
@@ -187,7 +194,10 @@ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None:
187194
for col, data in single_row.items():
188195
if col in self.buffers:
189196
continue
190-
shift = self.shift_before - (1 if col == SampleBatch.OBS else 0)
197+
shift = self.shift_before - (1 if col in [
198+
SampleBatch.OBS, SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
199+
"env_id", "t"
200+
] else 0)
191201
# Python primitive or dict (e.g. INFOs).
192202
if isinstance(data, (int, float, bool, str, dict)):
193203
self.buffers[col] = [0 for _ in range(shift)]
@@ -360,7 +370,7 @@ def episode_step(self, episode_id: EpisodeID) -> None:
360370

361371
@override(_SampleCollector)
362372
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
363-
env_id: EnvID, policy_id: PolicyID,
373+
env_id: EnvID, policy_id: PolicyID, t: int,
364374
init_obs: TensorType) -> None:
365375
# Make sure our mappings are up to date.
366376
agent_key = (episode.episode_id, agent_id)
@@ -378,8 +388,9 @@ def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
378388
self.agent_collectors[agent_key] = _AgentCollector()
379389
self.agent_collectors[agent_key].add_init_obs(
380390
episode_id=episode.episode_id,
381-
agent_id=agent_id,
391+
agent_index=episode._agent_index(agent_id),
382392
env_id=env_id,
393+
t=t,
383394
init_obs=init_obs,
384395
view_requirements=view_reqs)
385396

@@ -429,7 +440,7 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> \
429440
# Create the batch of data from the different buffers.
430441
data_col = view_req.data_col or view_col
431442
time_indices = \
432-
view_req.shift - (
443+
view_req.data_rel_pos - (
433444
1 if data_col in [SampleBatch.OBS, "t", "env_id",
434445
SampleBatch.EPS_ID,
435446
SampleBatch.AGENT_INDEX] else 0)

rllib/evaluation/rollout_worker.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,11 @@ def __init__(
272272
output_creator (Callable[[IOContext], OutputWriter]): Function that
273273
returns an OutputWriter object for saving generated
274274
experiences.
275-
remote_worker_envs (bool): If using num_envs > 1, whether to create
276-
those new envs in remote processes instead of in the current
277-
process. This adds overheads, but can make sense if your envs
275+
remote_worker_envs (bool): If using num_envs_per_worker > 1,
276+
whether to create those new envs in remote processes instead of
277+
in the current process. This adds overheads, but can make sense
278+
if your envs are expensive to step/reset (e.g., for StarCraft).
279+
Use this cautiously, overheads are significant!
278280
remote_env_batch_wait_ms (float): Timeout that remote workers
279281
are waiting when polling environments. 0 (continue when at
280282
least one env is ready) is a reasonable default, but optimal

rllib/evaluation/sampler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,7 +1040,8 @@ def _process_observations_w_trajectory_view_api(
10401040
# Record transition info if applicable.
10411041
if last_observation is None:
10421042
_sample_collector.add_init_obs(episode, agent_id, env_id,
1043-
policy_id, filtered_obs)
1043+
policy_id, episode.length - 1,
1044+
filtered_obs)
10441045
else:
10451046
# Add actions, rewards, next-obs to collectors.
10461047
values_dict = {
@@ -1158,7 +1159,8 @@ def _process_observations_w_trajectory_view_api(
11581159

11591160
# Add initial obs to buffer.
11601161
_sample_collector.add_init_obs(
1161-
new_episode, agent_id, env_id, policy_id, filtered_obs)
1162+
new_episode, agent_id, env_id, policy_id,
1163+
new_episode.length - 1, filtered_obs)
11621164

11631165
item = PolicyEvalData(
11641166
env_id, agent_id, filtered_obs,

rllib/evaluation/tests/test_trajectory_view_api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_traj_view_normal_case(self):
5959
assert view_req_policy[key].data_col is None
6060
else:
6161
assert view_req_policy[key].data_col == SampleBatch.OBS
62-
assert view_req_policy[key].shift == 1
62+
assert view_req_policy[key].data_rel_pos == 1
6363
rollout_worker = trainer.workers.local_worker()
6464
sample_batch = rollout_worker.sample()
6565
expected_count = \
@@ -99,18 +99,18 @@ def test_traj_view_lstm_prev_actions_and_rewards(self):
9999

100100
if key == SampleBatch.PREV_ACTIONS:
101101
assert view_req_policy[key].data_col == SampleBatch.ACTIONS
102-
assert view_req_policy[key].shift == -1
102+
assert view_req_policy[key].data_rel_pos == -1
103103
elif key == SampleBatch.PREV_REWARDS:
104104
assert view_req_policy[key].data_col == SampleBatch.REWARDS
105-
assert view_req_policy[key].shift == -1
105+
assert view_req_policy[key].data_rel_pos == -1
106106
elif key not in [
107107
SampleBatch.NEXT_OBS, SampleBatch.PREV_ACTIONS,
108108
SampleBatch.PREV_REWARDS
109109
]:
110110
assert view_req_policy[key].data_col is None
111111
else:
112112
assert view_req_policy[key].data_col == SampleBatch.OBS
113-
assert view_req_policy[key].shift == 1
113+
assert view_req_policy[key].data_rel_pos == 1
114114
trainer.stop()
115115

116116
def test_traj_view_simple_performance(self):

rllib/examples/policy/episode_env_aware_policy.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,24 @@ class _fake_model:
2828
"t": ViewRequirement(),
2929
SampleBatch.OBS: ViewRequirement(),
3030
SampleBatch.PREV_ACTIONS: ViewRequirement(
31-
SampleBatch.ACTIONS, space=self.action_space, shift=-1),
31+
SampleBatch.ACTIONS, space=self.action_space, data_rel_pos=-1),
3232
SampleBatch.PREV_REWARDS: ViewRequirement(
33-
SampleBatch.REWARDS, shift=-1),
33+
SampleBatch.REWARDS, data_rel_pos=-1),
3434
}
3535
for i in range(2):
3636
self.model.inference_view_requirements["state_in_{}".format(i)] = \
3737
ViewRequirement(
38-
"state_out_{}".format(i), shift=-1, space=self.state_space)
38+
"state_out_{}".format(i),
39+
data_rel_pos=-1,
40+
space=self.state_space)
3941
self.model.inference_view_requirements[
4042
"state_out_{}".format(i)] = \
4143
ViewRequirement(space=self.state_space)
4244

4345
self.view_requirements = dict(
4446
**{
4547
SampleBatch.NEXT_OBS: ViewRequirement(
46-
SampleBatch.OBS, shift=1),
48+
SampleBatch.OBS, data_rel_pos=1),
4749
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
4850
SampleBatch.REWARDS: ViewRequirement(),
4951
SampleBatch.DONES: ViewRequirement(),

rllib/examples/policy/rock_paper_scissors_dummies.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __init__(self, *args, **kwargs):
1616
self.view_requirements.update({
1717
"state_in_0": ViewRequirement(
1818
"state_out_0",
19-
shift=-1,
19+
data_rel_pos=-1,
2020
space=gym.spaces.Box(0, 100, shape=(), dtype=np.int32))
2121
})
2222

rllib/models/modelv2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ def __init__(self, obs_space: gym.spaces.Space,
6161
self.time_major = self.model_config.get("_time_major")
6262
# Basic view requirement for all models: Use the observation as input.
6363
self.inference_view_requirements = {
64-
SampleBatch.OBS: ViewRequirement(shift=0, space=self.obs_space),
64+
SampleBatch.OBS: ViewRequirement(
65+
data_rel_pos=0, space=self.obs_space),
6566
}
6667

6768
# TODO: (sven): Get rid of `get_initial_state` once Trajectory

rllib/models/tf/recurrent_net.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,10 @@ def __init__(self, obs_space: gym.spaces.Space,
178178
if model_config["lstm_use_prev_action"]:
179179
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
180180
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
181-
shift=-1)
181+
data_rel_pos=-1)
182182
if model_config["lstm_use_prev_reward"]:
183183
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
184-
ViewRequirement(SampleBatch.REWARDS, shift=-1)
184+
ViewRequirement(SampleBatch.REWARDS, data_rel_pos=-1)
185185

186186
@override(RecurrentNetwork)
187187
def forward(self, input_dict: Dict[str, TensorType],

rllib/models/torch/recurrent_net.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,10 @@ def __init__(self, obs_space: gym.spaces.Space,
159159
if model_config["lstm_use_prev_action"]:
160160
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
161161
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
162-
shift=-1)
162+
data_rel_pos=-1)
163163
if model_config["lstm_use_prev_reward"]:
164164
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
165-
ViewRequirement(SampleBatch.REWARDS, shift=-1)
165+
ViewRequirement(SampleBatch.REWARDS, data_rel_pos=-1)
166166

167167
@override(RecurrentNetwork)
168168
def forward(self, input_dict: Dict[str, TensorType],

0 commit comments

Comments
 (0)