-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Initial design for Ray-Data based offline RL Algos (on new API stack). #44969
Changes from 50 commits
3f13747
f33e595
1496723
801911d
ff46fa2
caa48d0
c748df8
6409007
d2f9030
a3416a8
8582ad9
b565f34
c0eed1f
341cb95
b76807f
c84fab8
85cf954
69157f5
af9c9e9
780b49d
7337d08
da83264
82ae5bd
06ed2e5
bafdcba
72cd797
848a205
d39ef0a
4952de7
d4479ff
372a107
5ffa94b
dcf9524
04ce0f0
72088fa
40e9b35
89b06fe
0da5db1
8d5f1bd
5128bdb
5d9dd97
19d99a5
c437210
f1a7663
dea037b
356222a
fb3eea0
8b16489
68fabec
980ebcc
8d45a78
38b43cf
b7b0a34
9870e5f
56e5de5
34fa38e
1a81cd9
6b4939f
a3fe5fb
88786b6
01a4f54
e8ade96
2b3af33
b226c2e
b406295
6c855ad
7b928d2
6582320
30c3f8d
59cf300
a1cf4f1
be43bcd
30f4170
7591410
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -620,12 +620,28 @@ def setup(self, config: AlgorithmConfig) -> None: | |
validate_env=self.validate_env, | ||
default_policy_class=self.get_default_policy_class(self.config), | ||
config=self.config, | ||
num_env_runners=self.config.num_env_runners, | ||
num_env_runners=0 if self.config.input_ else self.config.num_env_runners, | ||
local_env_runner=True, | ||
logdir=self.logdir, | ||
tune_trial_id=self.trial_id, | ||
) | ||
|
||
# Ensure remote workers are initially in sync with the local worker. | ||
self.workers.sync_weights(inference_only=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dumb question: Why do we need this additional sync? Which (currently existing) sync does this replace? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is a left-over from a missed merge with master? Could you check? |
||
# If and input path is available and we are on the new API stack generate | ||
# an `OfflineData` instance. | ||
if ( | ||
self.config.input_ | ||
and self.config.input_ != "sampler" | ||
and self.config._enable_new_api_stack | ||
): | ||
from ray.rllib.offline.offline_data import OfflineData | ||
|
||
self.offline_data = OfflineData(self.config) | ||
# Otherwise set the attribute to `None`. | ||
else: | ||
self.offline_data = None | ||
|
||
# Compile, validate, and freeze an evaluation config. | ||
self.evaluation_config = self.config.get_evaluation_config_object() | ||
self.evaluation_config.validate() | ||
|
@@ -706,6 +722,8 @@ def setup(self, config: AlgorithmConfig) -> None: | |
# Need to add back method_type in case Algorithm is restored from checkpoint | ||
method_config["type"] = method_type | ||
|
||
# TODO (sven): Probably obsolete b/c the learner group is already None. | ||
self.learner_group = None | ||
if self.config.enable_rl_module_and_learner: | ||
local_worker = self.workers.local_worker() | ||
env = spaces = None | ||
|
@@ -781,6 +799,29 @@ def setup(self, config: AlgorithmConfig) -> None: | |
), | ||
) | ||
|
||
if self.offline_data: | ||
# If the learners are remote we need to provide specific | ||
# information and the learner's actor handles. | ||
if self.learner_group.is_remote: | ||
# If learners run on different nodes, locality hints help | ||
# to use the nearest learner in the workers that do the | ||
# data preprocessing. | ||
learner_node_ids = self.learner_group.foreach_learner( | ||
lambda l: ray.get_runtime_context().get_node_id() | ||
) | ||
self.offline_data.locality_hints = [ | ||
node_id.get() for node_id in learner_node_ids | ||
] | ||
# Provide the actor handles for the learners for module | ||
# updating during preprocessing. | ||
self.offline_data.learner_handles = self.learner_group._workers | ||
# Provide the module_spec. Note, in the remote case this is needed | ||
# because the learner module cannot be copied, but must be built. | ||
self.offline_data.module_spec = module_spec | ||
# Otherwise we can simply pass in the local learner. | ||
else: | ||
self.offline_data.learner_handles = [self.learner_group._learner] | ||
|
||
# Run `on_algorithm_init` callback after initialization is done. | ||
self.callbacks.on_algorithm_init(algorithm=self, metrics_logger=self.metrics) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from ray.rllib.core.learner.learner import Learner | ||
from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import ( | ||
AddObservationsFromEpisodesToBatch, | ||
) | ||
from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa | ||
AddNextObservationsFromEpisodesToTrainBatch, | ||
) | ||
from ray.rllib.utils.annotations import ( | ||
override, | ||
OverrideToImplementCustomLogic_CallToSuperRecommended, | ||
) | ||
|
||
|
||
class BCLearner(Learner): | ||
@OverrideToImplementCustomLogic_CallToSuperRecommended | ||
@override(Learner) | ||
def build(self) -> None: | ||
super().build() | ||
# Prepend a NEXT_OBS from episodes to train batch connector piece (right | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cool! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, one has to know however which connectors are needed :D There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's actually fine.
|
||
# after the observation default piece). | ||
if self.config.add_default_connectors_to_learner_pipeline: | ||
self._learner_connector.insert_after( | ||
AddObservationsFromEpisodesToBatch, | ||
AddNextObservationsFromEpisodesToTrainBatch(), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice! Old stack is no longer the norm :D