diff --git a/source/extensions/filters/network/common/redis/client_impl.cc b/source/extensions/filters/network/common/redis/client_impl.cc index 35710d1d7b46..673e771928e8 100644 --- a/source/extensions/filters/network/common/redis/client_impl.cc +++ b/source/extensions/filters/network/common/redis/client_impl.cc @@ -86,7 +86,6 @@ ClientImpl::ClientImpl(Upstream::HostConstSharedPtr host, Event::Dispatcher& dis time_source_(dispatcher.timeSource()), redis_command_stats_(redis_command_stats), scope_(scope), is_transaction_client_(is_transaction_client), is_pubsub_client_(is_pubsub_client), is_blocking_client_(is_blocking_client) { - ENVOY_LOG(debug,"ClientImpl Constructor creating client of type: is_transaction_client: {}, is_pubsub_client: {}, is_blocking_client: {}", is_transaction_client_, is_pubsub_client_, is_blocking_client_); Upstream::ClusterTrafficStats& traffic_stats = *host->cluster().trafficStats(); traffic_stats.upstream_cx_total_.inc(); host->stats().cx_total_.inc(); @@ -96,6 +95,17 @@ ClientImpl::ClientImpl(Upstream::HostConstSharedPtr host, Event::Dispatcher& dis if (is_pubsub_client_){ pubsub_cb_ = std::move(pubsubcb); } + if (!is_transaction_client_ && !is_pubsub_client_ && !is_blocking_client_){ + ENVOY_LOG(debug, "Upstream Client created of type ThreadLocal Active Client"); +}else if (is_transaction_client_){ + ENVOY_LOG(debug, "Upstream Client created of type Transaction Client"); + }else if (is_pubsub_client_){ + ENVOY_LOG(debug, "Upstream Client created of type Pubsub Client"); + }else if (is_blocking_client_){ + ENVOY_LOG(debug, "Upstream Client created of type Blocking Client"); + }else{ + ENVOY_LOG(error, "Upstream Client created of type Unknown Client"); + } } ClientImpl::~ClientImpl() { @@ -104,7 +114,7 @@ ClientImpl::~ClientImpl() { host_->cluster().trafficStats()->upstream_cx_active_.dec(); host_->stats().cx_active_.dec(); pubsub_cb_.reset(); - + ENVOY_LOG(debug, "Upstream Client destroyed"); } void ClientImpl::close() { @@ -202,7 +212,6 @@ bool ClientImpl::makePubSubRequest(const RespValue& request) { void ClientImpl::onConnectOrOpTimeout() { - ENVOY_LOG(debug, "Upstream Client Connection or Operation timeout occurred, is blocking client: {}", is_blocking_client_); putOutlierEvent(Upstream::Outlier::Result::LocalOriginTimeout); if (connected_) { host_->cluster().trafficStats()->upstream_rq_timeout_.inc(); @@ -217,6 +226,17 @@ void ClientImpl::onConnectOrOpTimeout() { } else { ENVOY_LOG(debug, "Ignoring timeout for connection close for Blocking clients!"); } + if (!is_transaction_client_ && !is_pubsub_client_ && !is_blocking_client_){ + ENVOY_LOG(debug, "Upstream Client onConnectOrOpTimeout for ThreadLocal Active Client"); + }else if (is_transaction_client_){ + ENVOY_LOG(debug, "Upstream Client onConnectOrOpTimeout for Transaction Client"); + }else if (is_pubsub_client_){ + ENVOY_LOG(debug, "Upstream Client onConnectOrOpTimeout for Pubsub Client"); + }else if (is_blocking_client_){ + ENVOY_LOG(debug, "Upstream Client onConnectOrOpTimeout for Blocking Client"); + }else{ + ENVOY_LOG(error, "Upstream Client onConnectOrOpTimeout for Unknown Client"); + } } void ClientImpl::onData(Buffer::Instance& data) { @@ -332,7 +352,7 @@ void ClientImpl::onRespValue(RespValuePtr&& value) { pending_requests_.pop_front(); if (canceled) { host_->cluster().trafficStats()->upstream_rq_cancelled_.inc(); - } else if (config_.enableRedirection() && (!is_blocking_client_ || !is_transaction_client_) && + } else if (config_.enableRedirection() && (!is_blocking_client_ && !is_transaction_client_ && !is_pubsub_client_) && (value->type() == Common::Redis::RespType::Error)) { std::vector err = StringUtil::splitToken(value->asString(), " ", false); if (err.size() == 3 && diff --git a/source/extensions/filters/network/common/redis/supported_commands.h b/source/extensions/filters/network/common/redis/supported_commands.h index 21fa226aa7fb..39ba46a330ec 100644 --- a/source/extensions/filters/network/common/redis/supported_commands.h +++ b/source/extensions/filters/network/common/redis/supported_commands.h @@ -33,11 +33,10 @@ struct SupportedCommands { "zrangebylex", "zrangebyscore", "zrank", "zrem", "zremrangebylex", "zremrangebyrank", "zremrangebyscore", "zrevrange", "zrevrangebylex", "zrevrangebyscore", "zrevrank", "zscan", "zscore", "rpoplpush", "smove", "sunion", "sdiff", "sinter", "sinterstore", "zunionstore", - "zinterstore", "pfmerge", "georadius", "georadiusbymember", "xadd", "xlen", "xdel", "xtrim", - "xrange", "xrevrange", "rename", "getex", "sort", "zmscore", "sdiffstore", "msetnx", "substr", - "zrangestore", "zunion", "echo", "zdiff", "xautoclaim", "xinfo", "sunionstore", "smismember", + "zinterstore", "pfmerge", "georadius", "georadiusbymember", "rename", "getex", "sort", "zmscore", "sdiffstore", "msetnx", "substr", + "zrangestore", "zunion", "echo", "zdiff", "sunionstore", "smismember", "hrandfield", "geosearchstore", "zdiffstore", "geosearch", "randomkey", "zinter", "zrandmember", - "bitop", "xclaim", "lpos", "renamenx", "xgroup","xreadnonblock"); + "bitop", "lpos", "renamenx","xread_simple_command"); } /** @@ -93,7 +92,7 @@ struct SupportedCommands { * @return commands that are called blocking commands but not pubsub commands. */ static const absl::flat_hash_set& blockingCommands() { - CONSTRUCT_ON_FIRST_USE(absl::flat_hash_set, "blpop", "brpop", "brpoplpush", "bzpopmax", "bzpopmin", "xreadblock", "xreadgroup", "blmove"); + CONSTRUCT_ON_FIRST_USE(absl::flat_hash_set, "blpop", "brpop", "brpoplpush", "bzpopmax", "bzpopmin", "xread_blocking_command", "blmove"); } /** @@ -124,6 +123,20 @@ struct SupportedCommands { CONSTRUCT_ON_FIRST_USE(absl::flat_hash_set, "script", "flushall", "flushdb", "pubsub", "keys", "slowlog", "config", "client", "info", "select", "unwatch"); } + /** + * @return commands which handle Redis Streams. + */ + static const absl::flat_hash_set& streamCommands() { + CONSTRUCT_ON_FIRST_USE(absl::flat_hash_set, "xack", "xadd", "xautoclaim", "xclaim", "xdel", "xgroup", "xinfo", "xlen", "xpending", "xrange", "xread","xreadgroup","xrevrange","xtrim"); + } + + /** + * @return List of stream commands which can be configured in blocking mode. + */ + static const absl::flat_hash_set& streamBlockingCommands() { + CONSTRUCT_ON_FIRST_USE(absl::flat_hash_set, "xread","xreadgroup"); + } + /** * @return scan command */ @@ -174,10 +187,6 @@ struct SupportedCommands { */ static const std::string& info() { CONSTRUCT_ON_FIRST_USE(std::string, "info"); } - /** - * @return special stream commands - */ - static const std::string& spl_strm_commands() { CONSTRUCT_ON_FIRST_USE(std::string, "xread"); } /** * @return commands which alters the state of redis */ diff --git a/source/extensions/filters/network/redis_proxy/command_splitter_impl.cc b/source/extensions/filters/network/redis_proxy/command_splitter_impl.cc index d93233a33433..9b9fc980faf1 100644 --- a/source/extensions/filters/network/redis_proxy/command_splitter_impl.cc +++ b/source/extensions/filters/network/redis_proxy/command_splitter_impl.cc @@ -61,13 +61,6 @@ AdminRespHandlerType getresponseHandlerType(const std::string& command_name) { {"publish", AdminRespHandlerType::singleshardresponse}, {"cluster", AdminRespHandlerType::singleshardresponse}, {"flushdb", AdminRespHandlerType::allresponses_mustbe_same}, - {"xadd", AdminRespHandlerType::singleshardresponse}, - {"xread", AdminRespHandlerType::singleshardresponse}, - {"xlen", AdminRespHandlerType::singleshardresponse}, - {"xdel", AdminRespHandlerType::singleshardresponse}, - {"xtrim", AdminRespHandlerType::singleshardresponse}, - {"xrange", AdminRespHandlerType::singleshardresponse}, - {"xrevrange", AdminRespHandlerType::singleshardresponse}, {"rename", AdminRespHandlerType::singleshardresponse}, {"unwatch", AdminRespHandlerType::allresponses_mustbe_same}, // Add more mappings as needed @@ -95,9 +88,9 @@ int32_t getShardIndex(const std::string command, int32_t requestsCount,int32_t r bool isBlockingCommand = Common::Redis::SupportedCommands::blockingCommands().contains(command); bool isAllShardCommand = Common::Redis::SupportedCommands::allShardCommands().contains(command); - bool isXreadBlockingCommand = (command == "xread" || command == "xreadgroup"); + - if (!isBlockingCommand && !isAllShardCommand && requestsCount == 1 && !isXreadBlockingCommand){ + if (!isBlockingCommand && !isAllShardCommand && requestsCount == 1 ){ // Send request to a random shard so that we donot allways send to the same shard shard_index = rand() % redisShardsCount; } @@ -270,35 +263,19 @@ SplitRequestPtr SimpleRequest::create(Router& router, TimeSource& time_source, bool delay_command_latency, const StreamInfo::StreamInfo& stream_info) { std::string command_name = absl::AsciiStrToLower(incoming_request->asArray()[0].asString()); - std::string key =""; + int32_t shardKeyIndex = getShardingKeyIndex(command_name,*incoming_request); + if (shardKeyIndex < 0) { + ENVOY_LOG(debug, "unexpected command : '{}'", incoming_request->toString()); + callbacks.onResponse(Common::Redis::Utility::makeError(fmt::format("unexpected command format"))); + return nullptr; + } + std::string key =incoming_request->asArray()[shardKeyIndex].asString();; std::unique_ptr request_ptr{ new SimpleRequest(callbacks, command_stats, time_source, delay_command_latency)}; - const auto route = router.upstreamPool(incoming_request->asArray()[1].asString(), stream_info); + const auto route = router.upstreamPool(incoming_request->asArray()[shardKeyIndex].asString(), stream_info); if (route) { Common::Redis::RespValueSharedPtr base_request = std::move(incoming_request); - if (command_name == "xread"){ - int32_t index = 0; - int32_t count = base_request->asArray().size(); - while (count > 0) { - if (absl::AsciiStrToLower(base_request->asArray()[index].asString())== "streams") { - index++; - key = base_request->asArray()[index].asString(); - break; - } - index++; - count--; - } - if (key.empty()) { - ENVOY_LOG(debug, "unexpected command : '{}'", base_request->toString()); - callbacks.onResponse(Common::Redis::Utility::makeError(fmt::format("unexpected command format"))); - return nullptr; - } - - }else { - key = base_request->asArray()[1].asString(); - } - request_ptr->handle_ = makeSingleServerRequest( route, base_request->asArray()[0].asString(), key, base_request, *request_ptr, callbacks.transaction()); @@ -315,6 +292,31 @@ SplitRequestPtr SimpleRequest::create(Router& router, return request_ptr; } +int32_t SimpleRequest::getShardingKeyIndex(const std::string command_name, const Common::Redis::RespValue& request) { + if (command_name == "xread" || command_name == "xreadgroup") { + int32_t count = request.asArray().size(); + for (int32_t index = 0; index < count; ++index) { + if (absl::AsciiStrToLower(request.asArray()[index].asString()) == "streams") { + if (index + 1 < count) { + return index + 1; // Return the index of the key after "streams" + } else { + return -1; // "streams" is the last element + } + } + } + return -1; // "streams" not found + } else if (command_name == "xgroup" || command_name == "xinfo") { + if (request.asArray().size() > 2) { + return 2; // Return index 2 if there are more than 2 elements + } else { + return -1; // Not enough elements + } + } else { + return 1; // Default case for other commands + } +} + + SplitRequestPtr EvalRequest::create(Router& router, Common::Redis::RespValuePtr&& incoming_request, SplitCallbacks& callbacks, CommandStats& command_stats, TimeSource& time_source, bool delay_command_latency, @@ -375,6 +377,7 @@ AdministrationRequest::~AdministrationRequest() { ASSERT(!request.handle_); } #endif +ENVOY_LOG(debug, "AdministrationRequest::~AdministrationRequest()"); } void AdministrationRequest::cancel() { @@ -589,7 +592,8 @@ void mgmtNoKeyRequest::onallChildRespAgrregate(Common::Redis::RespValuePtr&& val if (!pending_responses_.empty()) { Common::Redis::RespValuePtr response = std::move(pending_responses_[response_index_]); callbacks_.onResponse(std::move(response)); - pending_responses_.clear(); + //pending_responses_.clear(); + return; } } else { bool positiveresponse = true; @@ -616,7 +620,8 @@ void mgmtNoKeyRequest::onallChildRespAgrregate(Common::Redis::RespValuePtr&& val response->asString() += infoProcessor.getInfoCmdResponseString(); callbacks_.onResponse(std::move(response)); } - pending_responses_.clear(); + //pending_responses_.clear(); + return; } if ( rediscommand == "pubsub" || rediscommand == "keys" || rediscommand == "slowlog"|| rediscommand == "client") { if ((redisarg == "numpat" || redisarg == "len") && (rediscommand == "pubsub" || rediscommand == "slowlog")) { @@ -648,8 +653,9 @@ void mgmtNoKeyRequest::onallChildRespAgrregate(Common::Redis::RespValuePtr&& val if (positiveresponse) { response->asInteger() = sum; callbacks_.onResponse(std::move(response)); - pending_responses_.clear(); + //pending_responses_.clear(); } + return; } else { Common::Redis::RespValuePtr response = std::make_unique(); Common::Redis::RespValue innerResponse; @@ -732,15 +738,18 @@ void mgmtNoKeyRequest::onallChildRespAgrregate(Common::Redis::RespValuePtr&& val ENVOY_LOG(debug, "all response not same: '{}'", pending_responses_[0]->toString()); callbacks_.onResponse(Common::Redis::Utility::makeError( fmt::format("all responses not same"))); + /* if (!pending_responses_.empty()) { - pending_responses_.clear(); + //pending_responses_.clear(); } + */ } } if (positiveresponse) { callbacks_.onResponse(std::move(response)); - pending_responses_.clear(); + //pending_responses_.clear(); } + return; } } } @@ -755,7 +764,7 @@ void mgmtNoKeyRequest::onSingleShardresponse(Common::Redis::RespValuePtr&& value ENVOY_LOG(debug, "response: {}", value->toString()); updateStats(true); callbacks_.onResponse(std::move(value)); - pending_responses_.clear(); + //pending_responses_.clear(); } void mgmtNoKeyRequest::onAllChildResponseSame(Common::Redis::RespValuePtr&& value, int32_t reqindex, int32_t shardindex) { @@ -782,22 +791,24 @@ void mgmtNoKeyRequest::onAllChildResponseSame(Common::Redis::RespValuePtr&& valu ENVOY_LOG(debug, "Error Response received: '{}'", pending_responses_[response_index_]->toString()); Common::Redis::RespValuePtr response = std::move(pending_responses_[response_index_]); callbacks_.onResponse(std::move(response)); - pending_responses_.clear(); + //pending_responses_.clear(); } } else if(! areAllResponsesSame(pending_responses_)) { updateStats(false); ENVOY_LOG(debug, "all response not same: '{}'", pending_responses_[0]->toString()); callbacks_.onResponse(Common::Redis::Utility::makeError( fmt::format("all responses not same"))); + /* if (!pending_responses_.empty()) - pending_responses_.clear(); + // pending_responses_.clear(); + */ }else { updateStats(true); if (!pending_responses_.empty()) { Common::Redis::RespValuePtr response = std::move(pending_responses_[0]); ENVOY_LOG(debug, "response: {}", response->toString()); callbacks_.onResponse(std::move(response)); - pending_responses_.clear(); + //pending_responses_.clear(); } } } @@ -810,36 +821,20 @@ SplitRequestPtr BlockingClientRequest::create(Router& router, Common::Redis::Res // For blocking requests which operate on a single key, we can hash the key to a single //must send shard index as negative to indicate that it is a blocking request that acts on key std::string command_name = absl::AsciiStrToLower(incoming_request->asArray()[0].asString()); - std::string key =""; - int32_t shard_index=getShardIndex(command_name,1,1); + uint32_t key_index =getShardingKeyIndex(command_name,*incoming_request); + if (key_index < 0) { + ENVOY_LOG(debug, "unexpected command : '{}'", incoming_request->toString()); + callbacks.onResponse(Common::Redis::Utility::makeError(fmt::format("unexpected command format"))); + return nullptr; + } + std::string key = incoming_request->asArray()[key_index].asString(); + int32_t shard_index=-1; Common::Redis::Client::Transaction& transaction = callbacks.transaction(); std::unique_ptr request_ptr{ new BlockingClientRequest(callbacks, command_stats, time_source, delay_command_latency)}; - if (command_name == "xread"){ - int32_t index = 0; - int32_t count = incoming_request->asArray().size(); - while (count > 0) { - if (absl::AsciiStrToLower(incoming_request->asArray()[index].asString())== "streams") { - index++; - key = incoming_request->asArray()[index].asString(); - break; - } - index++; - count--; - } - if (key.empty()) { - ENVOY_LOG(debug, "unexpected command : '{}'", incoming_request->toString()); - callbacks.onResponse(Common::Redis::Utility::makeError(fmt::format("unexpected command format"))); - return nullptr; - } - - }else { - key = incoming_request->asArray()[1].asString(); - } - - if (transaction.active_ ){ + if (transaction.active_){ // when we are in blocking command, we cannnot accept any other commands if (transaction.isBlockingCommand()) { callbacks.onResponse( @@ -852,15 +847,9 @@ SplitRequestPtr BlockingClientRequest::create(Router& router, Common::Redis::Res return nullptr; } }else { - if (Common::Redis::SupportedCommands::blockingCommands().contains(command_name) || command_name == "xread") { transaction.clients_.resize(1); transaction.setBlockingCommand(); transaction.start(); - }else{ - ENVOY_LOG(debug, "unexpected command : '{}'", command_name); - callbacks.onResponse(Common::Redis::Utility::makeError(fmt::format("unexpected error"))); - return nullptr; - } } const auto route = router.upstreamPool(incoming_request->asArray()[1].asString(), stream_info); if (route) { @@ -882,6 +871,26 @@ SplitRequestPtr BlockingClientRequest::create(Router& router, Common::Redis::Res return request_ptr; } +int32_t BlockingClientRequest::getShardingKeyIndex(const std::string command_name, const Common::Redis::RespValue& request) { + if (Common::Redis::SupportedCommands::streamBlockingCommands().contains(command_name)) { + int32_t count = request.asArray().size(); + + for (int32_t index = 0; index < count; ++index) { + if (absl::AsciiStrToLower(request.asArray()[index].asString()) == "streams") { + // Check if the next index is within bounds + if (index + 1 < count) { + return index + 1; + } else { + return -1; // "streams" is the last element, so return -1 + } + } + } + return -1; // "streams" not found + } + return 1; // Default for non-stream blocking commands +} + + bool isKeyspaceArgument(const std::string& argument) { std::string keyspacepattern = "__keyspace@0__"; std::string keyeventpattern = "__keyevent@0__"; @@ -1113,7 +1122,9 @@ void PubSubMessageHandler::handleChannelMessageCustom(Common::Redis::RespValuePt void PubSubMessageHandler::onFailure() { ENVOY_LOG(debug, "failure in pubsub message handler"); - downstream_callbacks_->onFailure(); + if (downstream_callbacks_) { + downstream_callbacks_->onFailure(); + } } void MGETRequest::onChildResponse(Common::Redis::RespValuePtr&& value, uint32_t index) { @@ -1755,6 +1766,7 @@ InstanceImpl::InstanceImpl(RouterPtr&& router, Stats::Scope& scope, const std::s SplitRequestPtr InstanceImpl::makeRequest(Common::Redis::RespValuePtr&& request, SplitCallbacks& callbacks, Event::Dispatcher& dispatcher, const StreamInfo::StreamInfo& stream_info) { + // Validate request type and contents. if ((request->type() != Common::Redis::RespType::Array) || request->asArray().empty()) { ENVOY_LOG(debug,"invalid request - not an array or empty"); onInvalidRequest(callbacks); @@ -1769,16 +1781,16 @@ SplitRequestPtr InstanceImpl::makeRequest(Common::Redis::RespValuePtr&& request, } } + // Extract command name std::string command_name = absl::AsciiStrToLower(request->asArray()[0].asString()); + // Respond to HELLO locally adding this before auth, since hello will be issued before auth command if (command_name == Common::Redis::SupportedCommands::hello()) { - // Respond to HELLO locally - // Adding this before auth, since hello will be issued before auth command callbacks.onResponse(Common::Redis::Utility::makeError(Response::get().UnKnownCommandHello)); return nullptr; - } + // Handle AUTH command if (command_name == Common::Redis::SupportedCommands::auth()) { if (request->asArray().size() < 2) { ENVOY_LOG(debug,"invalid request - not enough arguments for auth command"); @@ -1794,12 +1806,14 @@ SplitRequestPtr InstanceImpl::makeRequest(Common::Redis::RespValuePtr&& request, return nullptr; } + // Ensure connection is allowed or auth is required. if (!callbacks.connectionAllowed()) { stats_.auth_failure_.inc(); callbacks.onResponse(Common::Redis::Utility::makeError(Response::get().AuthRequiredError)); return nullptr; } + // Handle PING command locally if (command_name == Common::Redis::SupportedCommands::ping()) { // Respond to PING locally. Common::Redis::RespValuePtr pong(new Common::Redis::RespValue()); @@ -1809,6 +1823,7 @@ SplitRequestPtr InstanceImpl::makeRequest(Common::Redis::RespValuePtr&& request, return nullptr; } + // Handle TIME command locally if (command_name == Common::Redis::SupportedCommands::time()) { // Respond to TIME locally. Common::Redis::RespValuePtr time_resp(new Common::Redis::RespValue()); @@ -1833,7 +1848,8 @@ SplitRequestPtr InstanceImpl::makeRequest(Common::Redis::RespValuePtr&& request, callbacks.onResponse(std::move(time_resp)); return nullptr; } - // For transaction type commands and blockingcommands , quit needs to be handled from within the command handler + + // Hadle QUIT and EXIT commands locally if its not part of transaction or subscribed state if ((command_name == Common::Redis::SupportedCommands::quit() || command_name == Common::Redis::SupportedCommands::exit()) && !callbacks.transaction().active_) { callbacks.onQuit(); return nullptr; @@ -1848,6 +1864,7 @@ SplitRequestPtr InstanceImpl::makeRequest(Common::Redis::RespValuePtr&& request, return nullptr; } + // Handle CLIENT command locally if (command_name == "client") { std::string sub_command = absl::AsciiStrToLower(request->asArray()[1].asString()); if (Common::Redis::SupportedCommands::clientSubCommands().count(sub_command) == 0) { @@ -1877,12 +1894,21 @@ SplitRequestPtr InstanceImpl::makeRequest(Common::Redis::RespValuePtr&& request, // Get the handler for the downstream request auto handler = handler_lookup_table_.find(command_name.c_str()); - if (handler == nullptr && !callbacks.transaction().isSubscribedMode() && command_name!=Common::Redis::SupportedCommands::spl_strm_commands()) { - stats_.unsupported_command_.inc(); - ENVOY_LOG(debug, "unsupported command '{}'", request->asArray()[0].asString()); - callbacks.onResponse(Common::Redis::Utility::makeError( + if (handler == nullptr ){ + if (callbacks.transaction().active_ && callbacks.transaction().isSubscribedMode() && !Common::Redis::SupportedCommands::subcrStateallowedCommands().contains(command_name)) { + callbacks.onResponse(Common::Redis::Utility::makeError("command not supported in subscribed state")); + return nullptr; + }else if(Common::Redis::SupportedCommands::streamCommands().contains(command_name)){ + //Stream commands are not listed directly under any handler , we need to check if it is a blocking or simple command and choose appropriate handler + handler=getHandlerForStreamsCommand(command_name,request); + + }else{ + stats_.unsupported_command_.inc(); + ENVOY_LOG(debug, "unsupported command '{}'", request->asArray()[0].asString()); + callbacks.onResponse(Common::Redis::Utility::makeError( fmt::format("unsupported command '{}'", request->asArray()[0].asString()))); - return nullptr; + return nullptr; + } } // If we are within a transaction, forward all requests to the transaction handler (i.e. handler @@ -1896,18 +1922,6 @@ SplitRequestPtr InstanceImpl::makeRequest(Common::Redis::RespValuePtr&& request, handler = handler_lookup_table_.find("subscribe"); } - //If the command is xread we need to check if its a blocking command or not - if (command_name == "xread") { - if (((request->asArray().size() > 1) && (absl::AsciiStrToLower(request->asArray()[1].asString()) == "block")) || - ((request->asArray().size() > 3) && (absl::AsciiStrToLower(request->asArray()[3].asString()) == "block"))) { - handler = handler_lookup_table_.find("xreadblock"); - } else { - handler = handler_lookup_table_.find("xreadnonblock"); - } - - } - - // Fault Injection Check const Common::Redis::Fault* fault_ptr = fault_manager_->getFaultForCommand(command_name); @@ -1973,6 +1987,21 @@ void InstanceImpl::addHandler(Stats::Scope& scope, const std::string& stat_prefi handler})); } +InstanceImpl::HandlerDataPtr InstanceImpl::getHandlerForStreamsCommand(const std::string& command_name, const Common::Redis::RespValuePtr& request) { + // Check if the command is a stream blocking command. + if (Common::Redis::SupportedCommands::streamBlockingCommands().contains(command_name)) { + // Check for "block" keyword in the appropriate positions of the request array. + if ((request->asArray().size() > 1 && absl::AsciiStrToLower(request->asArray()[1].asString()) == "block") || + (request->asArray().size() > 3 && absl::AsciiStrToLower(request->asArray()[3].asString()) == "block")) { + return handler_lookup_table_.find("xread_blocking_command"); + } + } + + // Default to "xread_simple_command" if it's not a blocking command. + return handler_lookup_table_.find("xread_simple_command"); +} + + } // namespace CommandSplitter } // namespace RedisProxy } // namespace NetworkFilters diff --git a/source/extensions/filters/network/redis_proxy/command_splitter_impl.h b/source/extensions/filters/network/redis_proxy/command_splitter_impl.h index d34c7afa10a3..38f7774cb65c 100644 --- a/source/extensions/filters/network/redis_proxy/command_splitter_impl.h +++ b/source/extensions/filters/network/redis_proxy/command_splitter_impl.h @@ -360,6 +360,8 @@ class BlockingClientRequest : public SingleServerRequest { BlockingClientRequest(SplitCallbacks& callbacks, CommandStats& command_stats, TimeSource& time_source, bool delay_command_latency) : SingleServerRequest(callbacks, command_stats, time_source, delay_command_latency) {} + + static int32_t getShardingKeyIndex(const std::string command_name,const Common::Redis::RespValue& request); }; @@ -377,6 +379,8 @@ class SimpleRequest : public SingleServerRequest { SimpleRequest(SplitCallbacks& callbacks, CommandStats& command_stats, TimeSource& time_source, bool delay_command_latency) : SingleServerRequest(callbacks, command_stats, time_source, delay_command_latency) {} + + static int32_t getShardingKeyIndex(const std::string command_name,const Common::Redis::RespValue& request); }; /** @@ -597,6 +601,8 @@ class InstanceImpl : public Instance, Logger::Loggable { bool latency_in_micros, CommandHandler& handler); void onInvalidRequest(SplitCallbacks& callbacks); + HandlerDataPtr getHandlerForStreamsCommand(const std::string& command_name, const Common::Redis::RespValuePtr& request); + RouterPtr router_; CommandHandlerFactory simple_command_handler_; CommandHandlerFactory eval_command_handler_; diff --git a/source/extensions/filters/network/redis_proxy/conn_pool_impl.cc b/source/extensions/filters/network/redis_proxy/conn_pool_impl.cc index 71989c80034e..28c4fca0ba52 100644 --- a/source/extensions/filters/network/redis_proxy/conn_pool_impl.cc +++ b/source/extensions/filters/network/redis_proxy/conn_pool_impl.cc @@ -529,36 +529,19 @@ InstanceImpl::ThreadLocalPool::makeRequestNoKey(int32_t shard_index, RespVariant Upstream::HostConstSharedPtr host = (*hosts)[shard_index]; pending_requests_.emplace_back(*this, std::move(request), callbacks, host); PendingRequest& pending_request = pending_requests_.back(); - - uint32_t client_idx = 0; - // If there is an active transaction, establish a new connection if necessary. - if (transaction.active_) { - ENVOY_LOG(error,"Ideally transanction client should not be used for admin commands ERROR!!!"); - client_idx = transaction.current_client_idx_; - if ((!transaction.connection_established_ && transaction.is_subscribed_mode_) || (!transaction.connection_established_ && transaction.is_blocking_command_)) { - transaction.clients_[client_idx] = - client_factory_.create(host, dispatcher_, *config_, redis_command_stats_, *(stats_scope_), - auth_username_, auth_password_, false,true,true,nullptr); - if (transaction.connection_cb_) { - transaction.clients_[client_idx]->addConnectionCallbacks(*transaction.connection_cb_); - } - } - pending_request.request_handler_ = transaction.clients_[client_idx]->makeRequest( - getRequest(pending_request.incoming_request_), pending_request); - }else { - ThreadLocalActiveClientPtr& client = this->threadLocalActiveClient(host); - if (!client) { - ENVOY_LOG(debug, "redis connection is rate limited, erasing empty client"); - pending_request.request_handler_ = nullptr; - onRequestCompleted(); - client_map_.erase(host); - return nullptr; - } - pending_request.request_handler_ = client->redis_client_->makeRequest( - getRequest(pending_request.incoming_request_), pending_request); + ThreadLocalActiveClientPtr& client = this->threadLocalActiveClient(host); + if (!client) { + ENVOY_LOG(debug, "redis connection is rate limited, erasing empty client"); + pending_request.request_handler_ = nullptr; + onRequestCompleted(); + client_map_.erase(host); + return nullptr; } + pending_request.request_handler_ = client->redis_client_->makeRequest( + getRequest(pending_request.incoming_request_), pending_request); + if (pending_request.request_handler_) { return &pending_request;