Skip to content

Commit

Permalink
Merge pull request #13 from freshworks/scan-command-implementation
Browse files Browse the repository at this point in the history
Adding support for scan command
  • Loading branch information
dinesh-murugiah authored Mar 19, 2024
2 parents 3bfdec3 + 60f3e00 commit 8876f60
Show file tree
Hide file tree
Showing 3 changed files with 398 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ struct SupportedCommands {
static const absl::flat_hash_set<std::string>& allShardCommands() {
CONSTRUCT_ON_FIRST_USE(absl::flat_hash_set<std::string>, "script", "flushall", "pubsub", "keys", "slowlog", "config","client");
}

/**
* @return scan command
*/
static const std::string& scan() { CONSTRUCT_ON_FIRST_USE(std::string, "scan"); }

/**
* @return auth command
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "source/common/common/logger.h"
#include "source/extensions/filters/network/common/redis/supported_commands.h"
#include <string>

namespace Envoy {
namespace Extensions {
Expand Down Expand Up @@ -920,6 +921,326 @@ void MSETRequest::onChildResponse(Common::Redis::RespValuePtr&& value, uint32_t
}
}

Common::Redis::Client::PoolRequest*
makeScanRequest(const RouteSharedPtr& route, int32_t shard_index,
Common::Redis::RespValue& incoming_request,
ConnPool::PoolCallbacks& callbacks,
Common::Redis::Client::Transaction& transaction) {
std::string key = std::string();
Extensions::NetworkFilters::RedisProxy::ConnPool::InstanceImpl* req_instance =
dynamic_cast<Extensions::NetworkFilters::RedisProxy::ConnPool::InstanceImpl*>(
route->upstream(key).get());
auto handler = req_instance->makeRequestNoKey(shard_index, ConnPool::RespVariant(incoming_request),
callbacks, transaction);
return handler;
}

bool requiresValue(const std::string& arg) {
return (arg == "count" || arg == "match" || arg == "type");
}

// Generic function to build an array with bulkstring
void addBulkString(Common::Redis::RespValue& requestArray, const std::string& value) {
Common::Redis::RespValue element;
element.type(Common::Redis::RespType::BulkString);
element.asString() = value;
requestArray.asArray().emplace_back(std::move(element));
}

void ScanRequest::onChildError(Common::Redis::RespValuePtr&& value) {
// Setting null pointer to all pending requests
for (auto& request : pending_requests_) {
request.handle_ = nullptr;
}
// Clearing pending responses
if (!pending_responses_.empty()) {
pending_responses_.clear();
}

Common::Redis::RespValuePtr response_t = std::move(value);
ENVOY_LOG(debug, "response: {}", response_t->toString());
callbacks_.onResponse(std::move(response_t));

}

SplitRequestPtr ScanRequest::create(Router& router, Common::Redis::RespValuePtr&& incoming_request,
SplitCallbacks& callbacks, CommandStats& command_stats,
TimeSource& time_source, bool delay_command_latency,
const StreamInfo::StreamInfo& stream_info) {

// SCAN looks like: SCAN cursor [MATCH pattern] [COUNT count] [TYPE type]
// Ensure there are at least two args to the command or it cannot be scanned.
// Also the number of arguments should be in even number, otherwise command is invalid
if (incoming_request->asArray().size() < 2 || incoming_request->asArray().size() % 2 != 0) {
onWrongNumberOfArguments(callbacks, *incoming_request);
command_stats.error_.inc();
return nullptr;
}

std::string key = std::string();
int32_t shard_idx = 0;
int32_t numofRequests;

Common::Redis::RespValue requestArray;
requestArray.type(Common::Redis::RespType::Array);

// Getting the right cursor before sending request
// We are appending 4 digit custom value as shard index to the cursor returned from the previous scan request
// If it is a first request, the received cursor will be sent directly without any preprocessing
// Ex: If cursor length is more than 4 -> last four digits will be considered as index and the remaining will be taken as cursor
// If cursor length is less than or equal to 4 -> all digits will be considered as cursor
std::string cursor = incoming_request->asArray()[1].asString();
if (cursor.length() > 4) {
std::string index = cursor.substr(cursor.length() - 4);
cursor = cursor.substr(0, cursor.length() - 4);
shard_idx = std::stoi(index);
}

// Completely reconstructing the request to add/modify count since we can't override the incoming array directly
// Add the command and cursor to the request array
addBulkString(requestArray, "SCAN");
addBulkString(requestArray, cursor);

std::unique_ptr<ScanRequest> request_ptr{
new ScanRequest(callbacks, command_stats, time_source, delay_command_latency)};

// TODO : This value should be configurable through protobuf
// Setting default count value to 1000
request_ptr->resp_obj_count_ = "1000";

// Iterate over the arguments modify count if necessary
// We are setting 1000 as the default count value if the incoming request has more than that
for (size_t i = 2; i < incoming_request->asArray().size(); ++i) {
std::string arg = incoming_request->asArray()[i].asString();
addBulkString(requestArray, arg);

// Check if the argument requires a value
if (requiresValue(arg)) {
if (arg == "count") {
std::string count = incoming_request->asArray()[++i].asString();
if (std::stoi(count) < 1000) {
addBulkString(requestArray, count);
request_ptr->resp_obj_count_ = count;
} else {
// Override with default value only when the count value is greater than 1000
addBulkString(requestArray, request_ptr->resp_obj_count_);
}
++i;
} else {
// If the current argument requires a value, add it to the request array
std::string value = incoming_request->asArray()[++i].asString();
addBulkString(requestArray, value);
}
}
}

// If the "count" argument is not found, add it with the default value
if (incoming_request->asArray().size() == 2) {
addBulkString(requestArray, "count");
addBulkString(requestArray, request_ptr->resp_obj_count_);
}

// caching the request and route for making child request from response
request_ptr->request_ = requestArray;
request_ptr->route_ = router.upstreamPool(key, stream_info);

if (request_ptr->route_) {
Extensions::NetworkFilters::RedisProxy::ConnPool::InstanceImpl* instance =
dynamic_cast<Extensions::NetworkFilters::RedisProxy::ConnPool::InstanceImpl*>(
request_ptr->route_->upstream(key).get());

request_ptr->num_of_Shards_ = instance->getNumofRedisShards();
if (request_ptr->num_of_Shards_ == 0 ) {
callbacks.onResponse(Common::Redis::Utility::makeError(Response::get().NoUpstreamHost));
}
}
else{
callbacks.onResponse(Common::Redis::Utility::makeError(Response::get().NoUpstreamHost));
}

// If shard index is some random value, we are setting the shard to 0 to avoid crashing
// This ensures scan always returns value.
// If the current shard index is zero, we assume that we may need to send request to all the shards
if (shard_idx > request_ptr->num_of_Shards_ || shard_idx == 0) {
shard_idx = 0;
numofRequests = request_ptr->num_of_Shards_;
} else {
// If we receive shard index other than zero, we assume that the request should be sent to current shard and all the remaining next shards.
numofRequests = request_ptr->num_of_Shards_ - shard_idx;
}

// Reserving memory for pending_requests and pending_responses
request_ptr->pending_requests_.reserve(numofRequests);
request_ptr->pending_responses_.reserve(numofRequests);

// Pending requests will be popped from the back so, pending request index is incremented and and shard index is decremented
//pi => Pending request index
//si => Shard index
for(int32_t pi = 0, si = request_ptr->num_of_Shards_-1; pi < numofRequests && si >= shard_idx; pi++, si-- ) {
request_ptr->pending_requests_.emplace_back(*request_ptr, pi, si);
}

PendingRequest& pending_request = request_ptr->pending_requests_.back();
if (request_ptr->route_) {
pending_request.handle_= makeScanRequest(request_ptr->route_, shard_idx, requestArray, pending_request, callbacks.transaction());
}

if (!pending_request.handle_) {
pending_request.onResponse(Common::Redis::Utility::makeError(Response::get().NoUpstreamHost));
return nullptr;
}

return request_ptr;
}

void ScanRequest::onChildResponse(Common::Redis::RespValuePtr&& value, int32_t index, int32_t shard_index) {

if (value->type() == Common::Redis::RespType::Error){
ENVOY_LOG(debug,"recived error for index: '{}'", shard_index);
onChildError(std::move(value));
} else {
// Request handled successfully
pending_requests_[index].handle_ = nullptr;
// Moving the response to pending response for doing validation later
// Incrementing the number of pending responses, it will drained during the validation
int64_t count = std::stoi(resp_obj_count_);

// Checking the cursor and number of objects for child request
std::string cursor = value->asArray()[0].asString();
int64_t objectsReceived = value->asArray()[1].asArray().size();
std::string objectsRemaining = std::to_string(count - objectsReceived);

// Resizing pending repsonses array based on the incoming reponses
if (index >= static_cast<int32_t>(pending_responses_.size())) {
// Resize the vector to accommodate the new index
pending_responses_.resize(index + 1);
}
ENVOY_LOG(debug,"response recived for index: '{}'", shard_index);

pending_responses_[num_pending_responses_++] = std::move(value);
bool send_response = true;

// Following conditions needs to be satisfied for making a child request
// 1) If cursor is zero, objects less than count
// 1.1) If No more shards to scan, there won't be any child request
// 1.2) If shards present, increment the shard index, update the count value, set cursor to 0 and send child request
// 2) If cursor is not zero
// 2.1) If objects returned is equal to count, there won't be any child request

if (cursor == "0" && objectsReceived < count && shard_index+1 < num_of_Shards_) {
// Popping the pending request, not needed anymore
pending_requests_.pop_back();
send_response = false;
// Cursor is always zero for child request
// Create the child request based on the original request
Common::Redis::RespValue child_request = request_;

// Set the cursor to "0" in the child request
child_request.asArray()[1].asString() = "0";
// Setting the new count for the child request
for (size_t i = 2; i < child_request.asArray().size() - 1; ++i) {
if (child_request.asArray()[i].asString() == "count") {
child_request.asArray()[i + 1].asString() = objectsRemaining;
break; // Stop searching after updating count
}
}
ENVOY_LOG(debug, "Child request: {}", request_.toString());
PendingRequest& pending_request = pending_requests_.back();
pending_request.handle_= makeScanRequest(route_, pending_request.shard_index_, child_request, pending_request, callbacks_.transaction());
if (!pending_request.handle_) {
onChildError(Common::Redis::Utility::makeError(Response::get().NoUpstreamHost));
}
}

if (send_response) {
// Setting null to the stale pending_requests_
for (auto& request : pending_requests_) {
request.handle_ = nullptr;
}
Common::Redis::RespValuePtr response = std::make_unique<Common::Redis::RespValue>();
response->type(Common::Redis::RespType::Array);

// Process cursor -> Append cursor with shard index so that next request will come to corresponding shard
// if cursor is non zero, there are few elements remaining in the shard so set the index to the same shard
// if cursor is zero, but the count is satisfied, the next request should go to next shard. so increament the index
if (cursor != "0" ) {
std::string indexStr = std::to_string(shard_index);
cursor += std::string(4 - indexStr.length(), '0') + indexStr;
}
if (cursor == "0" && shard_index+1 < num_of_Shards_) {
std::string nextIndexStr = std::to_string(shard_index+1);
cursor += std::string(4 - nextIndexStr.length(), '0') + nextIndexStr;
}

// Response array will be created in the following format
// [0] -> latest scan cursor for next iteration
// [1] -> array for returned objects from multiple pending responses

// Setting the cursor from last response, since that is the latest one
// We need iterate from here in the next request.
Common::Redis::RespValue cstr;
cstr.type(Common::Redis::RespType::BulkString);
cstr.asString() = std::move(cursor);
response->asArray().emplace_back(std::move(cstr));

// Creating a temporary array to hold the objects from multiple pending responses
Common::Redis::RespValue objArray;
objArray.type(Common::Redis::RespType::Array);

// Iterate through the pending responses
for (size_t i = 0; i < pending_responses_.size(); ++i) {
if (pending_responses_[i] != nullptr) {
auto& resp = pending_responses_[i];
if (resp->type() == Common::Redis::RespType::Array) {
auto& obj = resp->asArray()[1];
if (obj.type() == Common::Redis::RespType::Array) {
// Iterate through the inner objects array and add its elements to the object array
for (size_t k = 0; k < obj.asArray().size(); ++k) {
objArray.asArray().emplace_back(std::move(obj.asArray()[k]));
}
} else {
ENVOY_LOG(debug, "received non array response for the objects while scanning");
}
} else {
ENVOY_LOG(debug, "received non array response for the scan command");
}
} else {
ENVOY_LOG(debug, "received null pointer as response from one of the shard");
}
}
pending_responses_.clear();
// Add the object array to the main response array as a nested array
response->asArray().emplace_back(std::move(objArray));
updateStats(error_count_ == 0);
Common::Redis::RespValuePtr response_t = std::move(response);
ENVOY_LOG(debug, "response: {}", response_t->toString());
callbacks_.onResponse(std::move(response_t));
}
}
}

void ScanRequestBase::onChildFailure(int32_t reqindex,int32_t shardindex) {
updateStats(false);
onChildResponse(Common::Redis::Utility::makeError(Response::get().UpstreamFailure), reqindex,shardindex);
}

ScanRequestBase::~ScanRequestBase() {
#ifndef NDEBUG
for (const PendingRequest& request : pending_requests_) {
ASSERT(!request.handle_);
}
#endif
}

void ScanRequestBase::cancel() {
for (PendingRequest& request : pending_requests_) {
if (request.handle_) {
request.handle_->cancel();
request.handle_ = nullptr;
}
}
}

SplitRequestPtr
SplitKeysSumResultRequest::create(Router& router, Common::Redis::RespValuePtr&& incoming_request,
SplitCallbacks& callbacks, CommandStats& command_stats,
Expand Down Expand Up @@ -1093,7 +1414,7 @@ InstanceImpl::InstanceImpl(RouterPtr&& router, Stats::Scope& scope, const std::s
: router_(std::move(router)), simple_command_handler_(*router_),
eval_command_handler_(*router_), mget_handler_(*router_), mset_handler_(*router_),
split_keys_sum_result_handler_(*router_),
transaction_handler_(*router_),mgmt_nokey_request_handler_(*router_),subscription_handler_(*router_),
transaction_handler_(*router_),mgmt_nokey_request_handler_(*router_),scanrequest_handler_(*router_),subscription_handler_(*router_),
blocking_client_request_handler_(*router_),stats_{ALL_COMMAND_SPLITTER_STATS(POOL_COUNTER_PREFIX(scope, stat_prefix + "splitter."))},
time_source_(time_source), fault_manager_(std::move(fault_manager)) {
for (const std::string& command : Common::Redis::SupportedCommands::simpleCommands()) {
Expand All @@ -1114,6 +1435,9 @@ InstanceImpl::InstanceImpl(RouterPtr&& router, Stats::Scope& scope, const std::s

addHandler(scope, stat_prefix, Common::Redis::SupportedCommands::mset(), latency_in_micros,
mset_handler_);

addHandler(scope, stat_prefix, Common::Redis::SupportedCommands::scan(), latency_in_micros,
scanrequest_handler_);

for (const std::string& command : Common::Redis::SupportedCommands::transactionCommands()) {
addHandler(scope, stat_prefix, command, latency_in_micros, transaction_handler_);
Expand Down
Loading

0 comments on commit 8876f60

Please sign in to comment.