Skip to content

Commit

Permalink
Remove rllib dependency from core tests (ray-project#51171)
Browse files Browse the repository at this point in the history
More work towards breaking the core -> libraries dependencies.

- `test_output.py`: changed to use ray core API instead.
- `test_client_library_integration.py`: moved to a test inside rllib
directory.

---------

Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
  • Loading branch information
edoakes authored and elimelt committed Mar 9, 2025
1 parent 4988153 commit 5c574d9
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 94 deletions.
1 change: 0 additions & 1 deletion python/ray/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ py_test_module_list(
"test_client_multi.py",
"test_client_references.py",
"test_client_warnings.py",
"test_client_library_integration.py",
],
size = "medium",
tags = ["exclusive", "client_tests", "team:core"],
Expand Down
58 changes: 0 additions & 58 deletions python/ray/tests/test_client_library_integration.py

This file was deleted.

48 changes: 13 additions & 35 deletions python/ray/tests/test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,44 +741,21 @@ def test_output_ray_cluster(call_ray_start):
def test_output_on_driver_shutdown(ray_start_cluster):
cluster = ray_start_cluster
cluster.add_node(num_cpus=16)
# many_ppo.py script.
script = """
import ray
from ray.tune import run_experiments
from ray.tune.utils.release_test_util import ProgressCallback
object_store_memory = 10**9
num_nodes = 3
message = ("Make sure there is enough memory on this machine to run this "
"workload. We divide the system memory by 2 to provide a buffer.")
assert (num_nodes * object_store_memory <
ray._private.utils.get_system_memory() / 2), message
ray.init(address="auto")
# Simulate a cluster on one machine.
@ray.remote
def f(i: int):
return i
ray.init(address="auto")
obj_refs = [f.remote(i) for i in range(100)]
# Run the workload.
run_experiments(
{
"PPO": {
"run": "PPO",
"env": "CartPole-v0",
"num_samples": 10,
"config": {
"framework": "torch",
"num_workers": 1,
"num_gpus": 0,
"num_sgd_iter": 1,
},
"stop": {
"timesteps_total": 1,
},
}
},
callbacks=[ProgressCallback()])
while True:
assert len(obj_refs) == 100
ready, pending = ray.wait(obj_refs, num_returns=10)
for i in ray.get(ready):
obj_refs[i] = f.remote(i)
"""

proc = run_string_as_driver_nonblocking(script)
Expand All @@ -795,13 +772,14 @@ def test_output_on_driver_shutdown(ray_start_cluster):
time.sleep(0.1)
os.kill(proc.pid, signal.SIGINT)
try:
proc.wait(timeout=10)
proc.wait(timeout=5)
except subprocess.TimeoutExpired:
print("Script wasn't terminated by SIGINT. Try SIGTERM.")
os.kill(proc.pid, signal.SIGTERM)
print(proc.wait(timeout=10))
print(proc.wait(timeout=5))
err_str = proc.stderr.read().decode("ascii")
assert len(err_str) > 0
assert "KeyboardInterrupt" in err_str
assert "StackTrace Information" not in err_str
print(err_str)

Expand Down
7 changes: 7 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1822,6 +1822,13 @@ py_test(
srcs = ["tests/test_timesteps.py"]
)

py_test(
name = "tests/test_ray_client",
tags = ["team:rllib", "tests_dir"],
size = "medium",
srcs = ["tests/test_ray_client.py"]
)

# --------------------------------------------------------------------
# examples/ directory
#
Expand Down
27 changes: 27 additions & 0 deletions rllib/tests/test_ray_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import sys

import pytest

from ray.rllib.algorithms import dqn
from ray.util.client.ray_client_helpers import ray_start_client_server
from ray._private.client_mode_hook import enable_client_mode, client_mode_should_convert


def test_basic_dqn():
with ray_start_client_server():
# Need to enable this for client APIs to be used.
with enable_client_mode():
# Confirming mode hook is enabled.
assert client_mode_should_convert()
config = (
dqn.DQNConfig()
.environment("CartPole-v1")
.env_runners(num_env_runners=0, compress_observations=True)
)
trainer = config.build()
for i in range(2):
trainer.train()


if __name__ == "__main__":
sys.exit(pytest.main(["-v", "-s", __file__]))

0 comments on commit 5c574d9

Please sign in to comment.