diff --git a/src/ray/common/constants.h b/src/ray/common/constants.h index bfd06e677e7e..08986d3b415e 100644 --- a/src/ray/common/constants.h +++ b/src/ray/common/constants.h @@ -42,6 +42,8 @@ constexpr int kRayletStoreErrorExitCode = 100; constexpr char kObjectTablePrefix[] = "ObjectTable"; constexpr char kClusterIdKey[] = "ray_cluster_id"; +constexpr char kAuthTokenKey[] = "authorization"; +constexpr char kBearerPrefix[] = "Bearer "; constexpr char kWorkerDynamicOptionPlaceholder[] = "RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER"; diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 6e8d21956162..e4e8fc1d48ef 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -35,6 +35,12 @@ RAY_CONFIG(bool, emit_main_service_metrics, true) /// Whether to enable cluster authentication. RAY_CONFIG(bool, enable_cluster_auth, true) +/// Whether to enable token-based authentication for RPC calls. +/// will be converted to AuthenticationMode enum defined in +/// rpc/authentication/authentication_mode.h +/// use GetAuthenticationMode() to get the authentication mode enum value. +RAY_CONFIG(std::string, auth_mode, "disabled") + /// The interval of periodic event loop stats print. /// -1 means the feature is disabled. In this case, stats are available /// in the associated process's log file. diff --git a/src/ray/rpc/authentication/BUILD.bazel b/src/ray/rpc/authentication/BUILD.bazel new file mode 100644 index 000000000000..8da78e5d728b --- /dev/null +++ b/src/ray/rpc/authentication/BUILD.bazel @@ -0,0 +1,34 @@ +load("//bazel:ray.bzl", "ray_cc_library") + +ray_cc_library( + name = "authentication_mode", + srcs = ["authentication_mode.cc"], + hdrs = ["authentication_mode.h"], + visibility = ["//visibility:public"], + deps = [ + "//src/ray/common:ray_config", + "@com_google_absl//absl/strings", + ], +) + +ray_cc_library( + name = "authentication_token", + hdrs = ["authentication_token.h"], + visibility = ["//visibility:public"], + deps = [ + "//src/ray/common:constants", + "@com_github_grpc_grpc//:grpc++", + ], +) + +ray_cc_library( + name = "authentication_token_loader", + srcs = ["authentication_token_loader.cc"], + hdrs = ["authentication_token_loader.h"], + visibility = ["//visibility:public"], + deps = [ + ":authentication_mode", + ":authentication_token", + "//src/ray/util:logging", + ], +) diff --git a/src/ray/rpc/authentication/authentication_mode.cc b/src/ray/rpc/authentication/authentication_mode.cc new file mode 100644 index 000000000000..1bbe209733ce --- /dev/null +++ b/src/ray/rpc/authentication/authentication_mode.cc @@ -0,0 +1,37 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/rpc/authentication/authentication_mode.h" + +#include +#include + +#include "absl/strings/ascii.h" +#include "ray/common/ray_config.h" + +namespace ray { +namespace rpc { + +AuthenticationMode GetAuthenticationMode() { + std::string auth_mode_lower = absl::AsciiStrToLower(RayConfig::instance().auth_mode()); + + if (auth_mode_lower == "token") { + return AuthenticationMode::TOKEN; + } else { + return AuthenticationMode::DISABLED; + } +} + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/authentication/authentication_mode.h b/src/ray/rpc/authentication/authentication_mode.h new file mode 100644 index 000000000000..21bd165fd34b --- /dev/null +++ b/src/ray/rpc/authentication/authentication_mode.h @@ -0,0 +1,33 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace ray { +namespace rpc { + +enum class AuthenticationMode { + DISABLED, + TOKEN, +}; + +/// Get the authentication mode from the RayConfig. +/// \return The authentication mode enum value. returns AuthenticationMode::DISABLED if +/// the authentication mode is not set or is invalid. +AuthenticationMode GetAuthenticationMode(); + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/authentication/authentication_token.h b/src/ray/rpc/authentication/authentication_token.h new file mode 100644 index 000000000000..4f32310784de --- /dev/null +++ b/src/ray/rpc/authentication/authentication_token.h @@ -0,0 +1,156 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "ray/common/constants.h" + +namespace ray { +namespace rpc { + +/// Secure wrapper for authentication tokens. +/// - Wipes memory on destruction +/// - Constant-time comparison +/// - Redacted output when logged or printed +class AuthenticationToken { + public: + AuthenticationToken() = default; + explicit AuthenticationToken(std::string value) : secret_(value.begin(), value.end()) {} + + AuthenticationToken(const AuthenticationToken &other) : secret_(other.secret_) {} + AuthenticationToken &operator=(const AuthenticationToken &other) { + if (this != &other) { + SecureClear(); + secret_ = other.secret_; + } + return *this; + } + + // Move operations + AuthenticationToken(AuthenticationToken &&other) noexcept { + MoveFrom(std::move(other)); + } + AuthenticationToken &operator=(AuthenticationToken &&other) noexcept { + if (this != &other) { + SecureClear(); + MoveFrom(std::move(other)); + } + return *this; + } + ~AuthenticationToken() { SecureClear(); } + + bool empty() const noexcept { return secret_.empty(); } + + /// Constant-time equality comparison + bool Equals(const AuthenticationToken &other) const noexcept { + return ConstTimeEqual(secret_, other.secret_); + } + + /// Equality operator (constant-time) + bool operator==(const AuthenticationToken &other) const noexcept { + return Equals(other); + } + + /// Inequality operator + bool operator!=(const AuthenticationToken &other) const noexcept { + return !(*this == other); + } + + /// Set authentication metadata on a gRPC client context + /// Only call this from client-side code + void SetMetadata(grpc::ClientContext &context) const { + if (!secret_.empty()) { + context.AddMetadata(kAuthTokenKey, + kBearerPrefix + std::string(secret_.begin(), secret_.end())); + } + } + + /// Create AuthenticationToken from gRPC metadata value + /// Strips "Bearer " prefix and creates token object + /// @param metadata_value The raw value from server metadata (should include "Bearer " + /// prefix) + /// @return AuthenticationToken object (empty if format invalid) + static AuthenticationToken FromMetadata(std::string_view metadata_value) { + const std::string_view prefix(kBearerPrefix); + if (metadata_value.size() < prefix.size() || + metadata_value.substr(0, prefix.size()) != prefix) { + return AuthenticationToken(); // Invalid format, return empty + } + std::string_view token_part = metadata_value.substr(prefix.size()); + return AuthenticationToken(std::string(token_part)); + } + + friend std::ostream &operator<<(std::ostream &os, const AuthenticationToken &t) { + return os << ""; + } + + private: + std::vector secret_; + + // Constant-time string comparison to avoid timing attacks. + // https://en.wikipedia.org/wiki/Timing_attack + static bool ConstTimeEqual(const std::vector &a, + const std::vector &b) noexcept { + if (a.size() != b.size()) { + return false; + } + unsigned char diff = 0; + for (size_t i = 0; i < a.size(); ++i) { + diff |= a[i] ^ b[i]; + } + return diff == 0; + } + + // replace the characters in the memory with 0 + static void ExplicitBurn(void *p, size_t n) noexcept { +#if defined(_MSC_VER) + SecureZeroMemory(p, n); +#elif defined(__STDC_LIB_EXT1__) + memset_s(p, n, 0, n); +#else + // Using array indexing instead of pointer arithmetic + volatile auto *vp = static_cast(p); + for (size_t i = 0; i < n; ++i) { + vp[i] = 0; + } +#endif + } + + void SecureClear() noexcept { + if (!secret_.empty()) { + ExplicitBurn(secret_.data(), secret_.size()); + secret_.clear(); + } + } + + void MoveFrom(AuthenticationToken &&other) noexcept { + secret_ = std::move(other.secret_); + // Clear the moved-from object explicitly for security + // Note: 'other' is already an rvalue reference, no need to move again + other.SecureClear(); + } +}; + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/authentication/authentication_token_loader.cc b/src/ray/rpc/authentication/authentication_token_loader.cc new file mode 100644 index 000000000000..621f28fe351c --- /dev/null +++ b/src/ray/rpc/authentication/authentication_token_loader.cc @@ -0,0 +1,173 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/rpc/authentication/authentication_token_loader.h" + +#include +#include +#include + +#include "ray/util/logging.h" + +#ifdef _WIN32 +#ifndef _WINDOWS_ +#ifndef WIN32_LEAN_AND_MEAN // Sorry for the inconvenience. Please include any related + // headers you need manually. + // (https://stackoverflow.com/a/8294669) +#define WIN32_LEAN_AND_MEAN // Prevent inclusion of WinSock2.h +#endif +#include // Force inclusion of WinGDI here to resolve name conflict +#endif +#endif + +namespace ray { +namespace rpc { + +AuthenticationTokenLoader &AuthenticationTokenLoader::instance() { + static AuthenticationTokenLoader instance; + return instance; +} + +std::optional AuthenticationTokenLoader::GetToken() { + std::lock_guard lock(token_mutex_); + + // If already loaded, return cached value + if (cached_token_.has_value()) { + return cached_token_; + } + + // If token auth is not enabled, return std::nullopt + if (GetAuthenticationMode() != AuthenticationMode::TOKEN) { + cached_token_ = std::nullopt; + return std::nullopt; + } + + // Token auth is enabled, try to load from sources + AuthenticationToken token = LoadTokenFromSources(); + + // If no token found and auth is enabled, fail with RAY_CHECK + RAY_CHECK(!token.empty()) + << "Token authentication is enabled but Ray couldn't find an authentication token. " + << "Set the RAY_AUTH_TOKEN environment variable, or set RAY_AUTH_TOKEN_PATH to " + "point to a file with the token, " + << "or create a token file at ~/.ray/auth_token."; + + // Cache and return the loaded token + cached_token_ = std::move(token); + return *cached_token_; +} + +// Read token from the first line of the file. trim whitespace. +// Returns empty string if file cannot be opened or is empty. +std::string AuthenticationTokenLoader::ReadTokenFromFile(const std::string &file_path) { + std::ifstream token_file(file_path); + if (!token_file.is_open()) { + return ""; + } + + std::string token; + std::getline(token_file, token); + token_file.close(); + return token; +} + +AuthenticationToken AuthenticationTokenLoader::LoadTokenFromSources() { + // Precedence 1: RAY_AUTH_TOKEN environment variable + const char *env_token = std::getenv("RAY_AUTH_TOKEN"); + if (env_token != nullptr) { + std::string token_str(env_token); + if (!token_str.empty()) { + RAY_LOG(DEBUG) << "Loaded authentication token from RAY_AUTH_TOKEN environment " + "variable"; + return AuthenticationToken(TrimWhitespace(token_str)); + } + } + + // Precedence 2: RAY_AUTH_TOKEN_PATH environment variable + const char *env_token_path = std::getenv("RAY_AUTH_TOKEN_PATH"); + if (env_token_path != nullptr) { + std::string path_str(env_token_path); + if (!path_str.empty()) { + std::string token_str = TrimWhitespace(ReadTokenFromFile(path_str)); + RAY_CHECK(!token_str.empty()) + << "RAY_AUTH_TOKEN_PATH is set but file cannot be opened or is empty: " + << path_str; + RAY_LOG(DEBUG) << "Loaded authentication token from file: " << path_str; + return AuthenticationToken(token_str); + } + } + + // Precedence 3: Default token path ~/.ray/auth_token + std::string default_path = GetDefaultTokenPath(); + std::string token_str = TrimWhitespace(ReadTokenFromFile(default_path)); + if (!token_str.empty()) { + RAY_LOG(DEBUG) << "Loaded authentication token from default path: " << default_path; + return AuthenticationToken(token_str); + } + + // No token found + RAY_LOG(DEBUG) << "No authentication token found in any source"; + return AuthenticationToken(); +} + +std::string AuthenticationTokenLoader::GetDefaultTokenPath() { + std::string home_dir; + +#ifdef _WIN32 + const char *path_separator = "\\"; + const char *userprofile = std::getenv("USERPROFILE"); + if (userprofile != nullptr) { + home_dir = userprofile; + } else { + const char *homedrive = std::getenv("HOMEDRIVE"); + const char *homepath = std::getenv("HOMEPATH"); + if (homedrive != nullptr && homepath != nullptr) { + home_dir = std::string(homedrive) + std::string(homepath); + } + } +#else + const char *path_separator = "/"; + const char *home = std::getenv("HOME"); + if (home != nullptr) { + home_dir = home; + } +#endif + + const std::string token_subpath = + std::string(path_separator) + ".ray" + std::string(path_separator) + "auth_token"; + + if (home_dir.empty()) { + RAY_LOG(WARNING) << "Cannot determine home directory for token storage"; + return "." + token_subpath; + } + + return home_dir + token_subpath; +} + +std::string AuthenticationTokenLoader::TrimWhitespace(const std::string &str) { + std::string whitespace = " \t\n\r\f\v"; + std::string trimmed_str = str; + trimmed_str.erase(0, trimmed_str.find_first_not_of(whitespace)); + + // if the string is empty, return it + if (trimmed_str.empty()) { + return trimmed_str; + } + + trimmed_str.erase(trimmed_str.find_last_not_of(whitespace) + 1); + return trimmed_str; +} + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/authentication/authentication_token_loader.h b/src/ray/rpc/authentication/authentication_token_loader.h new file mode 100644 index 000000000000..4034ecbc78dd --- /dev/null +++ b/src/ray/rpc/authentication/authentication_token_loader.h @@ -0,0 +1,72 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions +// and limitations under the License. + +#pragma once + +#include +#include +#include + +#include "ray/rpc/authentication/authentication_mode.h" +#include "ray/rpc/authentication/authentication_token.h" + +namespace ray { +namespace rpc { + +/// Singleton class for loading and caching authentication tokens. +/// Supports loading tokens from multiple sources with precedence: +/// 1. RAY_AUTH_TOKEN environment variable +/// 2. RAY_AUTH_TOKEN_PATH environment variable (path to token file) +/// 3. Default token path: ~/.ray/auth_token (Unix) or %USERPROFILE%\.ray\auth_token +/// +/// Thread-safe with internal caching to avoid repeated file I/O. +class AuthenticationTokenLoader { + public: + static AuthenticationTokenLoader &instance(); + + /// Get the authentication token. + /// If token authentication is enabled but no token is found, fails with RAY_CHECK. + /// \return The authentication token, or std::nullopt if auth is disabled. + std::optional GetToken(); + + void ResetCache() { + std::lock_guard lock(token_mutex_); + cached_token_.reset(); + } + + AuthenticationTokenLoader(const AuthenticationTokenLoader &) = delete; + AuthenticationTokenLoader &operator=(const AuthenticationTokenLoader &) = delete; + + private: + AuthenticationTokenLoader() = default; + ~AuthenticationTokenLoader() = default; + + /// Read and trim token from file. + std::string ReadTokenFromFile(const std::string &file_path); + + /// Load token from environment or file. + AuthenticationToken LoadTokenFromSources(); + + /// Default token file path (~/.ray/auth_token or %USERPROFILE%\.ray\auth_token). + std::string GetDefaultTokenPath(); + + /// Trim whitespace from the beginning and end of the string. + std::string TrimWhitespace(const std::string &str); + + std::mutex token_mutex_; + std::optional cached_token_; +}; + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/tests/BUILD.bazel b/src/ray/rpc/tests/BUILD.bazel index 5fa8b14cc4db..d5113ae0d3aa 100644 --- a/src/ray/rpc/tests/BUILD.bazel +++ b/src/ray/rpc/tests/BUILD.bazel @@ -40,3 +40,30 @@ ray_cc_test( "@com_google_googletest//:gtest_main", ], ) + +ray_cc_test( + name = "authentication_token_loader_test", + size = "small", + srcs = [ + "authentication_token_loader_test.cc", + ], + tags = ["team:core"], + deps = [ + "//src/ray/common:ray_config", + "//src/ray/rpc/authentication:authentication_token_loader", + "@com_google_googletest//:gtest_main", + ], +) + +ray_cc_test( + name = "authentication_token_test", + size = "small", + srcs = [ + "authentication_token_test.cc", + ], + tags = ["team:core"], + deps = [ + "//src/ray/rpc/authentication:authentication_token", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/src/ray/rpc/tests/authentication_token_loader_test.cc b/src/ray/rpc/tests/authentication_token_loader_test.cc new file mode 100644 index 000000000000..2332c6d09313 --- /dev/null +++ b/src/ray/rpc/tests/authentication_token_loader_test.cc @@ -0,0 +1,344 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/rpc/authentication/authentication_token_loader.h" + +#include +#include + +#include "gtest/gtest.h" +#include "ray/common/ray_config.h" + +#if defined(__APPLE__) || defined(__linux__) +#include +#include +#endif + +#ifdef _WIN32 +#ifndef _WINDOWS_ +#ifndef WIN32_LEAN_AND_MEAN // Sorry for the inconvenience. Please include any related + // headers you need manually. + // (https://stackoverflow.com/a/8294669) +#define WIN32_LEAN_AND_MEAN // Prevent inclusion of WinSock2.h +#endif +#include // Force inclusion of WinGDI here to resolve name conflict +#endif +#include // For _mkdir on Windows +#include // For _getpid on Windows +#endif + +namespace ray { +namespace rpc { + +class AuthenticationTokenLoaderTest : public ::testing::Test { + protected: + void SetUp() override { + // Enable token authentication for tests + RayConfig::instance().initialize(R"({"auth_mode": "token"})"); + + // If HOME is not set (e.g., in Bazel sandbox), set it to a test directory + // This ensures tests work in environments where HOME isn't provided +#ifdef _WIN32 + if (std::getenv("USERPROFILE") == nullptr) { + const char *test_tmpdir = std::getenv("TEST_TMPDIR"); + if (test_tmpdir != nullptr) { + test_home_dir_ = std::string(test_tmpdir) + "\\ray_test_home"; + } else { + test_home_dir_ = "C:\\Windows\\Temp\\ray_test_home"; + } + _putenv(("USERPROFILE=" + test_home_dir_).c_str()); + } + const char *home_dir = std::getenv("USERPROFILE"); + default_token_path_ = std::string(home_dir) + "\\.ray\\auth_token"; +#else + if (std::getenv("HOME") == nullptr) { + const char *test_tmpdir = std::getenv("TEST_TMPDIR"); + if (test_tmpdir != nullptr) { + test_home_dir_ = std::string(test_tmpdir) + "/ray_test_home"; + } else { + test_home_dir_ = "/tmp/ray_test_home"; + } + setenv("HOME", test_home_dir_.c_str(), 1); + } + const char *home_dir = std::getenv("HOME"); + if (home_dir != nullptr) { + default_token_path_ = std::string(home_dir) + "/.ray/auth_token"; + test_home_dir_ = home_dir; + } else { + default_token_path_ = ".ray/auth_token"; + } +#endif + cleanup_env(); + // Reset the singleton's cached state for test isolation + AuthenticationTokenLoader::instance().ResetCache(); + } + + void TearDown() override { + // Clean up after test + cleanup_env(); + // Reset the singleton's cached state for test isolation + AuthenticationTokenLoader::instance().ResetCache(); + // Disable token auth after tests + RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + } + + void cleanup_env() { + unset_env_var("RAY_AUTH_TOKEN"); + unset_env_var("RAY_AUTH_TOKEN_PATH"); + remove(default_token_path_.c_str()); + } + + std::string get_temp_token_path() { +#ifdef _WIN32 + return "C:\\Windows\\Temp\\ray_test_token_" + std::to_string(_getpid()); +#else + return "/tmp/ray_test_token_" + std::to_string(getpid()); +#endif + } + + void set_env_var(const char *name, const char *value) { +#ifdef _WIN32 + _putenv_s(name, value); +#else + setenv(name, value, 1); +#endif + } + + void unset_env_var(const char *name) { +#ifdef _WIN32 + _putenv_s(name, "") +#else + unsetenv(name); +#endif + } + + void ensure_ray_dir_exists() { +#ifdef _WIN32 + const char *home_dir = std::getenv("USERPROFILE"); + _mkdir(home_dir); // Create parent directory + std::string ray_dir = std::string(home_dir) + "\\.ray"; + _mkdir(ray_dir.c_str()); +#else + // Always ensure the home directory exists (it might be a test temp dir we created) + if (!test_home_dir_.empty()) { + mkdir(test_home_dir_.c_str(), + 0700); // Create if it doesn't exist (ignore error if it does) + } + + const char *home_dir = std::getenv("HOME"); + if (home_dir != nullptr) { + std::string ray_dir = std::string(home_dir) + "/.ray"; + mkdir(ray_dir.c_str(), 0700); + } +#endif + } + + void write_token_file(const std::string &path, const std::string &content) { + std::ofstream token_file(path); + token_file << content; + token_file.close(); + } + + std::string default_token_path_; + std::string test_home_dir_; // Fallback home directory for tests +}; + +TEST_F(AuthenticationTokenLoaderTest, TestLoadFromEnvVariable) { + // Set token in environment variable + set_env_var("RAY_AUTH_TOKEN", "test-token-from-env"); + + // Create a new instance to avoid cached state + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt = loader.GetToken(); + + ASSERT_TRUE(token_opt.has_value()); + AuthenticationToken expected("test-token-from-env"); + EXPECT_TRUE(token_opt->Equals(expected)); + EXPECT_TRUE(loader.GetToken().has_value()); +} + +TEST_F(AuthenticationTokenLoaderTest, TestLoadFromEnvPath) { + // Create a temporary token file + std::string temp_token_path = get_temp_token_path(); + write_token_file(temp_token_path, "test-token-from-file"); + + // Set path in environment variable + set_env_var("RAY_AUTH_TOKEN_PATH", temp_token_path.c_str()); + + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt = loader.GetToken(); + + ASSERT_TRUE(token_opt.has_value()); + AuthenticationToken expected("test-token-from-file"); + EXPECT_TRUE(token_opt->Equals(expected)); + EXPECT_TRUE(loader.GetToken().has_value()); + + // Clean up + remove(temp_token_path.c_str()); +} + +TEST_F(AuthenticationTokenLoaderTest, TestLoadFromDefaultPath) { + // Create directory and token file in default location + ensure_ray_dir_exists(); + write_token_file(default_token_path_, "test-token-from-default"); + + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt = loader.GetToken(); + + ASSERT_TRUE(token_opt.has_value()); + AuthenticationToken expected("test-token-from-default"); + EXPECT_TRUE(token_opt->Equals(expected)); + EXPECT_TRUE(loader.GetToken().has_value()); +} + +// Parametrized test for token loading precedence: env var > user-specified file > default +// file + +struct TokenSourceConfig { + bool set_env = false; + bool set_file = false; + bool set_default = false; + std::string expected_token; + std::string env_token = "token-from-env"; + std::string file_token = "token-from-path"; + std::string default_token = "token-from-default"; +}; + +class AuthenticationTokenLoaderPrecedenceTest + : public AuthenticationTokenLoaderTest, + public ::testing::WithParamInterface {}; + +INSTANTIATE_TEST_SUITE_P(TokenPrecedenceCases, + AuthenticationTokenLoaderPrecedenceTest, + ::testing::Values( + // All set: env should win + TokenSourceConfig{true, true, true, "token-from-env"}, + // File and default file set: file should win + TokenSourceConfig{false, true, true, "token-from-path"}, + // Only default file set + TokenSourceConfig{ + false, false, true, "token-from-default"})); + +TEST_P(AuthenticationTokenLoaderPrecedenceTest, Precedence) { + const auto ¶m = GetParam(); + + // Optionally set environment variable + if (param.set_env) { + set_env_var("RAY_AUTH_TOKEN", param.env_token.c_str()); + } else { + unset_env_var("RAY_AUTH_TOKEN"); + } + + // Optionally create file and set path + std::string temp_token_path = get_temp_token_path(); + if (param.set_file) { + write_token_file(temp_token_path, param.file_token); + set_env_var("RAY_AUTH_TOKEN_PATH", temp_token_path.c_str()); + } else { + unset_env_var("RAY_AUTH_TOKEN_PATH"); + } + + // Optionally create default file + ensure_ray_dir_exists(); + if (param.set_default) { + write_token_file(default_token_path_, param.default_token); + } else { + remove(default_token_path_.c_str()); + } + + // Always create a new instance to avoid cached state + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt = loader.GetToken(); + + ASSERT_TRUE(token_opt.has_value()); + AuthenticationToken expected(param.expected_token); + EXPECT_TRUE(token_opt->Equals(expected)); + + // Clean up token file if it was written + if (param.set_file) { + remove(temp_token_path.c_str()); + } + // Clean up default file if it was written + if (param.set_default) { + remove(default_token_path_.c_str()); + } +} + +TEST_F(AuthenticationTokenLoaderTest, TestNoTokenFoundWhenAuthDisabled) { + // Disable auth for this specific test + RayConfig::instance().initialize(R"({"auth_mode": "disabled"})"); + AuthenticationTokenLoader::instance().ResetCache(); + + // No token set anywhere, but auth is disabled + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt = loader.GetToken(); + + EXPECT_FALSE(token_opt.has_value()); + EXPECT_FALSE(loader.GetToken().has_value()); + + // Re-enable for other tests + RayConfig::instance().initialize(R"({"auth_mode": "token"})"); +} + +TEST_F(AuthenticationTokenLoaderTest, TestErrorWhenAuthEnabledButNoToken) { + // Token auth is already enabled in SetUp() + // No token exists, should trigger RAY_CHECK failure + EXPECT_DEATH( + { + auto &loader = AuthenticationTokenLoader::instance(); + loader.GetToken(); + }, + "Token authentication is enabled but Ray couldn't find an authentication token."); +} + +TEST_F(AuthenticationTokenLoaderTest, TestCaching) { + // Set token in environment + set_env_var("RAY_AUTH_TOKEN", "cached-token"); + + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt1 = loader.GetToken(); + + // Change environment variable (shouldn't affect cached value) + set_env_var("RAY_AUTH_TOKEN", "new-token"); + auto token_opt2 = loader.GetToken(); + + // Should still return the cached token + ASSERT_TRUE(token_opt1.has_value()); + ASSERT_TRUE(token_opt2.has_value()); + EXPECT_TRUE(token_opt1->Equals(*token_opt2)); + AuthenticationToken expected("cached-token"); + EXPECT_TRUE(token_opt2->Equals(expected)); +} + +TEST_F(AuthenticationTokenLoaderTest, TestWhitespaceHandling) { + // Create token file with whitespace + ensure_ray_dir_exists(); + write_token_file(default_token_path_, " token-with-spaces \n\t"); + + auto &loader = AuthenticationTokenLoader::instance(); + auto token_opt = loader.GetToken(); + + // Whitespace should be trimmed + ASSERT_TRUE(token_opt.has_value()); + AuthenticationToken expected("token-with-spaces"); + EXPECT_TRUE(token_opt->Equals(expected)); +} + +} // namespace rpc +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/rpc/tests/authentication_token_test.cc b/src/ray/rpc/tests/authentication_token_test.cc new file mode 100644 index 000000000000..77ae4eb7cfc2 --- /dev/null +++ b/src/ray/rpc/tests/authentication_token_test.cc @@ -0,0 +1,120 @@ +// Copyright 2025 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/rpc/authentication/authentication_token.h" + +#include +#include +#include + +#include "gtest/gtest.h" + +namespace ray { +namespace rpc { + +class AuthenticationTokenTest : public ::testing::Test {}; + +TEST_F(AuthenticationTokenTest, TestDefaultConstructor) { + AuthenticationToken token; + EXPECT_TRUE(token.empty()); +} + +TEST_F(AuthenticationTokenTest, TestConstructorWithValue) { + AuthenticationToken token("test-token-value"); + EXPECT_FALSE(token.empty()); + AuthenticationToken expected("test-token-value"); + EXPECT_TRUE(token.Equals(expected)); +} + +TEST_F(AuthenticationTokenTest, TestMoveConstructor) { + AuthenticationToken token1("original-token"); + AuthenticationToken token2(std::move(token1)); + + EXPECT_FALSE(token2.empty()); + AuthenticationToken expected("original-token"); + EXPECT_TRUE(token2.Equals(expected)); + EXPECT_TRUE(token1.empty()); +} + +TEST_F(AuthenticationTokenTest, TestMoveAssignment) { + AuthenticationToken token1("first-token"); + AuthenticationToken token2("second-token"); + + token2 = std::move(token1); + + EXPECT_FALSE(token2.empty()); + AuthenticationToken expected("first-token"); + EXPECT_TRUE(token2.Equals(expected)); + EXPECT_TRUE(token1.empty()); +} + +TEST_F(AuthenticationTokenTest, TestEquals) { + AuthenticationToken token1("same-token"); + AuthenticationToken token2("same-token"); + AuthenticationToken token3("different-token"); + + EXPECT_TRUE(token1.Equals(token2)); + EXPECT_FALSE(token1.Equals(token3)); + EXPECT_TRUE(token1 == token2); + EXPECT_FALSE(token1 == token3); + EXPECT_FALSE(token1 != token2); + EXPECT_TRUE(token1 != token3); +} + +TEST_F(AuthenticationTokenTest, TestEqualityDifferentLengths) { + AuthenticationToken token1("short"); + AuthenticationToken token2("much-longer-token"); + + EXPECT_FALSE(token1.Equals(token2)); +} + +TEST_F(AuthenticationTokenTest, TestEqualityEmptyTokens) { + AuthenticationToken token1; + AuthenticationToken token2; + + EXPECT_TRUE(token1.Equals(token2)); +} + +TEST_F(AuthenticationTokenTest, TestEqualityEmptyVsNonEmpty) { + AuthenticationToken token1; + AuthenticationToken token2("non-empty"); + + EXPECT_FALSE(token1.Equals(token2)); + EXPECT_FALSE(token2.Equals(token1)); +} + +TEST_F(AuthenticationTokenTest, TestRedactedOutput) { + AuthenticationToken token("super-secret-token"); + + std::ostringstream oss; + oss << token; + + std::string output = oss.str(); + EXPECT_EQ(output, ""); + EXPECT_EQ(output.find("super-secret-token"), std::string::npos); +} + +TEST_F(AuthenticationTokenTest, TestEmptyString) { + AuthenticationToken token(""); + EXPECT_TRUE(token.empty()); + AuthenticationToken expected(""); + EXPECT_TRUE(token.Equals(expected)); +} +} // namespace rpc +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}