Skip to content

Commit

Permalink
[RLlib] Fix rest of PPO RL Modules tests (ray-project#35672)
Browse files Browse the repository at this point in the history
Signed-off-by: Artur Niederfahrenhorst <artur@anyscale.com>
  • Loading branch information
ArturNiederfahrenhorst authored and scv119 committed Jun 11, 2023
1 parent ae66e7d commit 1f1e88c
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 25 deletions.
2 changes: 1 addition & 1 deletion rllib/evaluation/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def compute_gae_for_sample_batch(
# This implementation right now will compute even the action_dist which
# will not be needed but takes time to compute.
if policy.framework == "torch":
input_dict = convert_to_torch_tensor(input_dict)
input_dict = convert_to_torch_tensor(input_dict, device=policy.device)
# TODO (sven): Fix this once we support RNNs on the new stack.
input_dict[STATE_IN] = input_dict[SampleBatch.SEQ_LENS] = None
input_dict = NestedDict(input_dict)
Expand Down
14 changes: 10 additions & 4 deletions rllib/examples/deterministic_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,14 @@
check(results1["hist_stats"], results2["hist_stats"])
# As well as training behavior (minibatch sequence during SGD
# iterations).
check(
results1["info"][LEARNER_INFO][DEFAULT_POLICY_ID]["learner_stats"],
results2["info"][LEARNER_INFO][DEFAULT_POLICY_ID]["learner_stats"],
)
if config._enable_learner_api:
check(
results1["info"][LEARNER_INFO][DEFAULT_POLICY_ID],
results2["info"][LEARNER_INFO][DEFAULT_POLICY_ID],
)
else:
check(
results1["info"][LEARNER_INFO][DEFAULT_POLICY_ID]["learner_stats"],
results2["info"][LEARNER_INFO][DEFAULT_POLICY_ID]["learner_stats"],
)
ray.shutdown()
2 changes: 1 addition & 1 deletion rllib/examples/env/greyscale_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def env_creator(config):
lr=0.0001,
grad_clip=100,
sgd_minibatch_size=500,
train_batch_size=5000,
train_batch_size=5000 if not args.as_test else 1000,
model={"vf_share_layers": True},
)
.resources(num_gpus=1 if not args.as_test else 0)
Expand Down
11 changes: 11 additions & 0 deletions rllib/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,6 +1494,17 @@ def _initialize_loss_from_dummy_batch(
stats_fn(self, train_batch)
if hasattr(self, "stats_fn") and not self.config["in_evaluation"]:
self.stats_fn(train_batch)
else:
# This is not needed to run a training with the Learner API, but useful if
# we want to create a batche of data for training from view requirements.
for key in set(postprocessed_batch.keys()).difference(
set(new_batch.keys())
):
# Add all columns generated by postprocessing to view requirements.
if key not in self.view_requirements and key != SampleBatch.SEQ_LENS:
self.view_requirements[key] = ViewRequirement(
used_for_compute_actions=False
)

# Re-enable tracing.
self._no_tracing = False
Expand Down
62 changes: 49 additions & 13 deletions rllib/utils/debug/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,19 +111,21 @@ def code():
if test:
results_per_category["policy"].extend(test)

# Call `learn_on_batch()` n times.
dummy_batch = policy._get_dummy_batch_from_view_requirements(batch_size=16)

test = _test_some_code_for_memory_leaks(
desc="Calling `learn_on_batch()`.",
init=None,
code=lambda: policy.learn_on_batch(dummy_batch),
# How many times to repeat the function call?
repeats=repeats or 100,
max_num_trials=max_num_trials,
)
if test:
results_per_category["policy"].extend(test)
# Testing this only makes sense if the learner API is disabled.
if not policy.config.get("_enable_learner_api", False):
# Call `learn_on_batch()` n times.
dummy_batch = policy._get_dummy_batch_from_view_requirements(batch_size=16)

test = _test_some_code_for_memory_leaks(
desc="Calling `learn_on_batch()`.",
init=None,
code=lambda: policy.learn_on_batch(dummy_batch),
# How many times to repeat the function call?
repeats=repeats or 100,
max_num_trials=max_num_trials,
)
if test:
results_per_category["policy"].extend(test)

# Test only the model.
if "model" in to_check:
Expand Down Expand Up @@ -170,4 +172,38 @@ def code():
if test:
results_per_category["rollout_worker"].extend(test)

if "learner" in to_check and algorithm.config.get("_enable_learner_api", False):
learner_group = algorithm.learner_group
assert learner_group._is_local, (
"This test will miss leaks hidden in remote "
"workers. Please make sure that there is a "
"local learner inside the learner group for "
"this test."
)

dummy_batch = (
algorithm.get_policy()
._get_dummy_batch_from_view_requirements(batch_size=16)
.as_multi_agent()
)

print("Looking for leaks in Learner")

def code():
learner_group.update(dummy_batch)

# Call `compute_actions_from_input_dict()` n times.
test = _test_some_code_for_memory_leaks(
desc="Calling `LearnerGroup.update()`.",
init=None,
code=code,
# How many times to repeat the function call?
repeats=repeats or 400,
# How many times to re-try if we find a suspicious memory
# allocation?
max_num_trials=max_num_trials,
)
if test:
results_per_category["learner"].extend(test)

return results_per_category
16 changes: 12 additions & 4 deletions rllib/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,10 +1147,18 @@ def check_reproducibilty(
check(results1["hist_stats"], results2["hist_stats"])
# As well as training behavior (minibatch sequence during SGD
# iterations).
check(
results1["info"][LEARNER_INFO][DEFAULT_POLICY_ID]["learner_stats"],
results2["info"][LEARNER_INFO][DEFAULT_POLICY_ID]["learner_stats"],
)
# As well as training behavior (minibatch sequence during SGD
# iterations).
if algo_config._enable_learner_api:
check(
results1["info"][LEARNER_INFO][DEFAULT_POLICY_ID],
results2["info"][LEARNER_INFO][DEFAULT_POLICY_ID],
)
else:
check(
results1["info"][LEARNER_INFO][DEFAULT_POLICY_ID]["learner_stats"],
results2["info"][LEARNER_INFO][DEFAULT_POLICY_ID]["learner_stats"],
)


def get_cartpole_dataset_reader(batch_size: int = 1) -> "DatasetReader":
Expand Down
4 changes: 2 additions & 2 deletions rllib/utils/tests/run_memory_leak_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@
parser.add_argument(
"--to-check",
nargs="+",
default=["env", "policy", "rollout_worker"],
help="List of 'env', 'policy', 'rollout_worker', 'model'.",
default=["env", "policy", "rollout_worker", "learner"],
help="List of 'env', 'policy', 'rollout_worker', 'model', 'learner'.",
)

# Obsoleted arg, use --dir instead.
Expand Down

0 comments on commit 1f1e88c

Please sign in to comment.