Skip to content

Commit 7bde811

Browse files
author
sampan
committed
Merge remote-tracking branch 'upstream/grpc_auth_1' into grpc_auth_2
Signed-off-by: sampan <sampan@anyscale.com>
2 parents 34aa7a3 + c079298 commit 7bde811

File tree

5 files changed

+234
-21
lines changed

5 files changed

+234
-21
lines changed

src/ray/gcs/gcs_server.cc

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -215,21 +215,21 @@ void GcsServer::Start() {
215215
// Init KV Manager. This needs to be initialized first here so that
216216
// it can be used to retrieve the cluster ID.
217217
InitKVManager();
218-
gcs_init_data->AsyncLoad({[this, gcs_init_data] {
219-
GetOrGenerateClusterId(
220-
{[this, gcs_init_data](ClusterID cluster_id) {
221-
rpc_server_.SetClusterId(cluster_id);
222-
// Load and set authentication token if enabled
223-
if (RayConfig::instance().enable_token_auth()) {
224-
rpc_server_.SetAuthToken(
225-
rpc::RayAuthTokenLoader::instance().GetToken(
226-
false));
227-
}
228-
DoStart(*gcs_init_data);
229-
},
230-
io_context_provider_.GetDefaultIOContext()});
231-
},
232-
io_context_provider_.GetDefaultIOContext()});
218+
gcs_init_data->AsyncLoad(
219+
{[this, gcs_init_data] {
220+
GetOrGenerateClusterId(
221+
{[this, gcs_init_data](ClusterID cluster_id) {
222+
rpc_server_.SetClusterId(cluster_id);
223+
// Load and set authentication token if enabled
224+
if (RayConfig::instance().enable_token_auth()) {
225+
rpc_server_.SetAuthToken(
226+
rpc::RayAuthTokenLoader::instance().GetToken(false));
227+
}
228+
DoStart(*gcs_init_data);
229+
},
230+
io_context_provider_.GetDefaultIOContext()});
231+
},
232+
io_context_provider_.GetDefaultIOContext()});
233233
}
234234

235235
void GcsServer::GetOrGenerateClusterId(

src/ray/rpc/auth_token_loader.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <fstream>
1818
#include <random>
1919
#include <sstream>
20+
#include <string>
2021

2122
#include "ray/util/logging.h"
2223
#include "ray/util/util.h"
@@ -102,8 +103,7 @@ std::string RayAuthTokenLoader::LoadTokenFromSources() {
102103
token.erase(0, token.find_first_not_of(" \t\n\r\f\v"));
103104
token.erase(token.find_last_not_of(" \t\n\r\f\v") + 1);
104105
if (!token.empty()) {
105-
RAY_LOG(DEBUG) << "Loaded authentication token from default path: "
106-
<< default_path;
106+
RAY_LOG(DEBUG) << "Loaded authentication token from default path: " << default_path;
107107
return token;
108108
}
109109
}
@@ -151,8 +151,7 @@ std::string RayAuthTokenLoader::GenerateToken() {
151151
chmod(token_path.c_str(), S_IRUSR | S_IWUSR);
152152
#endif
153153

154-
RAY_LOG(INFO) << "Generated new authentication token and saved to "
155-
<< token_path;
154+
RAY_LOG(INFO) << "Generated new authentication token and saved to " << token_path;
156155
} else {
157156
RAY_LOG(WARNING) << "Failed to save generated token to " << token_path
158157
<< ". Token will only be available in memory.";
@@ -195,4 +194,3 @@ std::string RayAuthTokenLoader::GetDefaultTokenPath() {
195194

196195
} // namespace rpc
197196
} // namespace ray
198-

src/ray/rpc/auth_token_loader.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,3 @@ class RayAuthTokenLoader {
6666

6767
} // namespace rpc
6868
} // namespace ray
69-
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
// Copyright 2017 The Ray Authors.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "ray/rpc/auth_token_loader.h"
16+
17+
#include <fstream>
18+
#include <string>
19+
#include <thread>
20+
#include <vector>
21+
22+
#include "gtest/gtest.h"
23+
#include "ray/util/logging.h"
24+
25+
namespace ray {
26+
namespace rpc {
27+
28+
class RayAuthTokenLoaderTest : public ::testing::Test {
29+
protected:
30+
void SetUp() override {
31+
// Clean up environment variables before each test
32+
unsetenv("RAY_AUTH_TOKEN");
33+
unsetenv("RAY_AUTH_TOKEN_PATH");
34+
35+
// Clean up default token file
36+
std::string home_dir = getenv("HOME");
37+
default_token_path_ = home_dir + "/.ray/auth_token";
38+
remove(default_token_path_.c_str());
39+
}
40+
41+
void TearDown() override {
42+
// Clean up after test
43+
unsetenv("RAY_AUTH_TOKEN");
44+
unsetenv("RAY_AUTH_TOKEN_PATH");
45+
remove(default_token_path_.c_str());
46+
}
47+
48+
std::string default_token_path_;
49+
};
50+
51+
TEST_F(RayAuthTokenLoaderTest, TestLoadFromEnvVariable) {
52+
// Set token in environment variable
53+
setenv("RAY_AUTH_TOKEN", "test-token-from-env", 1);
54+
55+
// Create a new instance to avoid cached state
56+
auto &loader = RayAuthTokenLoader::instance();
57+
std::string token = loader.GetToken(false);
58+
59+
EXPECT_EQ(token, "test-token-from-env");
60+
EXPECT_TRUE(loader.HasToken());
61+
}
62+
63+
TEST_F(RayAuthTokenLoaderTest, TestLoadFromEnvPath) {
64+
// Create a temporary token file
65+
std::string temp_token_path = "/tmp/ray_test_token_" + std::to_string(getpid());
66+
std::ofstream token_file(temp_token_path);
67+
token_file << "test-token-from-file";
68+
token_file.close();
69+
70+
// Set path in environment variable
71+
setenv("RAY_AUTH_TOKEN_PATH", temp_token_path.c_str(), 1);
72+
73+
auto &loader = RayAuthTokenLoader::instance();
74+
std::string token = loader.GetToken(false);
75+
76+
EXPECT_EQ(token, "test-token-from-file");
77+
EXPECT_TRUE(loader.HasToken());
78+
79+
// Clean up
80+
remove(temp_token_path.c_str());
81+
}
82+
83+
TEST_F(RayAuthTokenLoaderTest, TestLoadFromDefaultPath) {
84+
// Create directory
85+
std::string ray_dir = std::string(getenv("HOME")) + "/.ray";
86+
mkdir(ray_dir.c_str(), 0700);
87+
88+
// Create token file in default location
89+
std::ofstream token_file(default_token_path_);
90+
token_file << "test-token-from-default";
91+
token_file.close();
92+
93+
auto &loader = RayAuthTokenLoader::instance();
94+
std::string token = loader.GetToken(false);
95+
96+
EXPECT_EQ(token, "test-token-from-default");
97+
EXPECT_TRUE(loader.HasToken());
98+
}
99+
100+
TEST_F(RayAuthTokenLoaderTest, TestPrecedenceOrder) {
101+
// Set all three sources
102+
setenv("RAY_AUTH_TOKEN", "token-from-env", 1);
103+
104+
std::string temp_token_path = "/tmp/ray_test_token_" + std::to_string(getpid());
105+
std::ofstream temp_file(temp_token_path);
106+
temp_file << "token-from-path";
107+
temp_file.close();
108+
setenv("RAY_AUTH_TOKEN_PATH", temp_token_path.c_str(), 1);
109+
110+
std::string ray_dir = std::string(getenv("HOME")) + "/.ray";
111+
mkdir(ray_dir.c_str(), 0700);
112+
std::ofstream default_file(default_token_path_);
113+
default_file << "token-from-default";
114+
default_file.close();
115+
116+
// Environment variable should have highest precedence
117+
auto &loader = RayAuthTokenLoader::instance();
118+
std::string token = loader.GetToken(false);
119+
120+
EXPECT_EQ(token, "token-from-env");
121+
122+
// Clean up
123+
remove(temp_token_path.c_str());
124+
}
125+
126+
TEST_F(RayAuthTokenLoaderTest, TestNoTokenFound) {
127+
// No token set anywhere
128+
auto &loader = RayAuthTokenLoader::instance();
129+
std::string token = loader.GetToken(false);
130+
131+
EXPECT_EQ(token, "");
132+
EXPECT_FALSE(loader.HasToken());
133+
}
134+
135+
TEST_F(RayAuthTokenLoaderTest, TestGenerateToken) {
136+
// No token exists, but request generation
137+
auto &loader = RayAuthTokenLoader::instance();
138+
std::string token = loader.GetToken(true);
139+
140+
// Token should be generated (32 character hex string)
141+
EXPECT_EQ(token.length(), 32);
142+
EXPECT_TRUE(loader.HasToken());
143+
144+
// Token should be saved to default path
145+
std::ifstream token_file(default_token_path_);
146+
EXPECT_TRUE(token_file.is_open());
147+
std::string saved_token;
148+
std::getline(token_file, saved_token);
149+
EXPECT_EQ(saved_token, token);
150+
}
151+
152+
TEST_F(RayAuthTokenLoaderTest, TestCaching) {
153+
// Set token in environment
154+
setenv("RAY_AUTH_TOKEN", "cached-token", 1);
155+
156+
auto &loader = RayAuthTokenLoader::instance();
157+
std::string token1 = loader.GetToken(false);
158+
159+
// Change environment variable (shouldn't affect cached value)
160+
setenv("RAY_AUTH_TOKEN", "new-token", 1);
161+
std::string token2 = loader.GetToken(false);
162+
163+
// Should still return the cached token
164+
EXPECT_EQ(token1, token2);
165+
EXPECT_EQ(token2, "cached-token");
166+
}
167+
168+
TEST_F(RayAuthTokenLoaderTest, TestThreadSafety) {
169+
// Set a token
170+
setenv("RAY_AUTH_TOKEN", "thread-safe-token", 1);
171+
172+
auto &loader = RayAuthTokenLoader::instance();
173+
174+
// Create multiple threads that try to get token simultaneously
175+
std::vector<std::thread> threads;
176+
std::vector<std::string> results(10);
177+
178+
for (int i = 0; i < 10; i++) {
179+
threads.emplace_back(
180+
[&loader, &results, i]() { results[i] = loader.GetToken(false); });
181+
}
182+
183+
// Wait for all threads to complete
184+
for (auto &thread : threads) {
185+
thread.join();
186+
}
187+
188+
// All threads should get the same token
189+
for (const auto &result : results) {
190+
EXPECT_EQ(result, "thread-safe-token");
191+
}
192+
}
193+
194+
TEST_F(RayAuthTokenLoaderTest, TestWhitespaceHandling) {
195+
// Create token file with whitespace
196+
std::string ray_dir = std::string(getenv("HOME")) + "/.ray";
197+
mkdir(ray_dir.c_str(), 0700);
198+
std::ofstream token_file(default_token_path_);
199+
token_file << " token-with-spaces \n\t";
200+
token_file.close();
201+
202+
auto &loader = RayAuthTokenLoader::instance();
203+
std::string token = loader.GetToken(false);
204+
205+
// Whitespace should be trimmed
206+
EXPECT_EQ(token, "token-with-spaces");
207+
}
208+
209+
} // namespace rpc
210+
} // namespace ray
211+
212+
int main(int argc, char **argv) {
213+
::testing::InitGoogleTest(&argc, argv);
214+
return RUN_ALL_TESTS();
215+
}

src/ray/rpc/tests/grpc_server_client_test.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include <chrono>
1616
#include <memory>
17+
#include <string>
1718
#include <vector>
1819

1920
#include "gtest/gtest.h"

0 commit comments

Comments
 (0)