|
| 1 | +"""Unit tests for ray._private.auth_token_loader module.""" |
| 2 | + |
| 3 | +import os |
| 4 | +import tempfile |
| 5 | +import threading |
| 6 | +from pathlib import Path |
| 7 | +from unittest import mock |
| 8 | + |
| 9 | +import pytest |
| 10 | + |
| 11 | +from ray._private import auth_token_loader |
| 12 | + |
| 13 | + |
| 14 | +@pytest.fixture(autouse=True) |
| 15 | +def reset_cached_token(): |
| 16 | + """Reset the cached token before each test.""" |
| 17 | + auth_token_loader._cached_token = None |
| 18 | + yield |
| 19 | + auth_token_loader._cached_token = None |
| 20 | + |
| 21 | + |
| 22 | +@pytest.fixture(autouse=True) |
| 23 | +def clean_env_vars(): |
| 24 | + """Clean up environment variables before and after each test.""" |
| 25 | + env_vars = ["RAY_AUTH_TOKEN", "RAY_AUTH_TOKEN_PATH"] |
| 26 | + old_values = {var: os.environ.get(var) for var in env_vars} |
| 27 | + |
| 28 | + # Clear environment variables |
| 29 | + for var in env_vars: |
| 30 | + if var in os.environ: |
| 31 | + del os.environ[var] |
| 32 | + |
| 33 | + yield |
| 34 | + |
| 35 | + # Restore old values |
| 36 | + for var, value in old_values.items(): |
| 37 | + if value is not None: |
| 38 | + os.environ[var] = value |
| 39 | + elif var in os.environ: |
| 40 | + del os.environ[var] |
| 41 | + |
| 42 | + |
| 43 | +@pytest.fixture |
| 44 | +def temp_token_file(): |
| 45 | + """Create a temporary token file and clean it up after the test.""" |
| 46 | + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".token") as f: |
| 47 | + temp_path = f.name |
| 48 | + f.write("test-token-from-file") |
| 49 | + yield temp_path |
| 50 | + try: |
| 51 | + os.unlink(temp_path) |
| 52 | + except FileNotFoundError: |
| 53 | + pass |
| 54 | + |
| 55 | + |
| 56 | +@pytest.fixture |
| 57 | +def default_token_path(): |
| 58 | + """Return the default token path and clean it up after the test.""" |
| 59 | + path = Path.home() / ".ray" / "auth_token" |
| 60 | + yield path |
| 61 | + try: |
| 62 | + path.unlink() |
| 63 | + except FileNotFoundError: |
| 64 | + pass |
| 65 | + |
| 66 | + |
| 67 | +class TestLoadAuthToken: |
| 68 | + """Tests for load_auth_token function.""" |
| 69 | + |
| 70 | + def test_load_from_env_variable(self): |
| 71 | + """Test loading token from RAY_AUTH_TOKEN environment variable.""" |
| 72 | + os.environ["RAY_AUTH_TOKEN"] = "token-from-env" |
| 73 | + token = auth_token_loader.load_auth_token(generate_if_not_found=False) |
| 74 | + assert token == "token-from-env" |
| 75 | + |
| 76 | + def test_load_from_env_path(self, temp_token_file): |
| 77 | + """Test loading token from RAY_AUTH_TOKEN_PATH environment variable.""" |
| 78 | + os.environ["RAY_AUTH_TOKEN_PATH"] = temp_token_file |
| 79 | + token = auth_token_loader.load_auth_token(generate_if_not_found=False) |
| 80 | + assert token == "test-token-from-file" |
| 81 | + |
| 82 | + def test_load_from_default_path(self, default_token_path): |
| 83 | + """Test loading token from default ~/.ray/auth_token path.""" |
| 84 | + default_token_path.parent.mkdir(parents=True, exist_ok=True) |
| 85 | + default_token_path.write_text("token-from-default") |
| 86 | + |
| 87 | + token = auth_token_loader.load_auth_token(generate_if_not_found=False) |
| 88 | + assert token == "token-from-default" |
| 89 | + |
| 90 | + def test_precedence_order(self, temp_token_file, default_token_path): |
| 91 | + """Test that token loading follows correct precedence order.""" |
| 92 | + # Set all three sources |
| 93 | + os.environ["RAY_AUTH_TOKEN"] = "token-from-env" |
| 94 | + os.environ["RAY_AUTH_TOKEN_PATH"] = temp_token_file |
| 95 | + |
| 96 | + default_token_path.parent.mkdir(parents=True, exist_ok=True) |
| 97 | + default_token_path.write_text("token-from-default") |
| 98 | + |
| 99 | + # Environment variable should have highest precedence |
| 100 | + token = auth_token_loader.load_auth_token(generate_if_not_found=False) |
| 101 | + assert token == "token-from-env" |
| 102 | + |
| 103 | + def test_env_path_over_default(self, temp_token_file, default_token_path): |
| 104 | + """Test that RAY_AUTH_TOKEN_PATH has precedence over default path.""" |
| 105 | + os.environ["RAY_AUTH_TOKEN_PATH"] = temp_token_file |
| 106 | + |
| 107 | + default_token_path.parent.mkdir(parents=True, exist_ok=True) |
| 108 | + default_token_path.write_text("token-from-default") |
| 109 | + |
| 110 | + token = auth_token_loader.load_auth_token(generate_if_not_found=False) |
| 111 | + assert token == "test-token-from-file" |
| 112 | + |
| 113 | + def test_no_token_found(self): |
| 114 | + """Test behavior when no token is found.""" |
| 115 | + token = auth_token_loader.load_auth_token(generate_if_not_found=False) |
| 116 | + assert token == "" |
| 117 | + |
| 118 | + def test_whitespace_handling(self, temp_token_file): |
| 119 | + """Test that whitespace is properly trimmed from token files.""" |
| 120 | + # Overwrite the temp file with whitespace |
| 121 | + with open(temp_token_file, "w") as f: |
| 122 | + f.write(" token-with-spaces \n\t") |
| 123 | + |
| 124 | + os.environ["RAY_AUTH_TOKEN_PATH"] = temp_token_file |
| 125 | + token = auth_token_loader.load_auth_token(generate_if_not_found=False) |
| 126 | + assert token == "token-with-spaces" |
| 127 | + |
| 128 | + def test_empty_env_variable(self): |
| 129 | + """Test that empty environment variable is ignored.""" |
| 130 | + os.environ["RAY_AUTH_TOKEN"] = " " # Empty after strip |
| 131 | + token = auth_token_loader.load_auth_token(generate_if_not_found=False) |
| 132 | + assert token == "" |
| 133 | + |
| 134 | + def test_nonexistent_path_in_env(self): |
| 135 | + """Test that nonexistent path in RAY_AUTH_TOKEN_PATH is handled gracefully.""" |
| 136 | + os.environ["RAY_AUTH_TOKEN_PATH"] = "/nonexistent/path/to/token" |
| 137 | + token = auth_token_loader.load_auth_token(generate_if_not_found=False) |
| 138 | + assert token == "" |
| 139 | + |
| 140 | + |
| 141 | +class TestTokenGeneration: |
| 142 | + """Tests for token generation functionality.""" |
| 143 | + |
| 144 | + def test_generate_token(self, default_token_path): |
| 145 | + """Test token generation when no token exists.""" |
| 146 | + token = auth_token_loader.load_auth_token(generate_if_not_found=True) |
| 147 | + |
| 148 | + # Token should be a 32-character hex string (UUID without dashes) |
| 149 | + assert len(token) == 32 |
| 150 | + assert all(c in "0123456789abcdef" for c in token) |
| 151 | + |
| 152 | + # Token should be saved to default path |
| 153 | + assert default_token_path.exists() |
| 154 | + saved_token = default_token_path.read_text().strip() |
| 155 | + assert saved_token == token |
| 156 | + |
| 157 | + def test_no_generation_without_flag(self): |
| 158 | + """Test that token is not generated when flag is False.""" |
| 159 | + token = auth_token_loader.load_auth_token(generate_if_not_found=False) |
| 160 | + assert token == "" |
| 161 | + |
| 162 | + def test_dont_generate_when_token_exists(self): |
| 163 | + """Test that token is not generated when one already exists.""" |
| 164 | + os.environ["RAY_AUTH_TOKEN"] = "existing-token" |
| 165 | + token = auth_token_loader.load_auth_token(generate_if_not_found=True) |
| 166 | + assert token == "existing-token" |
| 167 | + |
| 168 | + |
| 169 | +class TestTokenCaching: |
| 170 | + """Tests for token caching behavior.""" |
| 171 | + |
| 172 | + def test_caching_behavior(self): |
| 173 | + """Test that token is cached after first load.""" |
| 174 | + os.environ["RAY_AUTH_TOKEN"] = "cached-token" |
| 175 | + token1 = auth_token_loader.load_auth_token(generate_if_not_found=False) |
| 176 | + |
| 177 | + # Change environment variable (shouldn't affect cached value) |
| 178 | + os.environ["RAY_AUTH_TOKEN"] = "new-token" |
| 179 | + token2 = auth_token_loader.load_auth_token(generate_if_not_found=False) |
| 180 | + |
| 181 | + # Should still return the cached token |
| 182 | + assert token1 == token2 == "cached-token" |
| 183 | + |
| 184 | + def test_cache_empty_result(self): |
| 185 | + """Test that even empty results are cached.""" |
| 186 | + # First call with no token |
| 187 | + token1 = auth_token_loader.load_auth_token(generate_if_not_found=False) |
| 188 | + assert token1 == "" |
| 189 | + |
| 190 | + # Set environment variable after first call |
| 191 | + os.environ["RAY_AUTH_TOKEN"] = "new-token" |
| 192 | + token2 = auth_token_loader.load_auth_token(generate_if_not_found=False) |
| 193 | + |
| 194 | + # Should still return cached empty string |
| 195 | + assert token2 == "" |
| 196 | + |
| 197 | + |
| 198 | +class TestHasAuthToken: |
| 199 | + """Tests for has_auth_token function.""" |
| 200 | + |
| 201 | + def test_has_token_true(self): |
| 202 | + """Test has_auth_token returns True when token exists.""" |
| 203 | + os.environ["RAY_AUTH_TOKEN"] = "test-token" |
| 204 | + assert auth_token_loader.has_auth_token() is True |
| 205 | + |
| 206 | + def test_has_token_false(self): |
| 207 | + """Test has_auth_token returns False when no token exists.""" |
| 208 | + assert auth_token_loader.has_auth_token() is False |
| 209 | + |
| 210 | + def test_has_token_caches_result(self): |
| 211 | + """Test that has_auth_token doesn't trigger generation.""" |
| 212 | + # This should return False without generating a token |
| 213 | + assert auth_token_loader.has_auth_token() is False |
| 214 | + |
| 215 | + # Verify no token was generated |
| 216 | + default_path = Path.home() / ".ray" / "auth_token" |
| 217 | + assert not default_path.exists() |
| 218 | + |
| 219 | + |
| 220 | +class TestThreadSafety: |
| 221 | + """Tests for thread safety of token loading.""" |
| 222 | + |
| 223 | + def test_concurrent_loads(self): |
| 224 | + """Test that concurrent token loads are thread-safe.""" |
| 225 | + os.environ["RAY_AUTH_TOKEN"] = "thread-safe-token" |
| 226 | + |
| 227 | + results = [] |
| 228 | + threads = [] |
| 229 | + |
| 230 | + def load_token(): |
| 231 | + token = auth_token_loader.load_auth_token(generate_if_not_found=False) |
| 232 | + results.append(token) |
| 233 | + |
| 234 | + # Create multiple threads that try to load token simultaneously |
| 235 | + for _ in range(10): |
| 236 | + thread = threading.Thread(target=load_token) |
| 237 | + threads.append(thread) |
| 238 | + thread.start() |
| 239 | + |
| 240 | + # Wait for all threads to complete |
| 241 | + for thread in threads: |
| 242 | + thread.join() |
| 243 | + |
| 244 | + # All threads should get the same token |
| 245 | + assert len(results) == 10 |
| 246 | + assert all(result == "thread-safe-token" for result in results) |
| 247 | + |
| 248 | + def test_concurrent_generation(self, default_token_path): |
| 249 | + """Test that concurrent token generation is thread-safe.""" |
| 250 | + results = [] |
| 251 | + threads = [] |
| 252 | + |
| 253 | + def generate_token(): |
| 254 | + token = auth_token_loader.load_auth_token(generate_if_not_found=True) |
| 255 | + results.append(token) |
| 256 | + |
| 257 | + # Create multiple threads that try to generate token simultaneously |
| 258 | + for _ in range(5): |
| 259 | + thread = threading.Thread(target=generate_token) |
| 260 | + threads.append(thread) |
| 261 | + thread.start() |
| 262 | + |
| 263 | + # Wait for all threads to complete |
| 264 | + for thread in threads: |
| 265 | + thread.join() |
| 266 | + |
| 267 | + # All threads should get the same token (only generated once) |
| 268 | + assert len(results) == 5 |
| 269 | + assert len(set(results)) == 1 # All tokens should be identical |
| 270 | + |
| 271 | + |
| 272 | +class TestFilePermissions: |
| 273 | + """Tests for file permissions when saving tokens.""" |
| 274 | + |
| 275 | + def test_file_permissions_on_unix(self, default_token_path, monkeypatch): |
| 276 | + """Test that token file has 0600 permissions on Unix systems.""" |
| 277 | + # Skip on Windows |
| 278 | + if os.name == "nt": |
| 279 | + pytest.skip("Test only relevant on Unix systems") |
| 280 | + |
| 281 | + token = auth_token_loader.load_auth_token(generate_if_not_found=True) |
| 282 | + assert token |
| 283 | + |
| 284 | + # Check file permissions (should be 0600) |
| 285 | + stat_info = default_token_path.stat() |
| 286 | + assert stat_info.st_mode & 0o777 == 0o600 |
| 287 | + |
| 288 | + def test_file_permissions_error_handling(self, monkeypatch): |
| 289 | + """Test that permission errors are handled gracefully.""" |
| 290 | + |
| 291 | + # Mock os.chmod to raise an exception |
| 292 | + def mock_chmod(path, mode): |
| 293 | + raise OSError("Permission denied") |
| 294 | + |
| 295 | + monkeypatch.setattr(os, "chmod", mock_chmod) |
| 296 | + |
| 297 | + # Should still generate and return token, just not set permissions |
| 298 | + token = auth_token_loader.load_auth_token(generate_if_not_found=True) |
| 299 | + assert len(token) == 32 |
| 300 | + |
| 301 | + |
| 302 | +class TestIntegration: |
| 303 | + """Integration tests with ray.init() and ray start CLI.""" |
| 304 | + |
| 305 | + def test_token_loader_with_ray_init(self, default_token_path): |
| 306 | + """Test that token loader works with ray.init() enable_token_auth parameter.""" |
| 307 | + # This is more of a smoke test to ensure the module can be imported |
| 308 | + # and used in the context where it will be called |
| 309 | + from ray._private import auth_token_loader |
| 310 | + |
| 311 | + # Generate a token |
| 312 | + token = auth_token_loader.load_auth_token(generate_if_not_found=True) |
| 313 | + assert token |
| 314 | + assert len(token) == 32 |
| 315 | + |
| 316 | + # Verify it was saved |
| 317 | + assert default_token_path.exists() |
| 318 | + saved_token = default_token_path.read_text().strip() |
| 319 | + assert saved_token == token |
| 320 | + |
0 commit comments