Skip to content

Commit

Permalink
[RLlib] Add "shuffle batch per epoch" option. (ray-project#47458)
Browse files Browse the repository at this point in the history
Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
  • Loading branch information
sven1977 authored and ujjawal-khare committed Oct 15, 2024
1 parent 70e487b commit ae30748
Show file tree
Hide file tree
Showing 19 changed files with 356 additions and 52 deletions.
16 changes: 1 addition & 15 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2184,7 +2184,6 @@ def training(
learner_config_dict: Optional[Dict[str, Any]] = NotProvided,
# Deprecated args.
num_sgd_iter=DEPRECATED_VALUE,
max_requests_in_flight_per_sampler_worker=DEPRECATED_VALUE,
) -> "AlgorithmConfig":
"""Sets the training related configuration.
Expand Down Expand Up @@ -2240,7 +2239,7 @@ def training(
minibatch_size: The size of minibatches to use to further split the train
batch into.
shuffle_batch_per_epoch: Whether to shuffle the train batch once per epoch.
If the train batch has a time rank (axis=1), shuffling only takes
If the train batch has a time rank (axis=1), shuffling will only take
place along the batch axis to not disturb any intact (episode)
trajectories.
model: Arguments passed into the policy model. See models/catalog.py for a
Expand Down Expand Up @@ -2284,19 +2283,6 @@ def training(
error=False,
)
num_epochs = num_sgd_iter
if max_requests_in_flight_per_sampler_worker != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.training("
"max_requests_in_flight_per_sampler_worker=...)",
new="AlgorithmConfig.env_runners("
"max_requests_in_flight_per_env_runner=...)",
error=False,
)
self.env_runners(
max_requests_in_flight_per_env_runner=(
max_requests_in_flight_per_sampler_worker
),
)

if gamma is not NotProvided:
self.gamma = gamma
Expand Down
12 changes: 10 additions & 2 deletions rllib/algorithms/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def training(
lambda_: The lambda parameter for General Advantage Estimation (GAE).
Defines the exponential weight used between actually measured rewards
vs value function estimates over multiple time steps. Specifically,
`lambda_` balances short-term, low-variance estimates with longer-term,
`lambda_` balances short-term, low-variance estimates against long-term,
high-variance returns. A `lambda_` of 0.0 makes the GAE rely only on
immediate rewards (and vf predictions from there on, reducing variance,
but increasing bias), while a `lambda_` of 1.0 only incorporates vf
Expand Down Expand Up @@ -532,7 +532,15 @@ def _training_step_old_api_stack(self) -> ResultDict:
# Standardize advantages.
train_batch = standardize_fields(train_batch, ["advantages"])

if self.config.simple_optimizer:
# Perform a train step on the collected batch.
if self.config.enable_rl_module_and_learner:
train_results = self.learner_group.update_from_batch(
batch=train_batch,
minibatch_size=self.config.minibatch_size,
num_epochs=self.config.num_epochs,
)

elif self.config.simple_optimizer:
train_results = train_one_step(self, train_batch)
else:
train_results = multi_gpu_train_one_step(self, train_batch)
Expand Down
85 changes: 81 additions & 4 deletions rllib/algorithms/ppo/tests/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,93 @@ def setUpClass(cls):
def tearDownClass(cls):
ray.shutdown()

def test_ppo_compilation_w_connectors(self):
"""Test whether PPO can be built with all frameworks w/ connectors."""

# Build a PPOConfig object.
config = (
ppo.PPOConfig()
.training(
num_epochs=2,
# Setup lr schedule for testing.
lr_schedule=[[0, 5e-5], [128, 0.0]],
# Set entropy_coeff to a faulty value to proof that it'll get
# overridden by the schedule below (which is expected).
entropy_coeff=100.0,
entropy_coeff_schedule=[[0, 0.1], [256, 0.0]],
train_batch_size=128,
model=dict(
# Settings in case we use an LSTM.
lstm_cell_size=10,
max_seq_len=20,
),
)
.env_runners(
num_env_runners=1,
# Test with compression.
compress_observations=True,
enable_connectors=True,
)
.callbacks(MyCallbacks)
.evaluation(
evaluation_duration=2,
evaluation_duration_unit="episodes",
evaluation_num_env_runners=1,
)
) # For checking lr-schedule correctness.

num_iterations = 2

for env in ["FrozenLake-v1", "ALE/MsPacman-v5"]:
print("Env={}".format(env))
for lstm in [False, True]:
print("LSTM={}".format(lstm))
config.training(
model=dict(
use_lstm=lstm,
lstm_use_prev_action=lstm,
lstm_use_prev_reward=lstm,
)
)

algo = config.build(env=env)
policy = algo.get_policy()
entropy_coeff = algo.get_policy().entropy_coeff
lr = policy.cur_lr
check(entropy_coeff, 0.1)
check(lr, config.lr)

for i in range(num_iterations):
results = algo.train()
check_train_results(results)
print(results)

algo.evaluate()

check_inference_w_connectors(policy, env_name=env)
algo.stop()

def test_ppo_compilation_and_schedule_mixins(self):
"""Test whether PPO can be built with all frameworks."""

# Build a PPOConfig object with the `SingleAgentEnvRunner` class.
config = (
ppo.PPOConfig()
# Enable new API stack and use EnvRunner.
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
.training(
# Setup lr schedule for testing.
lr_schedule=[[0, 5e-5], [256, 0.0]],
# Set entropy_coeff to a faulty value to proof that it'll get
# overridden by the schedule below (which is expected).
entropy_coeff=100.0,
entropy_coeff_schedule=[[0, 0.1], [512, 0.0]],
train_batch_size=256,
minibatch_size=128,
num_epochs=2,
model=dict(
# Settings in case we use an LSTM.
lstm_cell_size=10,
max_seq_len=20,
),
)
.env_runners(num_env_runners=0)
.training(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def __call__(
# [:-1]: Shift state outs by one, ignore very last
# STATE_OUT (but therefore add the lookback/init state at
# the beginning).
lambda i, o: np.concatenate([[i], o[:-1]])[::max_seq_len],
lambda i, o, m=max_seq_len: np.concatenate([[i], o[:-1]])[::m],
look_back_state,
state_outs,
),
Expand Down
3 changes: 2 additions & 1 deletion rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,7 @@ def compute_losses(
fwd_out: Output from a call to the `forward_train()` method of the
underlying MultiRLModule (`self.module`) during training
(`self.update()`).
batch: The training batch that was used to compute `fwd_out`.
batch: The train batch that was used to compute `fwd_out`.
Returns:
A dictionary mapping module IDs to individual loss terms.
Expand Down Expand Up @@ -1094,6 +1094,7 @@ def update_from_iterator(
)

self._check_is_built()
# minibatch_size = minibatch_size or 32

# Call `before_gradient_based_update` to allow for non-gradient based
# preparations-, logging-, and update logic to happen.
Expand Down
1 change: 1 addition & 0 deletions rllib/core/learner/learner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from ray.rllib.utils.annotations import override
from ray.rllib.utils.checkpoints import Checkpointable
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
from ray.rllib.utils.minibatch_utils import (
ShardBatchIterator,
ShardEpisodesIterator,
Expand Down
Loading

0 comments on commit ae30748

Please sign in to comment.