Skip to content

Commit

Permalink
[RLlib] Add "shuffle batch per epoch" option. (ray-project#47458)
Browse files Browse the repository at this point in the history
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
  • Loading branch information
sven1977 authored and ujjawal-khare committed Oct 15, 2024
1 parent b78d77b commit fe0239e
Show file tree
Hide file tree
Showing 13 changed files with 348 additions and 31 deletions.
16 changes: 1 addition & 15 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2184,7 +2184,6 @@ def training(
learner_config_dict: Optional[Dict[str, Any]] = NotProvided,
# Deprecated args.
num_sgd_iter=DEPRECATED_VALUE,
max_requests_in_flight_per_sampler_worker=DEPRECATED_VALUE,
) -> "AlgorithmConfig":
"""Sets the training related configuration.
Expand Down Expand Up @@ -2240,7 +2239,7 @@ def training(
minibatch_size: The size of minibatches to use to further split the train
batch into.
shuffle_batch_per_epoch: Whether to shuffle the train batch once per epoch.
If the train batch has a time rank (axis=1), shuffling only takes
If the train batch has a time rank (axis=1), shuffling will only take
place along the batch axis to not disturb any intact (episode)
trajectories.
model: Arguments passed into the policy model. See models/catalog.py for a
Expand Down Expand Up @@ -2284,19 +2283,6 @@ def training(
error=False,
)
num_epochs = num_sgd_iter
if max_requests_in_flight_per_sampler_worker != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.training("
"max_requests_in_flight_per_sampler_worker=...)",
new="AlgorithmConfig.env_runners("
"max_requests_in_flight_per_env_runner=...)",
error=False,
)
self.env_runners(
max_requests_in_flight_per_env_runner=(
max_requests_in_flight_per_sampler_worker
),
)

if gamma is not NotProvided:
self.gamma = gamma
Expand Down
10 changes: 9 additions & 1 deletion rllib/algorithms/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,15 @@ def _training_step_old_api_stack(self) -> ResultDict:
# Standardize advantages.
train_batch = standardize_fields(train_batch, ["advantages"])

if self.config.simple_optimizer:
# Perform a train step on the collected batch.
if self.config.enable_rl_module_and_learner:
train_results = self.learner_group.update_from_batch(
batch=train_batch,
minibatch_size=self.config.minibatch_size,
num_epochs=self.config.num_epochs,
)

elif self.config.simple_optimizer:
train_results = train_one_step(self, train_batch)
else:
train_results = multi_gpu_train_one_step(self, train_batch)
Expand Down
85 changes: 81 additions & 4 deletions rllib/algorithms/ppo/tests/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,93 @@ def setUpClass(cls):
def tearDownClass(cls):
ray.shutdown()

def test_ppo_compilation_w_connectors(self):
"""Test whether PPO can be built with all frameworks w/ connectors."""

# Build a PPOConfig object.
config = (
ppo.PPOConfig()
.training(
num_epochs=2,
# Setup lr schedule for testing.
lr_schedule=[[0, 5e-5], [128, 0.0]],
# Set entropy_coeff to a faulty value to proof that it'll get
# overridden by the schedule below (which is expected).
entropy_coeff=100.0,
entropy_coeff_schedule=[[0, 0.1], [256, 0.0]],
train_batch_size=128,
model=dict(
# Settings in case we use an LSTM.
lstm_cell_size=10,
max_seq_len=20,
),
)
.env_runners(
num_env_runners=1,
# Test with compression.
compress_observations=True,
enable_connectors=True,
)
.callbacks(MyCallbacks)
.evaluation(
evaluation_duration=2,
evaluation_duration_unit="episodes",
evaluation_num_env_runners=1,
)
) # For checking lr-schedule correctness.

num_iterations = 2

for env in ["FrozenLake-v1", "ALE/MsPacman-v5"]:
print("Env={}".format(env))
for lstm in [False, True]:
print("LSTM={}".format(lstm))
config.training(
model=dict(
use_lstm=lstm,
lstm_use_prev_action=lstm,
lstm_use_prev_reward=lstm,
)
)

algo = config.build(env=env)
policy = algo.get_policy()
entropy_coeff = algo.get_policy().entropy_coeff
lr = policy.cur_lr
check(entropy_coeff, 0.1)
check(lr, config.lr)

for i in range(num_iterations):
results = algo.train()
check_train_results(results)
print(results)

algo.evaluate()

check_inference_w_connectors(policy, env_name=env)
algo.stop()

def test_ppo_compilation_and_schedule_mixins(self):
"""Test whether PPO can be built with all frameworks."""

# Build a PPOConfig object with the `SingleAgentEnvRunner` class.
config = (
ppo.PPOConfig()
# Enable new API stack and use EnvRunner.
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
.training(
# Setup lr schedule for testing.
lr_schedule=[[0, 5e-5], [256, 0.0]],
# Set entropy_coeff to a faulty value to proof that it'll get
# overridden by the schedule below (which is expected).
entropy_coeff=100.0,
entropy_coeff_schedule=[[0, 0.1], [512, 0.0]],
train_batch_size=256,
minibatch_size=128,
num_epochs=2,
model=dict(
# Settings in case we use an LSTM.
lstm_cell_size=10,
max_seq_len=20,
),
)
.env_runners(num_env_runners=0)
.training(
Expand Down
1 change: 1 addition & 0 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,6 +1095,7 @@ def update_from_iterator(
)

self._check_is_built()
# minibatch_size = minibatch_size or 32

# Call `before_gradient_based_update` to allow for non-gradient based
# preparations-, logging-, and update logic to happen.
Expand Down
1 change: 1 addition & 0 deletions rllib/core/learner/learner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from ray.rllib.utils.annotations import override
from ray.rllib.utils.checkpoints import Checkpointable
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
from ray.rllib.utils.minibatch_utils import (
ShardBatchIterator,
ShardEpisodesIterator,
Expand Down
212 changes: 212 additions & 0 deletions rllib/env/tests/test_multi_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,7 @@ def test_multi_agent_with_flex_agents(self):
PPOConfig()
.environment("flex_agents_multi_agent")
.env_runners(num_env_runners=0)
.framework("tf")
.training(train_batch_size=50, minibatch_size=50, num_epochs=1)
)
algo = config.build()
Expand Down Expand Up @@ -810,6 +811,217 @@ def is_recurrent(self):
check(batch["state_in_0"][i], h)
check(batch["state_out_0"][i], h)

def test_returning_model_based_rollouts_data(self):
# TODO(avnishn): This test only works with the old api

class ModelBasedPolicy(DQNTFPolicy):
def compute_actions_from_input_dict(
self, input_dict, explore=None, timestep=None, episodes=None, **kwargs
):
obs_batch = input_dict["obs"]
# In policy loss initialization phase, no episodes are passed
# in.
if episodes is not None:
# Pretend we did a model-based rollout and want to return
# the extra trajectory.
env_id = episodes[0].env_id
fake_eps = Episode(
episodes[0].policy_map,
episodes[0].policy_mapping_fn,
lambda: None,
lambda x: None,
env_id,
)
builder = get_global_worker().sampler.sample_collector
agent_id = "extra_0"
policy_id = "p1" # use p1 so we can easily check it
builder.add_init_obs(
episode=fake_eps,
agent_id=agent_id,
policy_id=policy_id,
env_id=env_id,
init_obs=obs_batch[0],
init_infos={},
)
for t in range(4):
builder.add_action_reward_next_obs(
episode_id=fake_eps.episode_id,
agent_id=agent_id,
env_id=env_id,
policy_id=policy_id,
agent_done=t == 3,
values=dict(
t=t,
actions=0,
rewards=0,
terminateds=False,
truncateds=t == 3,
infos={},
new_obs=obs_batch[0],
),
)
batch = builder.postprocess_episode(episode=fake_eps, build=True)
episodes[0].add_extra_batch(batch)

# Just return zeros for actions
return [0] * len(obs_batch), [], {}

ev = RolloutWorker(
env_creator=lambda _: MultiAgentCartPole({"num_agents": 2}),
default_policy_class=ModelBasedPolicy,
config=DQNConfig()
.framework("tf")
.env_runners(
rollout_fragment_length=5,
num_env_runners=0,
enable_connectors=False, # only works with old episode API
)
.multi_agent(
policies={"p0", "p1"},
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "p0",
),
)
batch = ev.sample()
# 5 environment steps (rollout_fragment_length).
check(batch.count, 5)
# 10 agent steps for p0: 2 agents, both using p0 as their policy.
check(batch.policy_batches["p0"].count, 10)
# 20 agent steps for p1: Each time both(!) agents takes 1 step,
# p1 takes 4: 5 (rollout-fragment length) * 4 = 20
check(batch.policy_batches["p1"].count, 20)

def test_train_multi_agent_cartpole_single_policy(self):
n = 10
register_env(
"multi_agent_cartpole", lambda _: MultiAgentCartPole({"num_agents": n})
)
config = (
PPOConfig()
.environment("multi_agent_cartpole")
.env_runners(num_env_runners=0)
.framework("tf")
)

algo = config.build()
for i in range(50):
result = algo.train()
print(
"Iteration {}, reward {}, timesteps {}".format(
i,
result[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN],
result[NUM_ENV_STEPS_SAMPLED_LIFETIME],
)
)
if result[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] >= 50 * n:
algo.stop()
return
raise Exception("failed to improve reward")

def test_train_multi_agent_cartpole_multi_policy(self):
n = 10
register_env(
"multi_agent_cartpole", lambda _: MultiAgentCartPole({"num_agents": n})
)

def gen_policy():
config = PPOConfig.overrides(
gamma=random.choice([0.5, 0.8, 0.9, 0.95, 0.99]),
lr=random.choice([0.001, 0.002, 0.003]),
)
return PolicySpec(config=config)

config = (
PPOConfig()
.environment("multi_agent_cartpole")
.env_runners(num_env_runners=0)
.multi_agent(
policies={
"policy_1": gen_policy(),
"policy_2": gen_policy(),
},
policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: (
"policy_1"
),
)
.framework("tf")
.training(train_batch_size=50, minibatch_size=50, num_epochs=1)
)

algo = config.build()
# Just check that it runs without crashing
for i in range(10):
result = algo.train()
print(
"Iteration {}, reward {}, timesteps {}".format(
i,
result[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN],
result[NUM_ENV_STEPS_SAMPLED_LIFETIME],
)
)
self.assertTrue(
algo.compute_single_action([0, 0, 0, 0], policy_id="policy_1") in [0, 1]
)
self.assertTrue(
algo.compute_single_action([0, 0, 0, 0], policy_id="policy_2") in [0, 1]
)
self.assertRaisesRegex(
KeyError,
"not found in PolicyMap",
lambda: algo.compute_single_action([0, 0, 0, 0], policy_id="policy_3"),
)

def test_space_in_preferred_format(self):
env = NestedMultiAgentEnv()
action_space_in_preferred_format = (
env._check_if_action_space_maps_agent_id_to_sub_space()
)
obs_space_in_preferred_format = (
env._check_if_obs_space_maps_agent_id_to_sub_space()
)
assert action_space_in_preferred_format, "Act space is not in preferred format."
assert obs_space_in_preferred_format, "Obs space is not in preferred format."

env2 = make_multi_agent("CartPole-v1")()
action_spaces_in_preferred_format = (
env2._check_if_action_space_maps_agent_id_to_sub_space()
)
obs_space_in_preferred_format = (
env2._check_if_obs_space_maps_agent_id_to_sub_space()
)
assert (
action_spaces_in_preferred_format
), "Action space should be in preferred format but isn't."
assert (
obs_space_in_preferred_format
), "Observation space should be in preferred format but isn't."

def test_spaces_sample_contain_in_preferred_format(self):
env = NestedMultiAgentEnv()
# this environment has spaces that are in the preferred format
# for multi-agent environments where the spaces are dict spaces
# mapping agent-ids to sub-spaces
obs = env.observation_space_sample()
assert env.observation_space_contains(
obs
), "Observation space does not contain obs"

action = env.action_space_sample()
assert env.action_space_contains(action), "Action space does not contain action"

def test_spaces_sample_contain_not_in_preferred_format(self):
env = make_multi_agent("CartPole-v1")({"num_agents": 2})
# this environment has spaces that are not in the preferred format
# for multi-agent environments where the spaces not in the preferred
# format, users must override the observation_space_contains,
# action_space_contains observation_space_sample,
# and action_space_sample methods in order to do proper checks
obs = env.observation_space_sample()
assert env.observation_space_contains(
obs
), "Observation space does not contain obs"
action = env.action_space_sample()
assert env.action_space_contains(action), "Action space does not contain action"


if __name__ == "__main__":
import pytest
Expand Down
Loading

0 comments on commit fe0239e

Please sign in to comment.