Skip to content

Commit db942cb

Browse files
committed
add a separate class AuthenticationTokenValidator for validating tokens
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
1 parent 8017316 commit db942cb

11 files changed

+173
-98
lines changed

python/ray/includes/rpc_token_authentication.pxd

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ cdef extern from "ray/rpc/authentication/authentication_token_loader.h" namespac
2727
@staticmethod
2828
CAuthenticationTokenLoader& instance()
2929
c_bool HasToken()
30-
c_bool ValidateToken(const CAuthenticationToken& token)
3130
void ResetCache()
3231
optional[CAuthenticationToken] GetToken()
32+
33+
cdef extern from "ray/rpc/authentication/authentication_token_validator.h" namespace "ray::rpc" nogil:
34+
cdef cppclass CAuthenticationTokenValidator "ray::rpc::AuthenticationTokenValidator":
35+
@staticmethod
36+
CAuthenticationTokenValidator& instance()
37+
c_bool ValidateToken(const optional[CAuthenticationToken]& expected_token, const CAuthenticationToken& provided_token)

python/ray/includes/rpc_token_authentication.pxi

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ from ray.includes.rpc_token_authentication cimport (
33
GetAuthenticationMode,
44
CAuthenticationToken,
55
CAuthenticationTokenLoader,
6+
CAuthenticationTokenValidator,
67
)
78
from ray._private.authentication.authentication_constants import AUTHORIZATION_HEADER_NAME
89
import logging
@@ -38,13 +39,18 @@ def validate_authentication_token(provided_token: str) -> bool:
3839
Returns:
3940
bool: True if token is valid, False otherwise
4041
"""
42+
cdef optional[CAuthenticationToken] expected_opt = CAuthenticationTokenLoader.instance().GetToken()
43+
44+
if not expected_opt.has_value():
45+
return False
46+
4147
# Parse provided token from Bearer format
4248
cdef CAuthenticationToken provided = CAuthenticationToken.FromMetadata(provided_token.encode())
4349

4450
if provided.empty():
4551
return False
4652

47-
return CAuthenticationTokenLoader.instance().ValidateToken(provided)
53+
return CAuthenticationTokenValidator.instance().ValidateToken(expected_opt, provided)
4854

4955

5056
class AuthenticationTokenLoader:

src/ray/rpc/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ ray_cc_library(
110110
"//src/ray/rpc/authentication:authentication_mode",
111111
"//src/ray/rpc/authentication:authentication_token",
112112
"//src/ray/rpc/authentication:authentication_token_loader",
113+
"//src/ray/rpc/authentication:authentication_token_validator",
113114
"//src/ray/stats:stats_metric",
114115
"@com_github_grpc_grpc//:grpc++",
115116
],

src/ray/rpc/authentication/BUILD.bazel

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,18 @@ ray_cc_library(
4343
deps = [
4444
":authentication_mode",
4545
":authentication_token",
46-
":k8s_util",
4746
"//src/ray/util:logging",
4847
],
4948
)
49+
50+
ray_cc_library(
51+
name = "authentication_token_validator",
52+
srcs = ["authentication_token_validator.cc"],
53+
hdrs = ["authentication_token_validator.h"],
54+
visibility = ["//visibility:public"],
55+
deps = [
56+
":authentication_mode",
57+
":authentication_token",
58+
":k8s_util",
59+
],
60+
)

src/ray/rpc/authentication/authentication_token.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,5 +174,12 @@ class AuthenticationToken {
174174
}
175175
};
176176

177+
// Hash function for AuthenticationToken
178+
struct AuthenticationTokenHash {
179+
std::size_t operator()(const AuthenticationToken &token) const {
180+
return std::hash<std::string>()(token.ToValue());
181+
}
182+
};
183+
177184
} // namespace rpc
178185
} // namespace ray

src/ray/rpc/authentication/authentication_token_loader.cc

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
#include <string>
2222
#include <utility>
2323

24-
#include "ray/rpc/authentication/k8s_util.h"
2524
#include "ray/util/logging.h"
2625

2726
#ifdef _WIN32
@@ -38,61 +37,11 @@
3837
namespace ray {
3938
namespace rpc {
4039

41-
namespace {
42-
const std::chrono::minutes kCacheTTL(5);
43-
} // namespace
44-
4540
AuthenticationTokenLoader &AuthenticationTokenLoader::instance() {
4641
static AuthenticationTokenLoader instance;
4742
return instance;
4843
}
4944

50-
bool AuthenticationTokenLoader::ValidateToken(const AuthenticationToken &provided_token) {
51-
if (GetAuthenticationMode() == AuthenticationMode::TOKEN) {
52-
auto expected_token = GetToken();
53-
if (!expected_token.has_value()) {
54-
return false;
55-
}
56-
return expected_token->Equals(provided_token);
57-
} else if (GetAuthenticationMode() == AuthenticationMode::K8S) {
58-
std::call_once(k8s::k8s_client_config_flag, k8s::InitK8sClientConfig);
59-
if (!k8s::k8s_client_initialized) {
60-
return false;
61-
}
62-
63-
// Check cache first.
64-
{
65-
std::lock_guard<std::mutex> lock(k8s_token_cache_mutex_);
66-
auto it = k8s_token_cache_.find(provided_token);
67-
if (it != k8s_token_cache_.end()) {
68-
if (std::chrono::steady_clock::now() < it->second.expiration) {
69-
return it->second.allowed;
70-
} else {
71-
k8s_token_cache_.erase(it);
72-
}
73-
}
74-
}
75-
76-
bool is_allowed = false;
77-
is_allowed = k8s::ValidateToken(provided_token);
78-
79-
// Only cache validated tokens for now. We don't want to invalidate a token
80-
// due to unrelated errors from Kubernetes API server. This has the downside of
81-
// causing more load if an unauthenticated client continues to make calls.
82-
// TODO(andrewsykim): cache invalid tokens once k8s::ValidateToken can distinguish
83-
// between invalid token errors and server errors.
84-
if (is_allowed) {
85-
std::lock_guard<std::mutex> lock(k8s_token_cache_mutex_);
86-
k8s_token_cache_[provided_token] = {is_allowed,
87-
std::chrono::steady_clock::now() + kCacheTTL};
88-
}
89-
90-
return is_allowed;
91-
}
92-
93-
return true;
94-
}
95-
9645
std::optional<AuthenticationToken> AuthenticationTokenLoader::GetToken() {
9746
std::lock_guard<std::mutex> lock(token_mutex_);
9847

src/ray/rpc/authentication/authentication_token_loader.h

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,13 @@
1818
#include <mutex>
1919
#include <optional>
2020
#include <string>
21-
#include <unordered_map>
2221

2322
#include "ray/rpc/authentication/authentication_mode.h"
2423
#include "ray/rpc/authentication/authentication_token.h"
25-
#include "ray/rpc/authentication/k8s_util.h"
2624

2725
namespace ray {
2826
namespace rpc {
2927

30-
// Hash function for AuthenticationToken
31-
struct AuthenticationTokenHash {
32-
std::size_t operator()(const AuthenticationToken &token) const {
33-
return std::hash<std::string>()(token.ToValue());
34-
}
35-
};
36-
3728
/// Singleton class for loading and caching authentication tokens.
3829
/// Supports loading tokens from multiple sources with precedence:
3930
/// 1. RAY_AUTH_TOKEN environment variable
@@ -55,14 +46,6 @@ class AuthenticationTokenLoader {
5546
/// \return true if a token exists, false otherwise.
5647
bool HasToken();
5748

58-
/// Validate the provided authentication token.
59-
/// For TOKEN mode, it compares with the loaded token.
60-
/// For K8S mode, it uses Kubernetes TokenReview and SubjectAccessReview APIs.
61-
/// The results for K8S mode are cached.
62-
/// \param provided_token The token to validate.
63-
/// \return true if the token is valid, false otherwise.
64-
bool ValidateToken(const AuthenticationToken &provided_token);
65-
6649
void ResetCache() {
6750
std::lock_guard<std::mutex> lock(token_mutex_);
6851
cached_token_.reset();
@@ -89,15 +72,6 @@ class AuthenticationTokenLoader {
8972

9073
std::mutex token_mutex_;
9174
std::optional<AuthenticationToken> cached_token_;
92-
93-
// Cache for K8s tokens.
94-
struct K8sCacheEntry {
95-
bool allowed;
96-
std::chrono::steady_clock::time_point expiration;
97-
};
98-
std::mutex k8s_token_cache_mutex_;
99-
std::unordered_map<AuthenticationToken, K8sCacheEntry, AuthenticationTokenHash>
100-
k8s_token_cache_;
10175
};
10276

10377
} // namespace rpc
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// Copyright 2025 The Ray Authors.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing,
10+
// software distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions
13+
// and limitations under the License.
14+
15+
#include "ray/rpc/authentication/authentication_token_validator.h"
16+
17+
#include "ray/rpc/authentication/authentication_mode.h"
18+
#include "ray/rpc/authentication/k8s_util.h"
19+
#include "ray/util/logging.h"
20+
21+
namespace ray {
22+
namespace rpc {
23+
24+
const std::chrono::minutes kCacheTTL(5);
25+
26+
AuthenticationTokenValidator &AuthenticationTokenValidator::instance() {
27+
static AuthenticationTokenValidator instance;
28+
return instance;
29+
}
30+
31+
bool AuthenticationTokenValidator::ValidateToken(
32+
const std::optional<AuthenticationToken> &expected_token,
33+
const AuthenticationToken &provided_token) {
34+
if (GetAuthenticationMode() == AuthenticationMode::TOKEN) {
35+
if (!expected_token.has_value() || expected_token->empty()) {
36+
return true; // No auth required on server side
37+
}
38+
39+
return expected_token->Equals(provided_token);
40+
} else if (GetAuthenticationMode() == AuthenticationMode::K8S) {
41+
std::call_once(k8s::k8s_client_config_flag, k8s::InitK8sClientConfig);
42+
if (!k8s::k8s_client_initialized) {
43+
RAY_LOG(WARNING) << "Kubernetes client not initialized, K8s authentication failed.";
44+
return false;
45+
}
46+
47+
// Check cache first.
48+
{
49+
std::lock_guard<std::mutex> lock(k8s_token_cache_mutex_);
50+
auto it = k8s_token_cache_.find(provided_token);
51+
if (it != k8s_token_cache_.end()) {
52+
if (std::chrono::steady_clock::now() < it->second.expiration) {
53+
RAY_LOG(DEBUG) << "K8s token found in cache and is valid.";
54+
return it->second.allowed;
55+
} else {
56+
RAY_LOG(DEBUG) << "K8s token in cache expired, removing from cache.";
57+
k8s_token_cache_.erase(it);
58+
}
59+
}
60+
}
61+
62+
bool is_allowed = false;
63+
is_allowed = k8s::ValidateToken(provided_token);
64+
65+
// Only cache validated tokens for now. We don't want to invalidate a token
66+
// due to unrelated errors from Kubernetes API server. This has the downside of
67+
// causing more load if an unauthenticated client continues to make calls.
68+
// TODO(andrewsykim): cache invalid tokens once k8s::ValidateToken can distinguish
69+
// between invalid token errors and server errors.
70+
if (is_allowed) {
71+
std::lock_guard<std::mutex> lock(k8s_token_cache_mutex_);
72+
k8s_token_cache_[provided_token] = {is_allowed,
73+
std::chrono::steady_clock::now() + kCacheTTL};
74+
RAY_LOG(DEBUG) << "K8s token validated and saved to cached.";
75+
}
76+
77+
return is_allowed;
78+
}
79+
80+
RAY_LOG(DEBUG) << "Authentication mode is disabled, token considered valid.";
81+
return true;
82+
}
83+
84+
} // namespace rpc
85+
} // namespace ray
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Copyright 2025 The Ray Authors.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing,
10+
// software distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions
13+
// and limitations under the License.
14+
15+
#pragma once
16+
17+
#include <optional>
18+
#include <unordered_map>
19+
20+
#include "ray/rpc/authentication/authentication_token.h"
21+
#include "ray/rpc/authentication/k8s_util.h"
22+
23+
namespace ray {
24+
namespace rpc {
25+
26+
class AuthenticationTokenValidator {
27+
public:
28+
static AuthenticationTokenValidator &instance();
29+
/// Validate the provided authentication token against the expected token.
30+
/// When auth_mode=token, this is a simple equality check.
31+
/// When auth_mode=k8s, provided_token is validated against Kubernetes API.
32+
/// \param expected_token The expected token (optional).
33+
/// \param provided_token The token to validate.
34+
/// \return true if the tokens are equal, false otherwise.
35+
bool ValidateToken(const std::optional<AuthenticationToken> &expected_token,
36+
const AuthenticationToken &provided_token);
37+
38+
private:
39+
// Cache for K8s tokens.
40+
struct K8sCacheEntry {
41+
bool allowed;
42+
std::chrono::steady_clock::time_point expiration;
43+
};
44+
std::mutex k8s_token_cache_mutex_;
45+
std::unordered_map<AuthenticationToken, K8sCacheEntry, AuthenticationTokenHash>
46+
k8s_token_cache_;
47+
};
48+
49+
} // namespace rpc
50+
} // namespace ray

src/ray/rpc/grpc_server.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,7 @@ class GrpcServer {
107107
if (auth_token.has_value()) {
108108
auth_token_ = std::move(auth_token.value());
109109
} else {
110-
if (GetAuthenticationMode() == AuthenticationMode::TOKEN) {
111-
auth_token_ = AuthenticationTokenLoader::instance().GetToken();
112-
}
110+
auth_token_ = AuthenticationTokenLoader::instance().GetToken();
113111
}
114112
Init();
115113
}

0 commit comments

Comments
 (0)