Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib; Offline RL] Add tests for OfflineSingleAgentEnvRunner. #47133

Merged
merged 4 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1598,6 +1598,13 @@ py_test(
srcs = ["offline/estimators/tests/test_dr_learning.py"],
)

py_test(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome!

super nit: let's keep test cases alphabetically ordered (within their category), so let's move this below the next record (test_offline_data).

name = "test_offline_env_runner",
tags = ["team:rllib", "offline"],
size = "small",
srcs = ["offline/tests/test_offline_env_runner.py"],
)

py_test(
name = "test_offline_data",
tags = ["team:rllib", "offline"],
Expand Down
7 changes: 5 additions & 2 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2456,7 +2456,9 @@ def offline_data(
to define the output data format when recording.
input_compress_columns: What input columns are compressed with LZ4 in the
input data. If data is stored in `RLlib`'s `SingleAgentEpisode` (
`MultiAgentEpisode` not supported, yet).
`MultiAgentEpisode` not supported, yet). Note,
`rllib.core.columns.Columns.OBS` will also try to decompress
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome. Thanks for clarifying this for the users. Makes perfect sense that they won't have to specify NEXT_OBS explicitly.

`rllib.core.columns.Columns.NEXT_OBS`.
map_batches_kwargs: `kwargs` for the `map_batches` method. These will be
passed into the `ray.data.Dataset.map_batches` method when sampling
without checking. If no arguments passed in the default arguments `{
Expand Down Expand Up @@ -2513,7 +2515,8 @@ def offline_data(
output_config: Arguments accessible from the IOContext for configuring
custom output.
output_compress_columns: What sample batch columns to LZ4 compress in the
output data.
output data. Note, `rllib.core.columns.Columns.OBS` will also compress
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

`rllib.core.columns.Columns.NEXT_OBS`.
output_max_file_size: Max output file size (in bytes) before rolling over
to a new file.
output_max_rows_per_file: Max output row numbers before rolling over to a
Expand Down
14 changes: 5 additions & 9 deletions rllib/offline/offline_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, config: AlgorithmConfig, **kwargs):
# Set the worker-specific path name. Note, this is
# specifically to enable multi-threaded writing into
# the same directory.
self.worker_path = "run-" + f"{self.worker_index}".zfill(6) + "-"
self.worker_path = "run-" + f"{self.worker_index}".zfill(6)

# If a specific filesystem is given, set it up. Note, this could
# be `gcsfs` for GCS, `pyarrow` for S3 or `adlfs` for Azure Blob Storage.
Expand Down Expand Up @@ -193,7 +193,7 @@ def sample(
getattr(samples_ds, self.output_write_method)(
path.as_posix(), **self.output_write_method_kwargs
)
logger.info("Wrote samples to storage.")
logger.info(f"Wrote samples to storage at {path}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

except Exception as e:
logger.error(e)

Expand Down Expand Up @@ -224,24 +224,20 @@ def _map_episodes_to_data(self, samples: List[EpisodeType]) -> None:
to_jsonable_if_needed(sample.get_observations(i), obs_space)
)
if Columns.OBS in self.output_compress_columns
else obs_space.to_jsonable_if_needed(
sample.get_observations(i), obs_space
),
else to_jsonable_if_needed(sample.get_observations(i), obs_space),
# Compress actions, if requested.
Columns.ACTIONS: pack_if_needed(
to_jsonable_if_needed(sample.get_actions(i), action_space)
)
if Columns.OBS in self.output_compress_columns
else action_space.to_jsonable_if_needed(
sample.get_actions(i), action_space
),
else to_jsonable_if_needed(sample.get_actions(i), action_space),
Columns.REWARDS: sample.get_rewards(i),
# Compress next observations, if requested.
Columns.NEXT_OBS: pack_if_needed(
to_jsonable_if_needed(sample.get_observations(i + 1), obs_space)
)
if Columns.OBS in self.output_compress_columns
else obs_space.to_jsonable_if_needed(
else to_jsonable_if_needed(
sample.get_observations(i + 1), obs_space
),
Columns.TERMINATEDS: False
Expand Down
1 change: 0 additions & 1 deletion rllib/offline/offline_prelearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ def convert(sample, space):
return sample

episodes = []
# TODO (simon): Give users possibility to provide a custom schema.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

for i, obs in enumerate(batch[schema[Columns.OBS]]):

# If multi-agent we need to extract the agent ID.
Expand Down
205 changes: 205 additions & 0 deletions rllib/offline/tests/test_offline_env_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import pathlib
import shutil
import unittest

import ray
from ray.rllib.algorithms.ppo.ppo import PPOConfig
from ray.rllib.core.columns import Columns
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.rllib.offline.offline_data import OfflineData
from ray.rllib.offline.offline_env_runner import OfflineSingleAgentEnvRunner


class TestOfflineEnvRunner(unittest.TestCase):
def setUp(self) -> None:
self.base_path = pathlib.Path("/tmp/")
self.config = (
PPOConfig()
# Enable new API stack and use EnvRunner.
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.env_runners(
# This defines how many rows per file we will
# have (given `num_rows_per_file` in the
# `output_write_method_kwargs` is not set).
rollout_fragment_length=1000,
num_env_runners=0,
# Note, this means that written episodes. if
# `output_write_episodes=True` will be incomplete
# in many cases.
batch_mode="truncate_episodes",
)
.environment("CartPole-v1")
.rl_module(
# Use a small network for this test.
model_config_dict={
"fcnet_hiddens": [32],
"fcnet_activation": "linear",
"vf_share_layers": True,
}
)
)
ray.init()

def tearDown(self) -> None:
ray.shutdown()

def test_offline_env_runner_record_episodes(self):
"""Tests recording of episodes.

Note, in this case each row of the dataset is an episode
that could potentially contain hundreds of steps.
"""
data_dir = "local://" / self.base_path / "cartpole-episodes"
config = self.config.offline_data(
output=data_dir.as_posix(),
# Store experiences in episodes.
output_write_episodes=True,
)

offline_env_runner = OfflineSingleAgentEnvRunner(config, worker_index=1)
# Sample 1ßß episodes.
_ = offline_env_runner.sample(
num_episodes=100,
random_actions=True,
)

data_path = data_dir / self.config.env.lower()
records = list(data_path.iterdir())

self.assertEqual(len(records), 1)
self.assertEqual(records[0].name, "run-000001-00001")

# Now read in episodes.
config = self.config.offline_data(
input_=[data_path.as_posix()],
input_read_episodes=True,
)
offline_data = OfflineData(config)
# Assert the dataset has only 100 rows (each row containing an episode).
self.assertEqual(offline_data.data.count(), 100)
# Take a single row and ensure its a `SingleAgentEpisode` instance.
self.assertIsInstance(offline_data.data.take(1)[0]["item"], SingleAgentEpisode)
# The batch contains now episodes (in a numpy.NDArray).
episodes = offline_data.data.take_batch(100)["item"]
# The batch should contain 100 episodes (not 100 env steps).
self.assertEqual(len(episodes), 100)
# Remove all data.
shutil.rmtree(data_dir)

def test_offline_env_runner_record_column_data(self):
"""Tests recording of single time steps in column format.

Note, in this case each row in the dataset contains only a single
timestep of the agent.
"""
data_dir = "local://" / self.base_path / "cartpole-columns"
config = self.config.offline_data(
output=data_dir.as_posix(),
# Store experiences in episodes.
output_write_episodes=False,
# Do not compress columns.
output_compress_columns=[],
)

offline_env_runner = OfflineSingleAgentEnvRunner(config, worker_index=1)

_ = offline_env_runner.sample(
num_timesteps=100,
random_actions=True,
)

data_path = data_dir / self.config.env.lower()
records = list(data_path.iterdir())

self.assertEqual(len(records), 1)
self.assertEqual(records[0].name, "run-000001-00001")

# Now read in episodes.
config = self.config.offline_data(
input_=[data_path.as_posix()],
input_read_episodes=False,
)
offline_data = OfflineData(config)
# Assert the dataset has only 100 rows.
self.assertEqual(offline_data.data.count(), 100)
# The batch contains now episodes (in a numpy.NDArray).
batch = offline_data.data.take_batch(100)
# The batch should contain 100 episodes (not 100 env steps).
self.assertTrue(len(batch[Columns.OBS]) == 100)
# Remove all data.
shutil.rmtree(data_dir)

def test_offline_env_runner_compress_columns(self):
"""Tests recording of timesteps with compressed columns.

Note, `input_compress_columns` will compress only the columns
listed. `Columns.OBS` will also compress `Columns.NEXT_OBS`.
"""
data_dir = "local://" / self.base_path / "cartpole-columns"
config = self.config.offline_data(
output=data_dir.as_posix(),
# Store experiences in episodes.
output_write_episodes=False,
# LZ4-compress columns 'obs', 'new_obs', and 'actions' to
# save disk space and increase performance. Note, this means
# that you have to use `input_compress_columns` in the same
# way when using the data for training in `RLlib`.
output_compress_columns=[Columns.OBS, Columns.ACTIONS],
# In addition compress the complete file.
# TODO (simon): This does not work. It looks as if there
# is an error in the write/read methods for qparquet in
# ray.data. `arrow_open_stream_args` nor `arrow_parquet_args`
# do work here.
# output_write_method_kwargs={
# "arrow_open_stream_args": {
# "compression": "gzip",
# }
# }
)

offline_env_runner = OfflineSingleAgentEnvRunner(config, worker_index=1)

_ = offline_env_runner.sample(
num_timesteps=100,
random_actions=True,
)

data_path = data_dir / self.config.env.lower()
records = list(data_path.iterdir())

self.assertEqual(len(records), 1)
self.assertEqual(records[0].name, "run-000001-00001")

# Now read in episodes.
config = self.config.offline_data(
input_=[(data_path / "run-000001-00001").as_posix()],
input_read_episodes=False,
# Also uncompress files and columns.
# TODO (simon): Activate as soon as the bug is fixed
# in ray.data.
# input_read_method_kwargs={
# "arrow_open_stream_args": {
# "compression": "gzip",
# }
# },
input_compress_columns=[Columns.OBS, Columns.ACTIONS],
)
offline_data = OfflineData(config)
# Assert the dataset has only 100 rows.
self.assertEqual(offline_data.data.count(), 100)
# The batch contains now episodes (in a numpy.NDArray).
batch = offline_data.data.take_batch(100)
# The batch should contain 100 episodes (not 100 env steps).
self.assertTrue(len(batch[Columns.OBS]) == 100)
# Remove all data.
shutil.rmtree(data_dir)


if __name__ == "__main__":
import sys
import pytest

sys.exit(pytest.main(["-v", __file__]))
Loading