Skip to content

Commit

Permalink
[Core] Remove shard context from RedisClient (ray-project#42095)
Browse files Browse the repository at this point in the history
We only support single shard redis.

Fix a TODO `// TODO (iycheng) Remove shard context from RedisClient`

Signed-off-by: Jiajun Yao <jeromeyjj@gmail.com>
  • Loading branch information
jjyao authored and vickytsang committed Jan 12, 2024
1 parent fa04fab commit 1e4cef0
Show file tree
Hide file tree
Showing 14 changed files with 39 additions and 169 deletions.
6 changes: 3 additions & 3 deletions python/ray/includes/global_state_accessor.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ cdef extern from * namespace "ray::gcs" nogil:
ray::RayLogLevel::WARNING,
"" /* log_dir */);
RedisClientOptions options(host, port, password, false, use_ssl);
RedisClientOptions options(host, port, password, use_ssl);
std::string config_list;
RAY_CHECK(absl::Base64Unescape(config, &config_list));
Expand Down Expand Up @@ -138,7 +138,7 @@ cdef extern from * namespace "ray::gcs" nogil:
const std::string& password,
bool use_ssl,
const std::string& key) {
RedisClientOptions options(host, port, password, false, use_ssl);
RedisClientOptions options(host, port, password, use_ssl);
auto cli = std::make_unique<RedisClient>(options);
instrumented_io_context io_service;
Expand All @@ -156,7 +156,7 @@ cdef extern from * namespace "ray::gcs" nogil:
auto status = cli->Connect(io_service);
RAY_CHECK(status.ok()) << "Failed to connect to redis: " << status.ToString();
auto context = cli->GetShardContext(key);
auto context = cli->GetPrimaryContext();
auto cmd = std::vector<std::string>{"DEL", key};
auto reply = context->RunArgvSync(cmd);
if(reply->ReadAsInteger() == 1) {
Expand Down
2 changes: 0 additions & 2 deletions src/ray/gcs/gcs_client/test/gcs_client_reconnection_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,11 @@ class GcsClientReconnectionTest : public ::testing::Test {

void SetUp() override {
config_.redis_address = "127.0.0.1";
config_.enable_sharding_conn = false;
config_.redis_port = TEST_REDIS_SERVER_PORTS.front();
config_.grpc_server_port = GetFreePort();
config_.grpc_server_name = "MockedGcsServer";
config_.grpc_server_thread_num = 1;
config_.node_ip_address = "127.0.0.1";
config_.enable_sharding_conn = false;
}

void TearDown() override {
Expand Down
2 changes: 0 additions & 2 deletions src/ray/gcs/gcs_client/test/gcs_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ class GcsClientTest : public ::testing::TestWithParam<bool> {
void SetUp() override {
if (!no_redis_) {
config_.redis_address = "127.0.0.1";
config_.enable_sharding_conn = false;
config_.redis_port = TEST_REDIS_SERVER_PORTS.front();
} else {
config_.redis_port = 0;
Expand All @@ -67,7 +66,6 @@ class GcsClientTest : public ::testing::TestWithParam<bool> {
config_.grpc_server_name = "MockedGcsServer";
config_.grpc_server_thread_num = 1;
config_.node_ip_address = "127.0.0.1";
config_.enable_sharding_conn = false;

// Tests legacy code paths. The poller and broadcaster have their own dedicated unit
// test targets.
Expand Down
2 changes: 1 addition & 1 deletion src/ray/gcs/gcs_server/gcs_redis_failure_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void GcsRedisFailureDetector::DetectRedis() {
callback_();
}
};
auto cxt = redis_client_->GetShardContext("");
auto cxt = redis_client_->GetPrimaryContext();
cxt->RunArgvAsync({"PING"}, redis_callback);
}

Expand Down
1 change: 0 additions & 1 deletion src/ray/gcs/gcs_server/gcs_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ RedisClientOptions GcsServer::GetRedisClientOptions() const {
return RedisClientOptions(config_.redis_address,
config_.redis_port,
config_.redis_password,
config_.enable_sharding_conn,
config_.enable_redis_ssl);
}

Expand Down
5 changes: 1 addition & 4 deletions src/ray/gcs/gcs_server/test/redis_gcs_table_storage_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ class RedisGcsTableStorageTest : public gcs::GcsTableStorageTestBase {
static void TearDownTestCase() { TestSetupUtil::ShutDownRedisServers(); }

void SetUp() override {
gcs::RedisClientOptions options("127.0.0.1",
TEST_REDIS_SERVER_PORTS.front(),
"",
/*enable_sharding_conn=*/false);
gcs::RedisClientOptions options("127.0.0.1", TEST_REDIS_SERVER_PORTS.front(), "");
redis_client_ = std::make_shared<gcs::RedisClient>(options);
RAY_CHECK_OK(redis_client_->Connect(io_service_pool_->GetAll()));

Expand Down
93 changes: 1 addition & 92 deletions src/ray/gcs/redis_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,51 +88,6 @@ static int DoGetNextJobID(redisContext *context) {
return counter;
}

static void GetRedisShards(redisContext *context,
std::vector<std::string> *addresses,
std::vector<int> *ports) {
// Get the total number of Redis shards in the system.
redisReply *reply = nullptr;
bool under_retry_limit = RunRedisCommandWithRetries(
context, "GET NumRedisShards", &reply, [](const redisReply *reply) {
return reply != nullptr && reply->type != REDIS_REPLY_NIL;
});
RAY_CHECK(under_retry_limit) << "No entry found for NumRedisShards";
RAY_CHECK(reply->type == REDIS_REPLY_STRING)
<< "Expected string, found Redis type " << reply->type << " for NumRedisShards";
int num_redis_shards = atoi(reply->str);
RAY_CHECK(num_redis_shards >= 1) << "Expected at least one Redis shard, "
<< "found " << num_redis_shards;
freeReplyObject(reply);

// Get the addresses of all of the Redis shards.
under_retry_limit = RunRedisCommandWithRetries(
context,
"LRANGE RedisShards 0 -1",
&reply,
[&num_redis_shards](const redisReply *reply) {
return static_cast<int>(reply->elements) == num_redis_shards;
});
RAY_CHECK(under_retry_limit) << "Expected " << num_redis_shards
<< " Redis shard addresses, found " << reply->elements;

// Parse the Redis shard addresses.
for (size_t i = 0; i < reply->elements; ++i) {
// Parse the shard addresses and ports.
RAY_CHECK(reply->element[i]->type == REDIS_REPLY_STRING);
std::string addr;
std::stringstream ss(reply->element[i]->str);
getline(ss, addr, ':');
addresses->emplace_back(std::move(addr));
int port;
ss >> port;
ports->emplace_back(port);
RAY_LOG(DEBUG) << "Received Redis shard address " << addr << ":" << port
<< " from head GCS.";
}
freeReplyObject(reply);
}

RedisClient::RedisClient(const RedisClientOptions &options) : options_(options) {}

Status RedisClient::Connect(instrumented_io_context &io_service) {
Expand All @@ -154,44 +109,9 @@ Status RedisClient::Connect(std::vector<instrumented_io_context *> io_services)

RAY_CHECK_OK(primary_context_->Connect(options_.server_ip_,
options_.server_port_,
/*sharding=*/options_.enable_sharding_conn_,
/*password=*/options_.password_,
/*enable_ssl=*/options_.enable_ssl_));

if (options_.enable_sharding_conn_) {
// Moving sharding into constructor defaultly means that sharding = true.
// This design decision may worth a look.
std::vector<std::string> addresses;
std::vector<int> ports;
GetRedisShards(primary_context_->sync_context(), &addresses, &ports);
if (addresses.empty()) {
RAY_CHECK(ports.empty());
addresses.push_back(options_.server_ip_);
ports.push_back(options_.server_port_);
}

for (size_t i = 0; i < addresses.size(); ++i) {
size_t io_service_index = (i + 1) % io_services.size();
instrumented_io_context &io_service = *io_services[io_service_index];
// Populate shard_contexts.
shard_contexts_.push_back(std::make_shared<RedisContext>(io_service));
// Only async context is used in sharding context, so we disable the other two.
RAY_CHECK_OK(shard_contexts_[i]->Connect(addresses[i],
ports[i],
/*sharding=*/true,
/*password=*/options_.password_,
/*enable_ssl=*/options_.enable_ssl_));
}
} else {
shard_contexts_.push_back(std::make_shared<RedisContext>(*io_services[0]));
// Only async context is used in sharding context, so wen disable the other two.
RAY_CHECK_OK(shard_contexts_[0]->Connect(options_.server_ip_,
options_.server_port_,
/*sharding=*/true,
/*password=*/options_.password_,
/*enable_ssl=*/options_.enable_ssl_));
}

Attach();

is_connected_ = true;
Expand All @@ -202,12 +122,7 @@ Status RedisClient::Connect(std::vector<instrumented_io_context *> io_services)

void RedisClient::Attach() {
// Take care of sharding contexts.
RAY_CHECK(shard_asio_async_clients_.empty()) << "Attach shall be called only once";
for (std::shared_ptr<RedisContext> context : shard_contexts_) {
instrumented_io_context &io_service = context->io_service();
shard_asio_async_clients_.emplace_back(
new RedisAsioClient(io_service, context->async_context()));
}
RAY_CHECK(!asio_async_auxiliary_client_) << "Attach shall be called only once";
instrumented_io_context &io_service = primary_context_->io_service();
asio_async_auxiliary_client_.reset(
new RedisAsioClient(io_service, primary_context_->async_context()));
Expand All @@ -219,12 +134,6 @@ void RedisClient::Disconnect() {
RAY_LOG(DEBUG) << "RedisClient disconnected.";
}

std::shared_ptr<RedisContext> RedisClient::GetShardContext(const std::string &shard_key) {
// TODO (iycheng) Remove shard context from RedisClient
RAY_CHECK(shard_contexts_.size() == 1);
return shard_contexts_[0];
}

int RedisClient::GetNextJobID() {
RAY_CHECK(primary_context_);
return DoGetNextJobID(primary_context_->sync_context());
Expand Down
16 changes: 1 addition & 15 deletions src/ray/gcs/redis_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,10 @@ class RedisClientOptions {
RedisClientOptions(const std::string &ip,
int port,
const std::string &password,
bool enable_sharding_conn = false,
bool enable_ssl = false)
: server_ip_(ip),
server_port_(port),
password_(password),
enable_sharding_conn_(enable_sharding_conn),
enable_ssl_(enable_ssl) {}

// Redis server address
Expand All @@ -47,9 +45,6 @@ class RedisClientOptions {
// Password of Redis.
std::string password_;

// Whether we enable sharding for accessing data.
bool enable_sharding_conn_ = false;

// Whether to use tls/ssl for redis connection
bool enable_ssl_ = false;
};
Expand Down Expand Up @@ -82,12 +77,6 @@ class RedisClient {
/// Disconnect with Redis. Non-thread safe.
void Disconnect();

std::vector<std::shared_ptr<RedisContext>> GetShardContexts() {
return shard_contexts_;
}

std::shared_ptr<RedisContext> GetShardContext(const std::string &shard_key);

std::shared_ptr<RedisContext> GetPrimaryContext() { return primary_context_; }

int GetNextJobID();
Expand All @@ -102,11 +91,8 @@ class RedisClient {
/// Whether this client is connected to redis.
bool is_connected_{false};

// The following contexts write to the data shard
std::vector<std::shared_ptr<RedisContext>> shard_contexts_;
std::vector<std::unique_ptr<RedisAsioClient>> shard_asio_async_clients_;
std::unique_ptr<RedisAsioClient> asio_async_auxiliary_client_;
// The following context writes everything to the primary shard
std::unique_ptr<RedisAsioClient> asio_async_auxiliary_client_;
std::shared_ptr<RedisContext> primary_context_;
};

Expand Down
3 changes: 1 addition & 2 deletions src/ray/gcs/redis_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,6 @@ std::vector<std::string> ResolveDNS(const std::string &address, int port) {

Status RedisContext::Connect(const std::string &address,
int port,
bool sharding,
const std::string &password,
bool enable_ssl) {
// Connect to the leader of the Redis cluster:
Expand Down Expand Up @@ -532,7 +531,7 @@ Status RedisContext::Connect(const std::string &address,
// Connect to the true leader.
RAY_LOG(INFO) << "Redis cluster leader is " << ip << ":" << port
<< ". Reconnect to it.";
return Connect(ip, port, sharding, password, enable_ssl);
return Connect(ip, port, password, enable_ssl);
} else {
RAY_LOG(INFO) << "Redis cluster leader is " << ip_addresses[0] << ":" << port;
freeReplyObject(redis_reply);
Expand Down
1 change: 0 additions & 1 deletion src/ray/gcs/redis_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ class RedisContext {

Status Connect(const std::string &address,
int port,
bool sharding,
const std::string &password,
bool enable_ssl = false);

Expand Down
64 changes: 27 additions & 37 deletions src/ray/gcs/store_client/redis_store_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ void RedisStoreClient::SendRedisCmd(std::vector<std::string> keys,
}
}
// Send the actual request
auto cxt = redis_client_->GetShardContext("");
auto cxt = redis_client_->GetPrimaryContext();
cxt->RunArgvAsync(std::move(args),
[this,
keys = std::move(keys),
Expand Down Expand Up @@ -362,7 +362,7 @@ Status RedisStoreClient::DeleteByKeys(const std::vector<std::string> &keys,
auto total_count = del_cmds.size();
auto finished_count = std::make_shared<size_t>(0);
auto num_deleted = std::make_shared<int64_t>(0);
auto context = redis_client_->GetShardContext("");
auto context = redis_client_->GetPrimaryContext();
for (auto &command : del_cmds) {
std::vector<std::string> partition_keys(command.begin() + 2, command.end());
auto delete_callback = [num_deleted, finished_count, total_count, callback](
Expand All @@ -388,9 +388,7 @@ RedisStoreClient::RedisScanner::RedisScanner(
: table_name_(table_name),
external_storage_namespace_(external_storage_namespace),
redis_client_(std::move(redis_client)) {
for (size_t index = 0; index < redis_client_->GetShardContexts().size(); ++index) {
shard_to_cursor_[index] = 0;
}
cursor_ = 0;
}

Status RedisStoreClient::RedisScanner::ScanKeysAndValues(
Expand All @@ -405,58 +403,50 @@ Status RedisStoreClient::RedisScanner::ScanKeysAndValues(

void RedisStoreClient::RedisScanner::Scan(const std::string &match_pattern,
const StatusCallback &callback) {
// This lock guards the iterator over shard_to_cursor_ because the callbacks
// can remove items from the shard_to_cursor_ map. If performance is a concern,
// This lock guards cursor_ because the callbacks
// can modify cursor_. If performance is a concern,
// we should consider using a reader-writer lock.
absl::MutexLock lock(&mutex_);
if (shard_to_cursor_.empty()) {
if (!cursor_.has_value()) {
callback(Status::OK());
return;
}

size_t batch_count = RayConfig::instance().maximum_gcs_storage_operation_batch_size();
for (const auto &item : shard_to_cursor_) {
++pending_request_count_;

size_t shard_index = item.first;
size_t cursor = item.second;

auto scan_callback = [this, match_pattern, shard_index, callback](
const std::shared_ptr<CallbackReply> &reply) {
OnScanCallback(match_pattern, shard_index, reply, callback);
};
// Scan by prefix from Redis.
std::vector<std::string> args = {"HSCAN",
external_storage_namespace_,
std::to_string(cursor),
"MATCH",
match_pattern,
"COUNT",
std::to_string(batch_count)};
auto shard_context = redis_client_->GetShardContexts()[shard_index];
shard_context->RunArgvAsync(args, scan_callback);
}
++pending_request_count_;

auto scan_callback =
[this, match_pattern, callback](const std::shared_ptr<CallbackReply> &reply) {
OnScanCallback(match_pattern, reply, callback);
};
// Scan by prefix from Redis.
std::vector<std::string> args = {"HSCAN",
external_storage_namespace_,
std::to_string(cursor_.value()),
"MATCH",
match_pattern,
"COUNT",
std::to_string(batch_count)};
auto primary_context = redis_client_->GetPrimaryContext();
primary_context->RunArgvAsync(args, scan_callback);
}

void RedisStoreClient::RedisScanner::OnScanCallback(
const std::string &match_pattern,
size_t shard_index,
const std::shared_ptr<CallbackReply> &reply,
const StatusCallback &callback) {
RAY_CHECK(reply);
std::vector<std::string> scan_result;
size_t cursor = reply->ReadAsScanArray(&scan_result);
// Update shard cursors and results_.
// Update cursor and results_.
{
absl::MutexLock lock(&mutex_);
auto shard_it = shard_to_cursor_.find(shard_index);
RAY_CHECK(shard_it != shard_to_cursor_.end());
// If cursor is equal to 0, it means that the scan of this shard is finished, so we
// erase it from shard_to_cursor_.
// If cursor is equal to 0, it means that the scan is finished, so we
// reset cursor_.
if (cursor == 0) {
shard_to_cursor_.erase(shard_it);
cursor_.reset();
} else {
shard_it->second = cursor;
cursor_ = cursor;
}
RAY_CHECK(scan_result.size() % 2 == 0);
for (size_t i = 0; i < scan_result.size(); i += 2) {
Expand Down
Loading

0 comments on commit 1e4cef0

Please sign in to comment.