Skip to content

Commit

Permalink
redis auth support select db
Browse files Browse the repository at this point in the history
  • Loading branch information
zwkno1 authored and zhouwk committed Jun 6, 2022
1 parent c1dc6e8 commit b86814a
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 36 deletions.
3 changes: 2 additions & 1 deletion src/brpc/controller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ void Controller::ResetPods() {
_request_stream = INVALID_STREAM_ID;
_response_stream = INVALID_STREAM_ID;
_remote_stream_settings = NULL;
_auth_flags = 0;
}

Controller::Call::Call(Controller::Call* rhs)
Expand Down Expand Up @@ -1162,7 +1163,7 @@ void Controller::IssueRPC(int64_t start_realtime_us) {
wopt.id_wait = cid;
wopt.abstime = pabstime;
wopt.pipelined_count = _pipelined_count;
wopt.with_auth = has_flag(FLAGS_REQUEST_WITH_AUTH);
wopt.auth_flags = _auth_flags;
wopt.ignore_eovercrowded = has_flag(FLAGS_IGNORE_EOVERCROWDED);
int rc;
size_t packet_size = 0;
Expand Down
8 changes: 7 additions & 1 deletion src/brpc/controller.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
// To brpc developers: This is a header included by user, don't depend
// on internal structures, use opaque pointers instead.

#include <cstdint>
#include <gflags/gflags.h> // Users often need gflags
#include <string>
#include "butil/intrusive_ptr.hpp" // butil::intrusive_ptr
Expand Down Expand Up @@ -138,7 +139,6 @@ friend void policy::ProcessThriftRequest(InputMessageBase*);
static const uint32_t FLAGS_PB_BYTES_TO_BASE64 = (1 << 11);
static const uint32_t FLAGS_ALLOW_DONE_TO_RUN_IN_PLACE = (1 << 12);
static const uint32_t FLAGS_USED_BY_RPC = (1 << 13);
static const uint32_t FLAGS_REQUEST_WITH_AUTH = (1 << 15);
static const uint32_t FLAGS_PB_JSONIFY_EMPTY_ARRAY = (1 << 16);
static const uint32_t FLAGS_ENABLED_CIRCUIT_BREAKER = (1 << 17);
static const uint32_t FLAGS_ALWAYS_PRINT_PRIMITIVE_FIELDS = (1 << 18);
Expand Down Expand Up @@ -554,6 +554,10 @@ friend void policy::ProcessThriftRequest(InputMessageBase*);
// -1 means no deadline.
int64_t deadline_us() const { return _deadline_us; }

void add_auth_flags(uint32_t auth_flags) { _auth_flags = auth_flags; }

void clear_auth_flags() { _auth_flags = 0; }

private:
struct CompletionInfo {
CallId id; // call_id of the corresponding request
Expand Down Expand Up @@ -807,6 +811,8 @@ friend void policy::ProcessThriftRequest(InputMessageBase*);

// Thrift method name, only used when thrift protocol enabled
std::string _thrift_method_name;

uint32_t _auth_flags;
};

// Advises the RPC system that the caller desires that the RPC call be
Expand Down
8 changes: 4 additions & 4 deletions src/brpc/details/controller_private_accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ class ControllerPrivateAccessor {
void set_readable_progressive_attachment(ReadableProgressiveAttachment* s)
{ _cntl->_rpa.reset(s); }

void add_with_auth() {
_cntl->add_flag(Controller::FLAGS_REQUEST_WITH_AUTH);
void add_auth_flags(uint32_t auth_flags) {
_cntl->add_auth_flags(auth_flags);
}

void clear_with_auth() {
_cntl->clear_flag(Controller::FLAGS_REQUEST_WITH_AUTH);
void clear_auth_flags() {
_cntl->clear_auth_flags();
}

std::string& protocol_param() { return _cntl->protocol_param(); }
Expand Down
7 changes: 6 additions & 1 deletion src/brpc/policy/redis_authenticator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ namespace policy {

int RedisAuthenticator::GenerateCredential(std::string* auth_str) const {
butil::IOBuf buf;
brpc::RedisCommandFormat(&buf, "AUTH %s", passwd_.c_str());
if(!passwd_.empty()) {
brpc::RedisCommandFormat(&buf, "AUTH %s", passwd_.c_str());
}
if(db_ >= 0) {
brpc::RedisCommandFormat(&buf, "SELECT %d", db_);
}
*auth_str = buf.to_string();
return 0;
}
Expand Down
16 changes: 14 additions & 2 deletions src/brpc/policy/redis_authenticator.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ namespace policy {
// Request to redis for authentication.
class RedisAuthenticator : public Authenticator {
public:
RedisAuthenticator(const std::string& passwd)
: passwd_(passwd) {}
RedisAuthenticator(const std::string& passwd, int db = -1)
: passwd_(passwd), db_(db) {}

int GenerateCredential(std::string* auth_str) const;

Expand All @@ -36,8 +36,20 @@ class RedisAuthenticator : public Authenticator {
return 0;
}

uint32_t GetAuthFlags() const {
uint32_t n = 0;
if(!passwd_.empty()) {
++n;
}
if(db_ != -1) {
++n;
}
return n;
}

private:
const std::string passwd_;
int db_;
};

} // namespace policy
Expand Down
27 changes: 17 additions & 10 deletions src/brpc/policy/redis_protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <google/protobuf/descriptor.h> // MethodDescriptor
#include <google/protobuf/message.h> // Message
#include <gflags/gflags.h>
#include "brpc/policy/redis_authenticator.h"
#include "butil/logging.h" // LOG()
#include "butil/time.h"
#include "butil/iobuf.h" // butil::IOBuf
Expand Down Expand Up @@ -214,26 +215,28 @@ ParseResult ParseRedisMessage(butil::IOBuf* source, Socket* socket,
socket->reset_parsing_context(msg);
}

const int consume_count = (pi.with_auth ? 1 : pi.count);
const int consume_count = (pi.auth_flags ? pi.auth_flags : pi.count);

ParseError err = msg->response.ConsumePartialIOBuf(*source, consume_count);
if (err != PARSE_OK) {
socket->GivebackPipelinedInfo(pi);
return MakeParseError(err);
}

if (pi.with_auth) {
if (msg->response.reply_size() != 1 ||
!(msg->response.reply(0).type() == brpc::REDIS_REPLY_STATUS &&
msg->response.reply(0).data().compare("OK") == 0)) {
LOG(ERROR) << "Redis Auth failed: " << msg->response;
return MakeParseError(PARSE_ERROR_NO_RESOURCE,
if (pi.auth_flags) {
for(int i = 0; i < (int)pi.auth_flags; ++i) {
if (i >= msg->response.reply_size() ||
!(msg->response.reply(i).type() == brpc::REDIS_REPLY_STATUS &&
msg->response.reply(i).data().compare("OK") == 0)) {
LOG(ERROR) << "Redis Auth failed: " << msg->response;
return MakeParseError(PARSE_ERROR_NO_RESOURCE,
"Fail to authenticate with Redis");
}
}

DestroyingPtr<InputResponse> auth_msg(
static_cast<InputResponse*>(socket->release_parsing_context()));
pi.with_auth = false;
pi.auth_flags = 0;
continue;
}

Expand Down Expand Up @@ -333,9 +336,13 @@ void PackRedisRequest(butil::IOBuf* buf,
return cntl->SetFailed(EREQUEST, "Fail to generate credential");
}
buf->append(auth_str);
ControllerPrivateAccessor(cntl).add_with_auth();
const RedisAuthenticator * redis_auth = dynamic_cast<const RedisAuthenticator *>(auth);
if (redis_auth == NULL) {
return cntl->SetFailed(EREQUEST, "Fail to generate credential");
}
ControllerPrivateAccessor(cntl).add_auth_flags(redis_auth->GetAuthFlags());
} else {
ControllerPrivateAccessor(cntl).clear_with_auth();
ControllerPrivateAccessor(cntl).clear_auth_flags();
}

buf->append(request);
Expand Down
26 changes: 13 additions & 13 deletions src/brpc/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ bool Socket::CreatedByConnect() const {
}

SocketMessage* const DUMMY_USER_MESSAGE = (SocketMessage*)0x1;
const uint32_t MAX_PIPELINED_COUNT = 32768;
const uint32_t MAX_PIPELINED_COUNT = 16384;

struct BAIDU_CACHELINE_ALIGNMENT Socket::WriteRequest {
static WriteRequest* const UNCONNECTED;
Expand All @@ -306,12 +306,12 @@ struct BAIDU_CACHELINE_ALIGNMENT Socket::WriteRequest {
Socket* socket;

uint32_t pipelined_count() const {
return (_pc_and_udmsg >> 48) & 0x7FFF;
return (_pc_and_udmsg >> 48) & 0x3FFF;
}
bool is_with_auth() const {
return _pc_and_udmsg & 0x8000000000000000ULL;
uint32_t get_auth_flags() const {
return (_pc_and_udmsg >> 62) & 0x03;
}
void clear_pipelined_count_and_with_auth() {
void clear_pipelined_count_and_auth_flags() {
_pc_and_udmsg &= 0xFFFFFFFFFFFFULL;
}
SocketMessage* user_message() const {
Expand All @@ -321,9 +321,9 @@ struct BAIDU_CACHELINE_ALIGNMENT Socket::WriteRequest {
_pc_and_udmsg &= 0xFFFF000000000000ULL;
}
void set_pipelined_count_and_user_message(
uint32_t pc, SocketMessage* msg, bool with_auth) {
if (with_auth) {
pc |= (1 << 15);
uint32_t pc, SocketMessage* msg, uint32_t auth_flags) {
if(auth_flags) {
pc |= (auth_flags & 0x03) << 14;
}
_pc_and_udmsg = ((uint64_t)pc << 48) | (uint64_t)(uintptr_t)msg;
}
Expand All @@ -337,7 +337,7 @@ struct BAIDU_CACHELINE_ALIGNMENT Socket::WriteRequest {
// is already failed.
(void)msg->AppendAndDestroySelf(&dummy_buf, NULL);
}
set_pipelined_count_and_user_message(0, NULL, false);
set_pipelined_count_and_user_message(0, NULL, 0);
return true;
}
return false;
Expand Down Expand Up @@ -376,9 +376,9 @@ void Socket::WriteRequest::Setup(Socket* s) {
// The struct will be popped when reading a message from the socket.
PipelinedInfo pi;
pi.count = pc;
pi.with_auth = is_with_auth();
pi.auth_flags = get_auth_flags();
pi.id_wait = id_wait;
clear_pipelined_count_and_with_auth(); // avoid being pushed again
clear_pipelined_count_and_auth_flags(); // avoid being pushed again
s->PushPipelinedInfo(pi);
}
}
Expand Down Expand Up @@ -1462,7 +1462,7 @@ int Socket::Write(butil::IOBuf* data, const WriteOptions* options_in) {
req->next = WriteRequest::UNCONNECTED;
req->id_wait = opt.id_wait;
req->set_pipelined_count_and_user_message(
opt.pipelined_count, DUMMY_USER_MESSAGE, opt.with_auth);
opt.pipelined_count, DUMMY_USER_MESSAGE, opt.auth_flags);
return StartWrite(req, opt);
}

Expand Down Expand Up @@ -1497,7 +1497,7 @@ int Socket::Write(SocketMessagePtr<>& msg, const WriteOptions* options_in) {
// wait until it points to a valid WriteRequest or NULL.
req->next = WriteRequest::UNCONNECTED;
req->id_wait = opt.id_wait;
req->set_pipelined_count_and_user_message(opt.pipelined_count, msg.release(), opt.with_auth);
req->set_pipelined_count_and_user_message(opt.pipelined_count, msg.release(), opt.auth_flags);
return StartWrite(req, opt);
}

Expand Down
8 changes: 4 additions & 4 deletions src/brpc/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,11 @@ struct PipelinedInfo {
PipelinedInfo() { reset(); }
void reset() {
count = 0;
with_auth = false;
auth_flags = 0;
id_wait = INVALID_BTHREAD_ID;
}
uint32_t count;
bool with_auth;
uint32_t auth_flags;
bthread_id_t id_wait;
};

Expand Down Expand Up @@ -256,15 +256,15 @@ friend class policy::H2GlobalStreamCreator;
// The request contains authenticating information which will be
// responded by the server and processed specially when dealing
// with the response.
bool with_auth;
uint32_t auth_flags;

// Do not return EOVERCROWDED
// Default: false
bool ignore_eovercrowded;

WriteOptions()
: id_wait(INVALID_BTHREAD_ID), abstime(NULL)
, pipelined_count(0), with_auth(false)
, pipelined_count(0), auth_flags(0)
, ignore_eovercrowded(false) {}
};
int Write(butil::IOBuf *msg, const WriteOptions* options = NULL);
Expand Down

0 comments on commit b86814a

Please sign in to comment.