Skip to content

Commit

Permalink
[RLlib] Initial design for Ray-Data based offline RL Algos (on new AP…
Browse files Browse the repository at this point in the history
…I stack). (#44969)
  • Loading branch information
simonsays1980 authored Jul 22, 2024
1 parent 648a0e6 commit 66b68d0
Show file tree
Hide file tree
Showing 423 changed files with 5,433 additions and 196 deletions.
20 changes: 19 additions & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ py_test(
name = "learning_tests_pendulum_cql_old_api_stack",
main = "tests/run_regression_tests.py",
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "learning_tests_with_ray_data"],
size = "medium",
size = "large",
srcs = ["tests/run_regression_tests.py"],
# Include the zipped json data file as well.
data = [
Expand Down Expand Up @@ -839,6 +839,14 @@ py_test(
)

# BC
py_test(
name = "test_bc_old_stack",
tags = ["team:rllib", "algorithms_dir"],
size = "medium",
# Include the json data file.
data = ["tests/data/cartpole/large.json"],
srcs = ["algorithms/bc/tests/test_bc_old_stack.py"]
)
py_test(
name = "test_bc",
tags = ["team:rllib", "algorithms_dir"],
Expand Down Expand Up @@ -1574,6 +1582,16 @@ py_test(
srcs = ["offline/estimators/tests/test_dr_learning.py"],
)

py_test(
name = "test_offline_data",
tags = ["team:rllib", "offline"],
size = "small",
srcs = ["offline/tests/test_offline_data.py"],
data = [
"tests/data/cartpole/cartpole-v1_large",
],
)

# --------------------------------------------------------------------
# Policies
# rllib/policy/
Expand Down
65 changes: 63 additions & 2 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,12 +658,50 @@ def setup(self, config: AlgorithmConfig) -> None:
validate_env=self.validate_env,
default_policy_class=self.get_default_policy_class(self.config),
config=self.config,
num_env_runners=self.config.num_env_runners,
num_env_runners=(
0
if (
self.config.input_
and (
isinstance(self.config.input_, str)
or (
isinstance(self.config.input_, list)
and isinstance(self.config.input_[0], str)
)
)
and self.config.input_ != "sampler"
and self.config.enable_rl_module_and_learner
and self.config.enable_env_runner_and_connector_v2
)
else self.config.num_env_runners
),
local_env_runner=True,
logdir=self.logdir,
tune_trial_id=self.trial_id,
)

# If an input path is available and we are on the new API stack generate
# an `OfflineData` instance.
if (
self.config.input_
and (
isinstance(self.config.input_, str)
or (
isinstance(self.config.input_, list)
and isinstance(self.config.input_[0], str)
)
)
and self.config.input_ != "sampler"
and self.config.enable_rl_module_and_learner
and self.config.enable_env_runner_and_connector_v2
):
from ray.rllib.offline.offline_data import OfflineData

self.offline_data = OfflineData(self.config)
# Otherwise set the attribute to `None`.
else:
self.offline_data = None

# Compile, validate, and freeze an evaluation config.
self.evaluation_config = self.config.get_evaluation_config_object()
self.evaluation_config.validate()
Expand Down Expand Up @@ -743,7 +781,7 @@ def setup(self, config: AlgorithmConfig) -> None:
# TODO (Rohan138): Refactor this and remove deprecated methods
# Need to add back method_type in case Algorithm is restored from checkpoint
method_config["type"] = method_type

self.learner_group = None
if self.config.enable_rl_module_and_learner:
local_worker = self.workers.local_worker()
env = spaces = None
Expand Down Expand Up @@ -819,6 +857,29 @@ def setup(self, config: AlgorithmConfig) -> None:
),
)

if self.offline_data:
# If the learners are remote we need to provide specific
# information and the learner's actor handles.
if self.learner_group.is_remote:
# If learners run on different nodes, locality hints help
# to use the nearest learner in the workers that do the
# data preprocessing.
learner_node_ids = self.learner_group.foreach_learner(
lambda l: ray.get_runtime_context().get_node_id()
)
self.offline_data.locality_hints = [
node_id.get() for node_id in learner_node_ids
]
# Provide the actor handles for the learners for module
# updating during preprocessing.
self.offline_data.learner_handles = self.learner_group._workers
# Provide the module_spec. Note, in the remote case this is needed
# because the learner module cannot be copied, but must be built.
self.offline_data.module_spec = module_spec
# Otherwise we can simply pass in the local learner.
else:
self.offline_data.learner_handles = [self.learner_group._learner]

# Run `on_algorithm_init` callback after initialization is done.
self.callbacks.on_algorithm_init(algorithm=self, metrics_logger=self.metrics)

Expand Down
35 changes: 34 additions & 1 deletion rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,10 @@ def __init__(self, algo_class: Optional[type] = None):

# `self.offline_data()`
self.input_ = "sampler"
self.input_read_method = "read_parquet"
self.input_read_method_kwargs = {}
self.prelearner_module_synch_period = 10
self.dataset_num_iters_per_learner = None
self.input_config = {}
self.actions_in_input_normalized = False
self.postprocess_inputs = False
Expand Down Expand Up @@ -2376,6 +2380,10 @@ def offline_data(
self,
*,
input_=NotProvided,
input_read_method=NotProvided,
input_read_method_kwargs=NotProvided,
prelearner_module_synch_period=NotProvided,
dataset_num_iters_per_learner=NotProvided,
input_config=NotProvided,
actions_in_input_normalized=NotProvided,
input_evaluation=NotProvided,
Expand All @@ -2400,7 +2408,24 @@ def offline_data(
- A callable that takes an `IOContext` object as only arg and returns a
ray.rllib.offline.InputReader.
- A string key that indexes a callable with tune.registry.register_input
input_config: Arguments that describe the settings for reading the input.
input_read_method: Read method for the `ray.data.Dataset` to read in the
offline data from `input_`. The default is `read_json` for JSON files.
See https://docs.ray.io/en/latest/data/api/input_output.html for more
info about available read methods in `ray.data`.
input_read_method_kwargs: kwargs for the `input_read_method`. These will be
passed into the read method without checking.
prelearner_module_synch_period: The period (number of batches converted)
after which the `RLModule` held by the `PreLearner` should sync weights.
The `PreLearner` is used to preprocess batches for the learners. The
higher this value the more off-policy the `PreLearner`'s module will be.
Values too small will force the `PreLearner` to sync a ,lot with the
`Learner` and will slow down the data pipeline. The default value chosen
by the `OfflinePreLearner` is 10.
dataset_num_iters_per_learner: Number of iterations to run in each learner
during a single training iteration. If `None`, each learner runs a
complete epoch over its data block (the dataset is partitioned into
as many blocks as there are learners). The default is `None`.
input_config: Arguments that describe the settings for reading the inpu t.
If input is `sample`, this will be environment configuation, e.g.
`env_name` and `env_config`, etc. See `EnvContext` for more info.
If the input is `dataset`, this will be e.g. `format`, `path`.
Expand Down Expand Up @@ -2438,6 +2463,14 @@ def offline_data(
"""
if input_ is not NotProvided:
self.input_ = input_
if input_read_method is not NotProvided:
self.input_read_method = input_read_method
if input_read_method_kwargs is not NotProvided:
self.input_read_method_kwargs = input_read_method_kwargs
if prelearner_module_synch_period is not NotProvided:
self.prelearner_module_synch_period = prelearner_module_synch_period
if dataset_num_iters_per_learner is not NotProvided:
self.dataset_num_iters_per_learner = dataset_num_iters_per_learner
if input_config is not NotProvided:
if not isinstance(input_config, dict):
raise ValueError(
Expand Down
Loading

0 comments on commit 66b68d0

Please sign in to comment.