Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove rllib dependency from core tests #51171

Merged
merged 23 commits into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion python/ray/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ py_test_module_list(
"test_client_multi.py",
"test_client_references.py",
"test_client_warnings.py",
"test_client_library_integration.py",
],
size = "medium",
tags = ["exclusive", "client_tests", "team:core"],
Expand Down
58 changes: 0 additions & 58 deletions python/ray/tests/test_client_library_integration.py

This file was deleted.

48 changes: 13 additions & 35 deletions python/ray/tests/test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,44 +741,21 @@ def test_output_ray_cluster(call_ray_start):
def test_output_on_driver_shutdown(ray_start_cluster):
cluster = ray_start_cluster
cluster.add_node(num_cpus=16)
# many_ppo.py script.
script = """
import ray
from ray.tune import run_experiments
from ray.tune.utils.release_test_util import ProgressCallback

object_store_memory = 10**9
num_nodes = 3

message = ("Make sure there is enough memory on this machine to run this "
"workload. We divide the system memory by 2 to provide a buffer.")
assert (num_nodes * object_store_memory <
ray._private.utils.get_system_memory() / 2), message
ray.init(address="auto")

# Simulate a cluster on one machine.
@ray.remote
def f(i: int):
return i

ray.init(address="auto")
obj_refs = [f.remote(i) for i in range(100)]

# Run the workload.

run_experiments(
{
"PPO": {
"run": "PPO",
"env": "CartPole-v0",
"num_samples": 10,
"config": {
"framework": "torch",
"num_workers": 1,
"num_gpus": 0,
"num_sgd_iter": 1,
},
"stop": {
"timesteps_total": 1,
},
}
},
callbacks=[ProgressCallback()])
while True:
assert len(obj_refs) == 100
ready, pending = ray.wait(obj_refs, num_returns=10)
for i in ray.get(ready):
obj_refs[i] = f.remote(i)
"""

proc = run_string_as_driver_nonblocking(script)
Expand All @@ -795,13 +772,14 @@ def test_output_on_driver_shutdown(ray_start_cluster):
time.sleep(0.1)
os.kill(proc.pid, signal.SIGINT)
try:
proc.wait(timeout=10)
proc.wait(timeout=5)
except subprocess.TimeoutExpired:
print("Script wasn't terminated by SIGINT. Try SIGTERM.")
os.kill(proc.pid, signal.SIGTERM)
print(proc.wait(timeout=10))
print(proc.wait(timeout=5))
err_str = proc.stderr.read().decode("ascii")
assert len(err_str) > 0
assert "KeyboardInterrupt" in err_str
assert "StackTrace Information" not in err_str
print(err_str)

Expand Down
7 changes: 7 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1822,6 +1822,13 @@ py_test(
srcs = ["tests/test_timesteps.py"]
)

py_test(
name = "tests/test_ray_client",
tags = ["team:rllib", "tests_dir"],
size = "medium",
srcs = ["tests/test_ray_client.py"]
)

# --------------------------------------------------------------------
# examples/ directory
#
Expand Down
27 changes: 27 additions & 0 deletions rllib/tests/test_ray_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import sys

import pytest

from ray.rllib.algorithms import dqn
from ray.util.client.ray_client_helpers import ray_start_client_server
from ray._private.client_mode_hook import enable_client_mode, client_mode_should_convert


def test_basic_dqn():
with ray_start_client_server():
# Need to enable this for client APIs to be used.
with enable_client_mode():
# Confirming mode hook is enabled.
assert client_mode_should_convert()
config = (
dqn.DQNConfig()
.environment("CartPole-v1")
.env_runners(num_env_runners=0, compress_observations=True)
)
trainer = config.build()
for i in range(2):
trainer.train()


if __name__ == "__main__":
sys.exit(pytest.main(["-v", "-s", __file__]))