Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Initial design for Ray-Data based offline RL Algos (on new API stack). #44969

Merged
merged 74 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
3f13747
Sketched a simple offline data class with usage in 'Algorithm'.
simonsays1980 Apr 25, 2024
f33e595
LINTER.
simonsays1980 Apr 25, 2024
1496723
Merged master
simonsays1980 May 6, 2024
801911d
Implemented test for 'OfflineData' class and stored intermediate work.
simonsays1980 May 6, 2024
ff46fa2
LINTER.
simonsays1980 May 6, 2024
caa48d0
Added a basic workflow to convert batches into list of episodes for L…
simonsays1980 May 6, 2024
c748df8
Changed comment.
simonsays1980 May 10, 2024
6409007
Merge branch 'master' of https://github.com/ray-project/ray
simonsays1980 May 13, 2024
d2f9030
Merge branch 'master' of https://github.com/ray-project/ray
simonsays1980 May 14, 2024
a3416a8
Merge branch 'master' of https://github.com/ray-project/ray
simonsays1980 May 15, 2024
8582ad9
Merge branch 'master' of https://github.com/ray-project/ray
simonsays1980 May 16, 2024
b565f34
Merge branch 'master' of https://github.com/ray-project/ray
simonsays1980 May 21, 2024
c0eed1f
Merge branch 'master' of https://github.com/ray-project/ray
simonsays1980 May 22, 2024
341cb95
Merge branch 'master' of https://github.com/ray-project/ray
simonsays1980 May 22, 2024
b76807f
Merge branch 'master' of https://github.com/ray-project/ray
simonsays1980 May 24, 2024
c84fab8
Merged master.
simonsays1980 May 24, 2024
85cf954
Initial commit for BC with offline data API in new stack.
simonsays1980 May 24, 2024
69157f5
Implemented BC in new API stack with Ray Data API and Learner API usi…
simonsays1980 May 27, 2024
af9c9e9
Merge branch 'master' of https://github.com/ray-project/ray
simonsays1980 May 27, 2024
780b49d
Merge branch 'master' into offline-data-new-stack
simonsays1980 May 27, 2024
7337d08
Added new test to BUILD file.
simonsays1980 May 27, 2024
da83264
Added new test for offline data to BUILD file.
simonsays1980 May 27, 2024
82ae5bd
Added functionality to map batches directly to episode lists in an it…
simonsays1980 May 27, 2024
06ed2e5
Fixed bug in test.
simonsays1980 May 28, 2024
bafdcba
Added @sven1977's review.
simonsays1980 May 28, 2024
72cd797
Added locality hints for distributed training.
simonsays1980 May 28, 2024
848a205
Multi-learner initialization.
simonsays1980 May 31, 2024
d39ef0a
LINTER.
simonsays1980 May 31, 2024
4952de7
Tryout with callable class in 'map_batches'.
simonsays1980 Jun 6, 2024
d4479ff
Merged master.
simonsays1980 Jun 24, 2024
372a107
Added resampled JSONL CartPole-v1 dataset from cartpole-small.json wi…
simonsays1980 Jun 24, 2024
5ffa94b
Modified episode conversion to work with new data format for offlien …
simonsays1980 Jun 25, 2024
dcf9524
Added 'batch_size' to the 'map_batches' and modified '_map_to_episode…
simonsays1980 Jun 25, 2024
04ce0f0
Added large CartPole-v1 data in new format. Modified BC algorithm, fi…
simonsays1980 Jun 25, 2024
72088fa
Merge branch 'master' into offline-data-new-stack
simonsays1980 Jun 25, 2024
40e9b35
Started multi-learner setup.
simonsays1980 Jun 25, 2024
89b06fe
Added tuned example for BC with new offline API.
simonsays1980 Jun 25, 2024
0da5db1
Merge branch 'master' into offline-data-new-stack
simonsays1980 Jun 26, 2024
8d5f1bd
Set up multi-learner training and tested it.
simonsays1980 Jun 26, 2024
5128bdb
Merged master.
simonsays1980 Jun 27, 2024
5d9dd97
Set default to 'parquet' files. Tested for different learner setups a…
simonsays1980 Jun 27, 2024
19d99a5
Added parquet files for cartpole and pendulum data. Also reset the de…
simonsays1980 Jun 27, 2024
c437210
Merge branch 'master' into offline-data-new-stack
simonsays1980 Jun 28, 2024
f1a7663
Added @sven1977's review. Also added 'override_num_blocks' to tuned e…
simonsays1980 Jun 28, 2024
dea037b
Disabled hybrid stack. Tested old stack and made some cleanups.
simonsays1980 Jun 28, 2024
356222a
Much refactoring and fixing smaller and larger bugs related to transf…
simonsays1980 Jul 3, 2024
fb3eea0
Merge branch 'master' into offline-data-new-stack
simonsays1980 Jul 3, 2024
8b16489
Readded the test for BC in old stack.
simonsays1980 Jul 3, 2024
68fabec
Fixed a small bug, due to the fact that 'Algorithm' objects do not ha…
simonsays1980 Jul 3, 2024
980ebcc
Reset concurrency. This was a relict from testing.
simonsays1980 Jul 3, 2024
8d45a78
Fixed some minor bugs that let tests failing.
simonsays1980 Jul 4, 2024
38b43cf
Merge branch 'master' into offline-data-new-stack
simonsays1980 Jul 4, 2024
b7b0a34
Another small bug fix due to the hybrid stack.
simonsays1980 Jul 4, 2024
9870e5f
Set training step such that old and hybrid stacks are training on MAR…
simonsays1980 Jul 4, 2024
56e5de5
Refactored hybrid and new stack training logic into two separate meth…
simonsays1980 Jul 4, 2024
34fa38e
Some small nits.
simonsays1980 Jul 4, 2024
1a81cd9
More small nits fixed in test file for OfflineData.
simonsays1980 Jul 7, 2024
6b4939f
Merged master.
simonsays1980 Jul 9, 2024
a3fe5fb
Added changes to BC to enable multi-learner.
simonsays1980 Jul 9, 2024
88786b6
Fixed a bug in offline data tests and refactored. In addition changed…
simonsays1980 Jul 10, 2024
01a4f54
Fixed a bug in BC with policies being a set not a list.
simonsays1980 Jul 10, 2024
e8ade96
Fixed data path in BUILD.
simonsays1980 Jul 10, 2024
2b3af33
Fixed a small bug in 'LearnerGroup' due to mistyped arguments.
simonsays1980 Jul 11, 2024
b226c2e
Small modification of elarning rate in multi-agent SAC test.
simonsays1980 Jul 11, 2024
b406295
Merge branch 'master' into offline-data-new-stack
simonsays1980 Jul 12, 2024
6c855ad
Merged master.
simonsays1980 Jul 15, 2024
7b928d2
Merge branch 'master' into offline-data-new-stack
simonsays1980 Jul 16, 2024
6582320
Added a further check to deal in Offline Data setups with 'PolicyServ…
simonsays1980 Jul 16, 2024
30c3f8d
Readded '_set_optimizer_state' after erroneously removing it.
simonsays1980 Jul 16, 2024
59cf300
Merge branch 'master' into offline-data-new-stack
simonsays1980 Jul 17, 2024
a1cf4f1
Merged master and modified 'OfflineData' to integrate the newest chan…
simonsays1980 Jul 18, 2024
be43bcd
Saving state.
simonsays1980 Jul 18, 2024
30f4170
Fixed bug with synching the weights between learner and local worker …
simonsays1980 Jul 18, 2024
7591410
Merge branch 'master' into offline-data-new-stack
simonsays1980 Jul 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
18 changes: 18 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,14 @@ py_test(
)

# BC
py_test(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! Old stack is no longer the norm :D

name = "test_bc_old_stack",
tags = ["team:rllib", "algorithms_dir"],
size = "medium",
# Include the json data file.
data = ["tests/data/cartpole/large.json"],
srcs = ["algorithms/bc/tests/test_bc_old_stack.py"]
)
py_test(
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
name = "test_bc",
tags = ["team:rllib", "algorithms_dir"],
Expand Down Expand Up @@ -1531,6 +1539,16 @@ py_test(
srcs = ["offline/estimators/tests/test_dr_learning.py"],
)

py_test(
name = "test_offline_data",
tags = ["team:rllib", "offline"],
size = "small",
srcs = ["offline/tests/test_offline_data.py"],
data = [
"tests/data/pendulum/small.json",
],
)

# --------------------------------------------------------------------
# Policies
# rllib/policy/
Expand Down
43 changes: 42 additions & 1 deletion rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,12 +620,28 @@ def setup(self, config: AlgorithmConfig) -> None:
validate_env=self.validate_env,
default_policy_class=self.get_default_policy_class(self.config),
config=self.config,
num_env_runners=self.config.num_env_runners,
num_env_runners=0 if self.config.input_ else self.config.num_env_runners,
local_env_runner=True,
logdir=self.logdir,
tune_trial_id=self.trial_id,
)

# Ensure remote workers are initially in sync with the local worker.
self.workers.sync_weights(inference_only=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dumb question: Why do we need this additional sync? Which (currently existing) sync does this replace?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a left-over from a missed merge with master? Could you check?

# If and input path is available and we are on the new API stack generate
# an `OfflineData` instance.
if (
self.config.input_
and self.config.input_ != "sampler"
and self.config._enable_new_api_stack
):
from ray.rllib.offline.offline_data import OfflineData

self.offline_data = OfflineData(self.config)
# Otherwise set the attribute to `None`.
else:
self.offline_data = None

# Compile, validate, and freeze an evaluation config.
self.evaluation_config = self.config.get_evaluation_config_object()
self.evaluation_config.validate()
Expand Down Expand Up @@ -706,6 +722,8 @@ def setup(self, config: AlgorithmConfig) -> None:
# Need to add back method_type in case Algorithm is restored from checkpoint
method_config["type"] = method_type

# TODO (sven): Probably obsolete b/c the learner group is already None.
self.learner_group = None
if self.config.enable_rl_module_and_learner:
local_worker = self.workers.local_worker()
env = spaces = None
Expand Down Expand Up @@ -781,6 +799,29 @@ def setup(self, config: AlgorithmConfig) -> None:
),
)

if self.offline_data:
# If the learners are remote we need to provide specific
# information and the learner's actor handles.
if self.learner_group.is_remote:
# If learners run on different nodes, locality hints help
# to use the nearest learner in the workers that do the
# data preprocessing.
learner_node_ids = self.learner_group.foreach_learner(
lambda l: ray.get_runtime_context().get_node_id()
)
self.offline_data.locality_hints = [
node_id.get() for node_id in learner_node_ids
]
# Provide the actor handles for the learners for module
# updating during preprocessing.
self.offline_data.learner_handles = self.learner_group._workers
# Provide the module_spec. Note, in the remote case this is needed
# because the learner module cannot be copied, but must be built.
self.offline_data.module_spec = module_spec
# Otherwise we can simply pass in the local learner.
else:
self.offline_data.learner_handles = [self.learner_group._learner]

# Run `on_algorithm_init` callback after initialization is done.
self.callbacks.on_algorithm_init(algorithm=self, metrics_logger=self.metrics)

Expand Down
35 changes: 34 additions & 1 deletion rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,10 @@ def __init__(self, algo_class: Optional[type] = None):

# `self.offline_data()`
self.input_ = "sampler"
self.input_read_method = "read_parquet"
self.input_read_method_kwargs = {}
self.prelearner_module_synch_period = 10
self.dataset_num_iters_per_learner = None
self.input_config = {}
self.actions_in_input_normalized = False
self.postprocess_inputs = False
Expand Down Expand Up @@ -2368,6 +2372,10 @@ def offline_data(
self,
*,
input_=NotProvided,
input_read_method=NotProvided,
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
input_read_method_kwargs=NotProvided,
prelearner_module_synch_period=NotProvided,
dataset_num_iters_per_learner=NotProvided,
input_config=NotProvided,
actions_in_input_normalized=NotProvided,
input_evaluation=NotProvided,
Expand All @@ -2392,7 +2400,24 @@ def offline_data(
- A callable that takes an `IOContext` object as only arg and returns a
ray.rllib.offline.InputReader.
- A string key that indexes a callable with tune.registry.register_input
input_config: Arguments that describe the settings for reading the input.
input_read_method: Read method for the `ray.data.Dataset` to read in the
offline data from `input_`. The default is `read_json` for JSON files.
See https://docs.ray.io/en/latest/data/api/input_output.html for more
info about available read methods in `ray.data`.
input_read_method_kwargs: kwargs for the `input_read_method`. These will be
passed into the read method without checking.
prelearner_module_synch_period: The period (number of batches converted)
after which the `RLModule` held by the `PreLearner` should sync weights.
The `PreLearner` is used to preprocess batches for the learners. The
higher this value the more off-policy the `PreLearner`'s module will be.
Values too small will force the `PreLearner` to sync a ,lot with the
`Learner` and will slow down the data pipeline. The default value chosen
by the `OfflinePreLearner` is 10.
dataset_num_iters_per_learner: Number of iterations to run in each learner
during a single training iteration. If `None`, each learner runs a
complete epoch over its data block (the dataset is partitioned into
as many blocks as there are learners). The default is `None`.
input_config: Arguments that describe the settings for reading the inpu t.
If input is `sample`, this will be environment configuation, e.g.
`env_name` and `env_config`, etc. See `EnvContext` for more info.
If the input is `dataset`, this will be e.g. `format`, `path`.
Expand Down Expand Up @@ -2430,6 +2455,14 @@ def offline_data(
"""
if input_ is not NotProvided:
self.input_ = input_
if input_read_method is not NotProvided:
self.input_read_method = input_read_method
if input_read_method_kwargs is not NotProvided:
self.input_read_method_kwargs = input_read_method_kwargs
if prelearner_module_synch_period is not NotProvided:
self.prelearner_module_synch_period = prelearner_module_synch_period
if dataset_num_iters_per_learner is not NotProvided:
self.dataset_num_iters_per_learner = dataset_num_iters_per_learner
if input_config is not NotProvided:
if not isinstance(input_config, dict):
raise ValueError(
Expand Down
119 changes: 56 additions & 63 deletions rllib/algorithms/bc/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
from ray.rllib.algorithms.bc.bc_catalog import BCCatalog
from ray.rllib.algorithms.marwil.marwil import MARWIL, MARWILConfig
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
from ray.rllib.utils.annotations import override
from ray.rllib.utils.metrics import (
ALL_MODULES,
NUM_AGENT_STEPS_SAMPLED,
NUM_AGENT_STEPS_TRAINED,
NUM_ENV_STEPS_SAMPLED,
LEARNER_RESULTS,
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
LEARNER_UPDATE_TIMER,
OFFLINE_SAMPLING_TIMER,
NUM_ENV_STEPS_TRAINED,
SAMPLE_TIMER,
NUM_ENV_STEPS_TRAINED_LIFETIME,
NUM_MODULE_STEPS_TRAINED,
NUM_MODULE_STEPS_TRAINED_LIFETIME,
SYNCH_WORKER_WEIGHTS_TIMER,
TIMERS,
)
from ray.rllib.utils.typing import RLModuleSpec, ResultDict

Expand Down Expand Up @@ -74,8 +76,10 @@ def __init__(self, algo_class=None):
# Advantages (calculated during postprocessing)
# not important for behavioral cloning.
self.postprocess_inputs = False
# Set RLModule as default.
self.api_stack(enable_rl_module_and_learner=True)
# Set RLModule as default if the `EnvRUnner`'s are used.
if self.enable_env_runner_and_connector_v2:
self.api_stack(enable_rl_module_and_learner=True)

# __sphinx_doc_end__
# fmt: on

Expand Down Expand Up @@ -144,75 +148,64 @@ def training_step(self) -> ResultDict:
return super().training_step()
else:
# Implement logic using RLModule and Learner API.
# TODO (sven): Remove RolloutWorkers/EnvRunners for
# datasets. Use RolloutWorker/EnvRunner only for
# env stepping.
# TODO (simon): Take care of sampler metrics: right
# now all rewards are `nan`, which possibly confuses
# the user that sth. is not right, although it is as
# we do not step the env.
with self._timers[SAMPLE_TIMER]:
with self.metrics.log_time((TIMERS, OFFLINE_SAMPLING_TIMER)):
sven1977 marked this conversation as resolved.
Show resolved Hide resolved
# Sampling from offline data.
# TODO (simon): We have to remove the `RolloutWorker`
# here and just use the already distributed `dataset`
# for sampling. Only in online evaluation
# `RolloutWorker/EnvRunner` should be used.
if self.config.count_steps_by == "agent_steps":
train_batch = synchronous_parallel_sample(
worker_set=self.workers,
max_agent_steps=self.config.train_batch_size,
)
else:
train_batch = synchronous_parallel_sample(
worker_set=self.workers,
max_env_steps=self.config.train_batch_size,
)

# TODO (sven): Use metrics API as soon as we moved to new API stack
# (from currently hybrid stack).
# self.metrics.log_dict(
# {
# NUM_AGENT_STEPS_SAMPLED_LIFETIME: len(train_batch),
# NUM_ENV_STEPS_SAMPLED_LIFETIME: len(train_batch),
# },
# reduce="sum",
# )
self._counters[NUM_AGENT_STEPS_SAMPLED] += len(train_batch)
self._counters[NUM_ENV_STEPS_SAMPLED] += len(train_batch)

# Updating the policy.
train_results = self.learner_group.update_from_batch(batch=train_batch)
# TODO (sven): Use metrics API as soon as we moved to new API stack
# (from currently hybrid stack).
# self.metrics.log_dict(
# {
# NUM_AGENT_STEPS_TRAINED_LIFETIME: len(train_batch),
# NUM_ENV_STEPS_TRAINED_LIFETIME: len(train_batch),
# },
# reduce="sum",
# )
self._counters[NUM_AGENT_STEPS_TRAINED] += len(train_batch)
self._counters[NUM_ENV_STEPS_TRAINED] += len(train_batch)

batch = self.offline_data.sample(
num_samples=self.config.train_batch_size_per_learner,
num_shards=self.config.num_learners,
return_iterator=True if self.config.num_learners > 1 else False,
)

with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)):
# Updating the policy.
# TODO (simon, sven): Check, if we should execute directly s.th. like
# update_from_iterator.
learner_results = self.learner_group.update_from_batch(
batch,
minibatch_size=self.config.train_batch_size_per_learner,
num_iters=self.config.dataset_num_iters_per_learner,
)

# Log training results.
self.metrics.merge_and_log_n_dicts(learner_results, key=LEARNER_RESULTS)
self.metrics.log_value(
NUM_ENV_STEPS_TRAINED_LIFETIME,
self.metrics.peek(
(LEARNER_RESULTS, ALL_MODULES, NUM_ENV_STEPS_TRAINED)
),
reduce="sum",
)
self.metrics.log_dict(
{
(LEARNER_RESULTS, mid, NUM_MODULE_STEPS_TRAINED_LIFETIME): (
stats[NUM_MODULE_STEPS_TRAINED]
)
for mid, stats in self.metrics.peek(LEARNER_RESULTS).items()
},
reduce="sum",
)
# Synchronize weights.
# As the results contain for each policy the loss and in addition the
# total loss over all policies is returned, this total loss has to be
# removed.
policies_to_update = set(train_results.keys()) - {ALL_MODULES}
modules_to_update = set(learner_results[0].keys()) - {ALL_MODULES}

# with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)):
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
# Update weights - after learning on the local worker -
# on all remote workers.
with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)):
if self.workers.num_remote_workers() > 0:
self.workers.sync_weights(
from_worker_or_learner_group=self.learner_group,
policies=policies_to_update,
policies=modules_to_update,
inference_only=True,
)
# Get weights from Learner to local worker.
# Then we must have a local worker.
else:
self.workers.local_worker().set_weights(
self.learner_group.get_weights()
)
weights = self.learner_group.get_weights(inference_only=True)
self.workers.local_worker().set_weights(weights)

# TODO (sven): Use metrics API as soon as we moved to new API stack
# (from currently hybrid stack).
return train_results
return self.metrics.reduce()
sven1977 marked this conversation as resolved.
Show resolved Hide resolved
25 changes: 25 additions & 0 deletions rllib/algorithms/bc/bc_learner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from ray.rllib.core.learner.learner import Learner
from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
AddObservationsFromEpisodesToBatch,
)
from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa
AddNextObservationsFromEpisodesToTrainBatch,
)
from ray.rllib.utils.annotations import (
override,
OverrideToImplementCustomLogic_CallToSuperRecommended,
)


class BCLearner(Learner):
@OverrideToImplementCustomLogic_CallToSuperRecommended
@override(Learner)
def build(self) -> None:
super().build()
# Prepend a NEXT_OBS from episodes to train batch connector piece (right
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, one has to know however which connectors are needed :D

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's actually fine.

  • We need to document well all the off-the-shelf RLlib connectors, e.g. AddObsToBatch, Flatten, etc..
  • An algo now has the chance to assemble its connector pipeline based on needs (and assumptions that certain pieces will always be there, so the algo's custom ones can be prepended/appended to these).

# after the observation default piece).
if self.config.add_default_connectors_to_learner_pipeline:
self._learner_connector.insert_after(
AddObservationsFromEpisodesToBatch,
AddNextObservationsFromEpisodesToTrainBatch(),
)
Loading
Loading