Skip to content

Commit

Permalink
[RLlib; Offline RL] Add tests for OfflineSingleAgentEnvRunner. (#47133
Browse files Browse the repository at this point in the history
)
  • Loading branch information
simonsays1980 authored Aug 15, 2024
1 parent e21af4b commit 50d44cc
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 12 deletions.
7 changes: 7 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1598,6 +1598,13 @@ py_test(
srcs = ["offline/estimators/tests/test_dr_learning.py"],
)

py_test(
name = "test_offline_env_runner",
tags = ["team:rllib", "offline"],
size = "small",
srcs = ["offline/tests/test_offline_env_runner.py"],
)

py_test(
name = "test_offline_data",
tags = ["team:rllib", "offline"],
Expand Down
7 changes: 5 additions & 2 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2456,7 +2456,9 @@ def offline_data(
to define the output data format when recording.
input_compress_columns: What input columns are compressed with LZ4 in the
input data. If data is stored in `RLlib`'s `SingleAgentEpisode` (
`MultiAgentEpisode` not supported, yet).
`MultiAgentEpisode` not supported, yet). Note,
`rllib.core.columns.Columns.OBS` will also try to decompress
`rllib.core.columns.Columns.NEXT_OBS`.
map_batches_kwargs: `kwargs` for the `map_batches` method. These will be
passed into the `ray.data.Dataset.map_batches` method when sampling
without checking. If no arguments passed in the default arguments `{
Expand Down Expand Up @@ -2513,7 +2515,8 @@ def offline_data(
output_config: Arguments accessible from the IOContext for configuring
custom output.
output_compress_columns: What sample batch columns to LZ4 compress in the
output data.
output data. Note, `rllib.core.columns.Columns.OBS` will also compress
`rllib.core.columns.Columns.NEXT_OBS`.
output_max_file_size: Max output file size (in bytes) before rolling over
to a new file.
output_max_rows_per_file: Max output row numbers before rolling over to a
Expand Down
14 changes: 5 additions & 9 deletions rllib/offline/offline_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, config: AlgorithmConfig, **kwargs):
# Set the worker-specific path name. Note, this is
# specifically to enable multi-threaded writing into
# the same directory.
self.worker_path = "run-" + f"{self.worker_index}".zfill(6) + "-"
self.worker_path = "run-" + f"{self.worker_index}".zfill(6)

# If a specific filesystem is given, set it up. Note, this could
# be `gcsfs` for GCS, `pyarrow` for S3 or `adlfs` for Azure Blob Storage.
Expand Down Expand Up @@ -193,7 +193,7 @@ def sample(
getattr(samples_ds, self.output_write_method)(
path.as_posix(), **self.output_write_method_kwargs
)
logger.info("Wrote samples to storage.")
logger.info(f"Wrote samples to storage at {path}")
except Exception as e:
logger.error(e)

Expand Down Expand Up @@ -224,24 +224,20 @@ def _map_episodes_to_data(self, samples: List[EpisodeType]) -> None:
to_jsonable_if_needed(sample.get_observations(i), obs_space)
)
if Columns.OBS in self.output_compress_columns
else obs_space.to_jsonable_if_needed(
sample.get_observations(i), obs_space
),
else to_jsonable_if_needed(sample.get_observations(i), obs_space),
# Compress actions, if requested.
Columns.ACTIONS: pack_if_needed(
to_jsonable_if_needed(sample.get_actions(i), action_space)
)
if Columns.OBS in self.output_compress_columns
else action_space.to_jsonable_if_needed(
sample.get_actions(i), action_space
),
else to_jsonable_if_needed(sample.get_actions(i), action_space),
Columns.REWARDS: sample.get_rewards(i),
# Compress next observations, if requested.
Columns.NEXT_OBS: pack_if_needed(
to_jsonable_if_needed(sample.get_observations(i + 1), obs_space)
)
if Columns.OBS in self.output_compress_columns
else obs_space.to_jsonable_if_needed(
else to_jsonable_if_needed(
sample.get_observations(i + 1), obs_space
),
Columns.TERMINATEDS: False
Expand Down
1 change: 0 additions & 1 deletion rllib/offline/offline_prelearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ def convert(sample, space):
return sample

episodes = []
# TODO (simon): Give users possibility to provide a custom schema.
for i, obs in enumerate(batch[schema[Columns.OBS]]):

# If multi-agent we need to extract the agent ID.
Expand Down
205 changes: 205 additions & 0 deletions rllib/offline/tests/test_offline_env_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import pathlib
import shutil
import unittest

import ray
from ray.rllib.algorithms.ppo.ppo import PPOConfig
from ray.rllib.core.columns import Columns
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.rllib.offline.offline_data import OfflineData
from ray.rllib.offline.offline_env_runner import OfflineSingleAgentEnvRunner


class TestOfflineEnvRunner(unittest.TestCase):
def setUp(self) -> None:
self.base_path = pathlib.Path("/tmp/")
self.config = (
PPOConfig()
# Enable new API stack and use EnvRunner.
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.env_runners(
# This defines how many rows per file we will
# have (given `num_rows_per_file` in the
# `output_write_method_kwargs` is not set).
rollout_fragment_length=1000,
num_env_runners=0,
# Note, this means that written episodes. if
# `output_write_episodes=True` will be incomplete
# in many cases.
batch_mode="truncate_episodes",
)
.environment("CartPole-v1")
.rl_module(
# Use a small network for this test.
model_config_dict={
"fcnet_hiddens": [32],
"fcnet_activation": "linear",
"vf_share_layers": True,
}
)
)
ray.init()

def tearDown(self) -> None:
ray.shutdown()

def test_offline_env_runner_record_episodes(self):
"""Tests recording of episodes.
Note, in this case each row of the dataset is an episode
that could potentially contain hundreds of steps.
"""
data_dir = "local://" / self.base_path / "cartpole-episodes"
config = self.config.offline_data(
output=data_dir.as_posix(),
# Store experiences in episodes.
output_write_episodes=True,
)

offline_env_runner = OfflineSingleAgentEnvRunner(config, worker_index=1)
# Sample 1ßß episodes.
_ = offline_env_runner.sample(
num_episodes=100,
random_actions=True,
)

data_path = data_dir / self.config.env.lower()
records = list(data_path.iterdir())

self.assertEqual(len(records), 1)
self.assertEqual(records[0].name, "run-000001-00001")

# Now read in episodes.
config = self.config.offline_data(
input_=[data_path.as_posix()],
input_read_episodes=True,
)
offline_data = OfflineData(config)
# Assert the dataset has only 100 rows (each row containing an episode).
self.assertEqual(offline_data.data.count(), 100)
# Take a single row and ensure its a `SingleAgentEpisode` instance.
self.assertIsInstance(offline_data.data.take(1)[0]["item"], SingleAgentEpisode)
# The batch contains now episodes (in a numpy.NDArray).
episodes = offline_data.data.take_batch(100)["item"]
# The batch should contain 100 episodes (not 100 env steps).
self.assertEqual(len(episodes), 100)
# Remove all data.
shutil.rmtree(data_dir)

def test_offline_env_runner_record_column_data(self):
"""Tests recording of single time steps in column format.
Note, in this case each row in the dataset contains only a single
timestep of the agent.
"""
data_dir = "local://" / self.base_path / "cartpole-columns"
config = self.config.offline_data(
output=data_dir.as_posix(),
# Store experiences in episodes.
output_write_episodes=False,
# Do not compress columns.
output_compress_columns=[],
)

offline_env_runner = OfflineSingleAgentEnvRunner(config, worker_index=1)

_ = offline_env_runner.sample(
num_timesteps=100,
random_actions=True,
)

data_path = data_dir / self.config.env.lower()
records = list(data_path.iterdir())

self.assertEqual(len(records), 1)
self.assertEqual(records[0].name, "run-000001-00001")

# Now read in episodes.
config = self.config.offline_data(
input_=[data_path.as_posix()],
input_read_episodes=False,
)
offline_data = OfflineData(config)
# Assert the dataset has only 100 rows.
self.assertEqual(offline_data.data.count(), 100)
# The batch contains now episodes (in a numpy.NDArray).
batch = offline_data.data.take_batch(100)
# The batch should contain 100 episodes (not 100 env steps).
self.assertTrue(len(batch[Columns.OBS]) == 100)
# Remove all data.
shutil.rmtree(data_dir)

def test_offline_env_runner_compress_columns(self):
"""Tests recording of timesteps with compressed columns.
Note, `input_compress_columns` will compress only the columns
listed. `Columns.OBS` will also compress `Columns.NEXT_OBS`.
"""
data_dir = "local://" / self.base_path / "cartpole-columns"
config = self.config.offline_data(
output=data_dir.as_posix(),
# Store experiences in episodes.
output_write_episodes=False,
# LZ4-compress columns 'obs', 'new_obs', and 'actions' to
# save disk space and increase performance. Note, this means
# that you have to use `input_compress_columns` in the same
# way when using the data for training in `RLlib`.
output_compress_columns=[Columns.OBS, Columns.ACTIONS],
# In addition compress the complete file.
# TODO (simon): This does not work. It looks as if there
# is an error in the write/read methods for qparquet in
# ray.data. `arrow_open_stream_args` nor `arrow_parquet_args`
# do work here.
# output_write_method_kwargs={
# "arrow_open_stream_args": {
# "compression": "gzip",
# }
# }
)

offline_env_runner = OfflineSingleAgentEnvRunner(config, worker_index=1)

_ = offline_env_runner.sample(
num_timesteps=100,
random_actions=True,
)

data_path = data_dir / self.config.env.lower()
records = list(data_path.iterdir())

self.assertEqual(len(records), 1)
self.assertEqual(records[0].name, "run-000001-00001")

# Now read in episodes.
config = self.config.offline_data(
input_=[(data_path / "run-000001-00001").as_posix()],
input_read_episodes=False,
# Also uncompress files and columns.
# TODO (simon): Activate as soon as the bug is fixed
# in ray.data.
# input_read_method_kwargs={
# "arrow_open_stream_args": {
# "compression": "gzip",
# }
# },
input_compress_columns=[Columns.OBS, Columns.ACTIONS],
)
offline_data = OfflineData(config)
# Assert the dataset has only 100 rows.
self.assertEqual(offline_data.data.count(), 100)
# The batch contains now episodes (in a numpy.NDArray).
batch = offline_data.data.take_batch(100)
# The batch should contain 100 episodes (not 100 env steps).
self.assertTrue(len(batch[Columns.OBS]) == 100)
# Remove all data.
shutil.rmtree(data_dir)


if __name__ == "__main__":
import sys
import pytest

sys.exit(pytest.main(["-v", __file__]))

0 comments on commit 50d44cc

Please sign in to comment.