-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib; Offline RL] Add tests for OfflineSingleAgentEnvRunner
.
#47133
Changes from all commits
f9552e6
ca36962
9bafd9a
3ac9f0b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2456,7 +2456,9 @@ def offline_data( | |
to define the output data format when recording. | ||
input_compress_columns: What input columns are compressed with LZ4 in the | ||
input data. If data is stored in `RLlib`'s `SingleAgentEpisode` ( | ||
`MultiAgentEpisode` not supported, yet). | ||
`MultiAgentEpisode` not supported, yet). Note, | ||
`rllib.core.columns.Columns.OBS` will also try to decompress | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Awesome. Thanks for clarifying this for the users. Makes perfect sense that they won't have to specify NEXT_OBS explicitly. |
||
`rllib.core.columns.Columns.NEXT_OBS`. | ||
map_batches_kwargs: `kwargs` for the `map_batches` method. These will be | ||
passed into the `ray.data.Dataset.map_batches` method when sampling | ||
without checking. If no arguments passed in the default arguments `{ | ||
|
@@ -2513,7 +2515,8 @@ def offline_data( | |
output_config: Arguments accessible from the IOContext for configuring | ||
custom output. | ||
output_compress_columns: What sample batch columns to LZ4 compress in the | ||
output data. | ||
output data. Note, `rllib.core.columns.Columns.OBS` will also compress | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
`rllib.core.columns.Columns.NEXT_OBS`. | ||
output_max_file_size: Max output file size (in bytes) before rolling over | ||
to a new file. | ||
output_max_rows_per_file: Max output row numbers before rolling over to a | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,7 +42,7 @@ def __init__(self, config: AlgorithmConfig, **kwargs): | |
# Set the worker-specific path name. Note, this is | ||
# specifically to enable multi-threaded writing into | ||
# the same directory. | ||
self.worker_path = "run-" + f"{self.worker_index}".zfill(6) + "-" | ||
self.worker_path = "run-" + f"{self.worker_index}".zfill(6) | ||
|
||
# If a specific filesystem is given, set it up. Note, this could | ||
# be `gcsfs` for GCS, `pyarrow` for S3 or `adlfs` for Azure Blob Storage. | ||
|
@@ -193,7 +193,7 @@ def sample( | |
getattr(samples_ds, self.output_write_method)( | ||
path.as_posix(), **self.output_write_method_kwargs | ||
) | ||
logger.info("Wrote samples to storage.") | ||
logger.info(f"Wrote samples to storage at {path}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
except Exception as e: | ||
logger.error(e) | ||
|
||
|
@@ -224,24 +224,20 @@ def _map_episodes_to_data(self, samples: List[EpisodeType]) -> None: | |
to_jsonable_if_needed(sample.get_observations(i), obs_space) | ||
) | ||
if Columns.OBS in self.output_compress_columns | ||
else obs_space.to_jsonable_if_needed( | ||
sample.get_observations(i), obs_space | ||
), | ||
else to_jsonable_if_needed(sample.get_observations(i), obs_space), | ||
# Compress actions, if requested. | ||
Columns.ACTIONS: pack_if_needed( | ||
to_jsonable_if_needed(sample.get_actions(i), action_space) | ||
) | ||
if Columns.OBS in self.output_compress_columns | ||
else action_space.to_jsonable_if_needed( | ||
sample.get_actions(i), action_space | ||
), | ||
else to_jsonable_if_needed(sample.get_actions(i), action_space), | ||
Columns.REWARDS: sample.get_rewards(i), | ||
# Compress next observations, if requested. | ||
Columns.NEXT_OBS: pack_if_needed( | ||
to_jsonable_if_needed(sample.get_observations(i + 1), obs_space) | ||
) | ||
if Columns.OBS in self.output_compress_columns | ||
else obs_space.to_jsonable_if_needed( | ||
else to_jsonable_if_needed( | ||
sample.get_observations(i + 1), obs_space | ||
), | ||
Columns.TERMINATEDS: False | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -250,7 +250,6 @@ def convert(sample, space): | |
return sample | ||
|
||
episodes = [] | ||
# TODO (simon): Give users possibility to provide a custom schema. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! |
||
for i, obs in enumerate(batch[schema[Columns.OBS]]): | ||
|
||
# If multi-agent we need to extract the agent ID. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
import pathlib | ||
import shutil | ||
import unittest | ||
|
||
import ray | ||
from ray.rllib.algorithms.ppo.ppo import PPOConfig | ||
from ray.rllib.core.columns import Columns | ||
from ray.rllib.env.single_agent_episode import SingleAgentEpisode | ||
from ray.rllib.offline.offline_data import OfflineData | ||
from ray.rllib.offline.offline_env_runner import OfflineSingleAgentEnvRunner | ||
|
||
|
||
class TestOfflineEnvRunner(unittest.TestCase): | ||
def setUp(self) -> None: | ||
self.base_path = pathlib.Path("/tmp/") | ||
self.config = ( | ||
PPOConfig() | ||
# Enable new API stack and use EnvRunner. | ||
.api_stack( | ||
enable_rl_module_and_learner=True, | ||
enable_env_runner_and_connector_v2=True, | ||
) | ||
.env_runners( | ||
# This defines how many rows per file we will | ||
# have (given `num_rows_per_file` in the | ||
# `output_write_method_kwargs` is not set). | ||
rollout_fragment_length=1000, | ||
num_env_runners=0, | ||
# Note, this means that written episodes. if | ||
# `output_write_episodes=True` will be incomplete | ||
# in many cases. | ||
batch_mode="truncate_episodes", | ||
) | ||
.environment("CartPole-v1") | ||
.rl_module( | ||
# Use a small network for this test. | ||
model_config_dict={ | ||
"fcnet_hiddens": [32], | ||
"fcnet_activation": "linear", | ||
"vf_share_layers": True, | ||
} | ||
) | ||
) | ||
ray.init() | ||
|
||
def tearDown(self) -> None: | ||
ray.shutdown() | ||
|
||
def test_offline_env_runner_record_episodes(self): | ||
"""Tests recording of episodes. | ||
|
||
Note, in this case each row of the dataset is an episode | ||
that could potentially contain hundreds of steps. | ||
""" | ||
data_dir = "local://" / self.base_path / "cartpole-episodes" | ||
config = self.config.offline_data( | ||
output=data_dir.as_posix(), | ||
# Store experiences in episodes. | ||
output_write_episodes=True, | ||
) | ||
|
||
offline_env_runner = OfflineSingleAgentEnvRunner(config, worker_index=1) | ||
# Sample 1ßß episodes. | ||
_ = offline_env_runner.sample( | ||
num_episodes=100, | ||
random_actions=True, | ||
) | ||
|
||
data_path = data_dir / self.config.env.lower() | ||
records = list(data_path.iterdir()) | ||
|
||
self.assertEqual(len(records), 1) | ||
self.assertEqual(records[0].name, "run-000001-00001") | ||
|
||
# Now read in episodes. | ||
config = self.config.offline_data( | ||
input_=[data_path.as_posix()], | ||
input_read_episodes=True, | ||
) | ||
offline_data = OfflineData(config) | ||
# Assert the dataset has only 100 rows (each row containing an episode). | ||
self.assertEqual(offline_data.data.count(), 100) | ||
# Take a single row and ensure its a `SingleAgentEpisode` instance. | ||
self.assertIsInstance(offline_data.data.take(1)[0]["item"], SingleAgentEpisode) | ||
# The batch contains now episodes (in a numpy.NDArray). | ||
episodes = offline_data.data.take_batch(100)["item"] | ||
# The batch should contain 100 episodes (not 100 env steps). | ||
self.assertEqual(len(episodes), 100) | ||
# Remove all data. | ||
shutil.rmtree(data_dir) | ||
|
||
def test_offline_env_runner_record_column_data(self): | ||
"""Tests recording of single time steps in column format. | ||
|
||
Note, in this case each row in the dataset contains only a single | ||
timestep of the agent. | ||
""" | ||
data_dir = "local://" / self.base_path / "cartpole-columns" | ||
config = self.config.offline_data( | ||
output=data_dir.as_posix(), | ||
# Store experiences in episodes. | ||
output_write_episodes=False, | ||
# Do not compress columns. | ||
output_compress_columns=[], | ||
) | ||
|
||
offline_env_runner = OfflineSingleAgentEnvRunner(config, worker_index=1) | ||
|
||
_ = offline_env_runner.sample( | ||
num_timesteps=100, | ||
random_actions=True, | ||
) | ||
|
||
data_path = data_dir / self.config.env.lower() | ||
records = list(data_path.iterdir()) | ||
|
||
self.assertEqual(len(records), 1) | ||
self.assertEqual(records[0].name, "run-000001-00001") | ||
|
||
# Now read in episodes. | ||
config = self.config.offline_data( | ||
input_=[data_path.as_posix()], | ||
input_read_episodes=False, | ||
) | ||
offline_data = OfflineData(config) | ||
# Assert the dataset has only 100 rows. | ||
self.assertEqual(offline_data.data.count(), 100) | ||
# The batch contains now episodes (in a numpy.NDArray). | ||
batch = offline_data.data.take_batch(100) | ||
# The batch should contain 100 episodes (not 100 env steps). | ||
self.assertTrue(len(batch[Columns.OBS]) == 100) | ||
# Remove all data. | ||
shutil.rmtree(data_dir) | ||
|
||
def test_offline_env_runner_compress_columns(self): | ||
"""Tests recording of timesteps with compressed columns. | ||
|
||
Note, `input_compress_columns` will compress only the columns | ||
listed. `Columns.OBS` will also compress `Columns.NEXT_OBS`. | ||
""" | ||
data_dir = "local://" / self.base_path / "cartpole-columns" | ||
config = self.config.offline_data( | ||
output=data_dir.as_posix(), | ||
# Store experiences in episodes. | ||
output_write_episodes=False, | ||
# LZ4-compress columns 'obs', 'new_obs', and 'actions' to | ||
# save disk space and increase performance. Note, this means | ||
# that you have to use `input_compress_columns` in the same | ||
# way when using the data for training in `RLlib`. | ||
output_compress_columns=[Columns.OBS, Columns.ACTIONS], | ||
# In addition compress the complete file. | ||
# TODO (simon): This does not work. It looks as if there | ||
# is an error in the write/read methods for qparquet in | ||
# ray.data. `arrow_open_stream_args` nor `arrow_parquet_args` | ||
# do work here. | ||
# output_write_method_kwargs={ | ||
# "arrow_open_stream_args": { | ||
# "compression": "gzip", | ||
# } | ||
# } | ||
) | ||
|
||
offline_env_runner = OfflineSingleAgentEnvRunner(config, worker_index=1) | ||
|
||
_ = offline_env_runner.sample( | ||
num_timesteps=100, | ||
random_actions=True, | ||
) | ||
|
||
data_path = data_dir / self.config.env.lower() | ||
records = list(data_path.iterdir()) | ||
|
||
self.assertEqual(len(records), 1) | ||
self.assertEqual(records[0].name, "run-000001-00001") | ||
|
||
# Now read in episodes. | ||
config = self.config.offline_data( | ||
input_=[(data_path / "run-000001-00001").as_posix()], | ||
input_read_episodes=False, | ||
# Also uncompress files and columns. | ||
# TODO (simon): Activate as soon as the bug is fixed | ||
# in ray.data. | ||
# input_read_method_kwargs={ | ||
# "arrow_open_stream_args": { | ||
# "compression": "gzip", | ||
# } | ||
# }, | ||
input_compress_columns=[Columns.OBS, Columns.ACTIONS], | ||
) | ||
offline_data = OfflineData(config) | ||
# Assert the dataset has only 100 rows. | ||
self.assertEqual(offline_data.data.count(), 100) | ||
# The batch contains now episodes (in a numpy.NDArray). | ||
batch = offline_data.data.take_batch(100) | ||
# The batch should contain 100 episodes (not 100 env steps). | ||
self.assertTrue(len(batch[Columns.OBS]) == 100) | ||
# Remove all data. | ||
shutil.rmtree(data_dir) | ||
|
||
|
||
if __name__ == "__main__": | ||
import sys | ||
import pytest | ||
|
||
sys.exit(pytest.main(["-v", __file__])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome!
super nit: let's keep test cases alphabetically ordered (within their category), so let's move this below the next record (
test_offline_data
).