Skip to content

Commit 092f29e

Browse files
author
sampan
committed
[Core][Tests] Add unit tests for Python auth_token_loader
- Test token loading from RAY_AUTH_TOKEN environment variable - Test token loading from RAY_AUTH_TOKEN_PATH file - Test token loading from default ~/.ray/auth_token path - Test precedence order (env var > path env var > default file) - Test token generation with generate_if_not_found=True - Test token caching behavior across multiple calls - Test has_auth_token() function - Test thread safety with concurrent loads and generation - Test whitespace handling and empty values - Test file permissions on Unix systems (0600) - Test error handling for permission errors - Test integration with fixtures and cleanup Signed-off-by: sampan <sampan@anyscale.com>
1 parent 91f783e commit 092f29e

File tree

1 file changed

+320
-0
lines changed

1 file changed

+320
-0
lines changed
Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
"""Unit tests for ray._private.auth_token_loader module."""
2+
3+
import os
4+
import tempfile
5+
import threading
6+
from pathlib import Path
7+
from unittest import mock
8+
9+
import pytest
10+
11+
from ray._private import auth_token_loader
12+
13+
14+
@pytest.fixture(autouse=True)
15+
def reset_cached_token():
16+
"""Reset the cached token before each test."""
17+
auth_token_loader._cached_token = None
18+
yield
19+
auth_token_loader._cached_token = None
20+
21+
22+
@pytest.fixture(autouse=True)
23+
def clean_env_vars():
24+
"""Clean up environment variables before and after each test."""
25+
env_vars = ["RAY_AUTH_TOKEN", "RAY_AUTH_TOKEN_PATH"]
26+
old_values = {var: os.environ.get(var) for var in env_vars}
27+
28+
# Clear environment variables
29+
for var in env_vars:
30+
if var in os.environ:
31+
del os.environ[var]
32+
33+
yield
34+
35+
# Restore old values
36+
for var, value in old_values.items():
37+
if value is not None:
38+
os.environ[var] = value
39+
elif var in os.environ:
40+
del os.environ[var]
41+
42+
43+
@pytest.fixture
44+
def temp_token_file():
45+
"""Create a temporary token file and clean it up after the test."""
46+
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".token") as f:
47+
temp_path = f.name
48+
f.write("test-token-from-file")
49+
yield temp_path
50+
try:
51+
os.unlink(temp_path)
52+
except FileNotFoundError:
53+
pass
54+
55+
56+
@pytest.fixture
57+
def default_token_path():
58+
"""Return the default token path and clean it up after the test."""
59+
path = Path.home() / ".ray" / "auth_token"
60+
yield path
61+
try:
62+
path.unlink()
63+
except FileNotFoundError:
64+
pass
65+
66+
67+
class TestLoadAuthToken:
68+
"""Tests for load_auth_token function."""
69+
70+
def test_load_from_env_variable(self):
71+
"""Test loading token from RAY_AUTH_TOKEN environment variable."""
72+
os.environ["RAY_AUTH_TOKEN"] = "token-from-env"
73+
token = auth_token_loader.load_auth_token(generate_if_not_found=False)
74+
assert token == "token-from-env"
75+
76+
def test_load_from_env_path(self, temp_token_file):
77+
"""Test loading token from RAY_AUTH_TOKEN_PATH environment variable."""
78+
os.environ["RAY_AUTH_TOKEN_PATH"] = temp_token_file
79+
token = auth_token_loader.load_auth_token(generate_if_not_found=False)
80+
assert token == "test-token-from-file"
81+
82+
def test_load_from_default_path(self, default_token_path):
83+
"""Test loading token from default ~/.ray/auth_token path."""
84+
default_token_path.parent.mkdir(parents=True, exist_ok=True)
85+
default_token_path.write_text("token-from-default")
86+
87+
token = auth_token_loader.load_auth_token(generate_if_not_found=False)
88+
assert token == "token-from-default"
89+
90+
def test_precedence_order(self, temp_token_file, default_token_path):
91+
"""Test that token loading follows correct precedence order."""
92+
# Set all three sources
93+
os.environ["RAY_AUTH_TOKEN"] = "token-from-env"
94+
os.environ["RAY_AUTH_TOKEN_PATH"] = temp_token_file
95+
96+
default_token_path.parent.mkdir(parents=True, exist_ok=True)
97+
default_token_path.write_text("token-from-default")
98+
99+
# Environment variable should have highest precedence
100+
token = auth_token_loader.load_auth_token(generate_if_not_found=False)
101+
assert token == "token-from-env"
102+
103+
def test_env_path_over_default(self, temp_token_file, default_token_path):
104+
"""Test that RAY_AUTH_TOKEN_PATH has precedence over default path."""
105+
os.environ["RAY_AUTH_TOKEN_PATH"] = temp_token_file
106+
107+
default_token_path.parent.mkdir(parents=True, exist_ok=True)
108+
default_token_path.write_text("token-from-default")
109+
110+
token = auth_token_loader.load_auth_token(generate_if_not_found=False)
111+
assert token == "test-token-from-file"
112+
113+
def test_no_token_found(self):
114+
"""Test behavior when no token is found."""
115+
token = auth_token_loader.load_auth_token(generate_if_not_found=False)
116+
assert token == ""
117+
118+
def test_whitespace_handling(self, temp_token_file):
119+
"""Test that whitespace is properly trimmed from token files."""
120+
# Overwrite the temp file with whitespace
121+
with open(temp_token_file, "w") as f:
122+
f.write(" token-with-spaces \n\t")
123+
124+
os.environ["RAY_AUTH_TOKEN_PATH"] = temp_token_file
125+
token = auth_token_loader.load_auth_token(generate_if_not_found=False)
126+
assert token == "token-with-spaces"
127+
128+
def test_empty_env_variable(self):
129+
"""Test that empty environment variable is ignored."""
130+
os.environ["RAY_AUTH_TOKEN"] = " " # Empty after strip
131+
token = auth_token_loader.load_auth_token(generate_if_not_found=False)
132+
assert token == ""
133+
134+
def test_nonexistent_path_in_env(self):
135+
"""Test that nonexistent path in RAY_AUTH_TOKEN_PATH is handled gracefully."""
136+
os.environ["RAY_AUTH_TOKEN_PATH"] = "/nonexistent/path/to/token"
137+
token = auth_token_loader.load_auth_token(generate_if_not_found=False)
138+
assert token == ""
139+
140+
141+
class TestTokenGeneration:
142+
"""Tests for token generation functionality."""
143+
144+
def test_generate_token(self, default_token_path):
145+
"""Test token generation when no token exists."""
146+
token = auth_token_loader.load_auth_token(generate_if_not_found=True)
147+
148+
# Token should be a 32-character hex string (UUID without dashes)
149+
assert len(token) == 32
150+
assert all(c in "0123456789abcdef" for c in token)
151+
152+
# Token should be saved to default path
153+
assert default_token_path.exists()
154+
saved_token = default_token_path.read_text().strip()
155+
assert saved_token == token
156+
157+
def test_no_generation_without_flag(self):
158+
"""Test that token is not generated when flag is False."""
159+
token = auth_token_loader.load_auth_token(generate_if_not_found=False)
160+
assert token == ""
161+
162+
def test_dont_generate_when_token_exists(self):
163+
"""Test that token is not generated when one already exists."""
164+
os.environ["RAY_AUTH_TOKEN"] = "existing-token"
165+
token = auth_token_loader.load_auth_token(generate_if_not_found=True)
166+
assert token == "existing-token"
167+
168+
169+
class TestTokenCaching:
170+
"""Tests for token caching behavior."""
171+
172+
def test_caching_behavior(self):
173+
"""Test that token is cached after first load."""
174+
os.environ["RAY_AUTH_TOKEN"] = "cached-token"
175+
token1 = auth_token_loader.load_auth_token(generate_if_not_found=False)
176+
177+
# Change environment variable (shouldn't affect cached value)
178+
os.environ["RAY_AUTH_TOKEN"] = "new-token"
179+
token2 = auth_token_loader.load_auth_token(generate_if_not_found=False)
180+
181+
# Should still return the cached token
182+
assert token1 == token2 == "cached-token"
183+
184+
def test_cache_empty_result(self):
185+
"""Test that even empty results are cached."""
186+
# First call with no token
187+
token1 = auth_token_loader.load_auth_token(generate_if_not_found=False)
188+
assert token1 == ""
189+
190+
# Set environment variable after first call
191+
os.environ["RAY_AUTH_TOKEN"] = "new-token"
192+
token2 = auth_token_loader.load_auth_token(generate_if_not_found=False)
193+
194+
# Should still return cached empty string
195+
assert token2 == ""
196+
197+
198+
class TestHasAuthToken:
199+
"""Tests for has_auth_token function."""
200+
201+
def test_has_token_true(self):
202+
"""Test has_auth_token returns True when token exists."""
203+
os.environ["RAY_AUTH_TOKEN"] = "test-token"
204+
assert auth_token_loader.has_auth_token() is True
205+
206+
def test_has_token_false(self):
207+
"""Test has_auth_token returns False when no token exists."""
208+
assert auth_token_loader.has_auth_token() is False
209+
210+
def test_has_token_caches_result(self):
211+
"""Test that has_auth_token doesn't trigger generation."""
212+
# This should return False without generating a token
213+
assert auth_token_loader.has_auth_token() is False
214+
215+
# Verify no token was generated
216+
default_path = Path.home() / ".ray" / "auth_token"
217+
assert not default_path.exists()
218+
219+
220+
class TestThreadSafety:
221+
"""Tests for thread safety of token loading."""
222+
223+
def test_concurrent_loads(self):
224+
"""Test that concurrent token loads are thread-safe."""
225+
os.environ["RAY_AUTH_TOKEN"] = "thread-safe-token"
226+
227+
results = []
228+
threads = []
229+
230+
def load_token():
231+
token = auth_token_loader.load_auth_token(generate_if_not_found=False)
232+
results.append(token)
233+
234+
# Create multiple threads that try to load token simultaneously
235+
for _ in range(10):
236+
thread = threading.Thread(target=load_token)
237+
threads.append(thread)
238+
thread.start()
239+
240+
# Wait for all threads to complete
241+
for thread in threads:
242+
thread.join()
243+
244+
# All threads should get the same token
245+
assert len(results) == 10
246+
assert all(result == "thread-safe-token" for result in results)
247+
248+
def test_concurrent_generation(self, default_token_path):
249+
"""Test that concurrent token generation is thread-safe."""
250+
results = []
251+
threads = []
252+
253+
def generate_token():
254+
token = auth_token_loader.load_auth_token(generate_if_not_found=True)
255+
results.append(token)
256+
257+
# Create multiple threads that try to generate token simultaneously
258+
for _ in range(5):
259+
thread = threading.Thread(target=generate_token)
260+
threads.append(thread)
261+
thread.start()
262+
263+
# Wait for all threads to complete
264+
for thread in threads:
265+
thread.join()
266+
267+
# All threads should get the same token (only generated once)
268+
assert len(results) == 5
269+
assert len(set(results)) == 1 # All tokens should be identical
270+
271+
272+
class TestFilePermissions:
273+
"""Tests for file permissions when saving tokens."""
274+
275+
def test_file_permissions_on_unix(self, default_token_path, monkeypatch):
276+
"""Test that token file has 0600 permissions on Unix systems."""
277+
# Skip on Windows
278+
if os.name == "nt":
279+
pytest.skip("Test only relevant on Unix systems")
280+
281+
token = auth_token_loader.load_auth_token(generate_if_not_found=True)
282+
assert token
283+
284+
# Check file permissions (should be 0600)
285+
stat_info = default_token_path.stat()
286+
assert stat_info.st_mode & 0o777 == 0o600
287+
288+
def test_file_permissions_error_handling(self, monkeypatch):
289+
"""Test that permission errors are handled gracefully."""
290+
291+
# Mock os.chmod to raise an exception
292+
def mock_chmod(path, mode):
293+
raise OSError("Permission denied")
294+
295+
monkeypatch.setattr(os, "chmod", mock_chmod)
296+
297+
# Should still generate and return token, just not set permissions
298+
token = auth_token_loader.load_auth_token(generate_if_not_found=True)
299+
assert len(token) == 32
300+
301+
302+
class TestIntegration:
303+
"""Integration tests with ray.init() and ray start CLI."""
304+
305+
def test_token_loader_with_ray_init(self, default_token_path):
306+
"""Test that token loader works with ray.init() enable_token_auth parameter."""
307+
# This is more of a smoke test to ensure the module can be imported
308+
# and used in the context where it will be called
309+
from ray._private import auth_token_loader
310+
311+
# Generate a token
312+
token = auth_token_loader.load_auth_token(generate_if_not_found=True)
313+
assert token
314+
assert len(token) == 32
315+
316+
# Verify it was saved
317+
assert default_token_path.exists()
318+
saved_token = default_token_path.read_text().strip()
319+
assert saved_token == token
320+

0 commit comments

Comments
 (0)