Skip to content

Commit

Permalink
Merge pull request #1768 from zwkno1/redis_auth
Browse files Browse the repository at this point in the history
redis auth support select db
  • Loading branch information
lorinlee authored Jul 19, 2022
2 parents a6de1a9 + 52d95b8 commit 53cbd1a
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 38 deletions.
3 changes: 2 additions & 1 deletion src/brpc/controller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ void Controller::ResetPods() {
_request_stream = INVALID_STREAM_ID;
_response_stream = INVALID_STREAM_ID;
_remote_stream_settings = NULL;
_auth_flags = 0;
}

Controller::Call::Call(Controller::Call* rhs)
Expand Down Expand Up @@ -1162,7 +1163,7 @@ void Controller::IssueRPC(int64_t start_realtime_us) {
wopt.id_wait = cid;
wopt.abstime = pabstime;
wopt.pipelined_count = _pipelined_count;
wopt.with_auth = has_flag(FLAGS_REQUEST_WITH_AUTH);
wopt.auth_flags = _auth_flags;
wopt.ignore_eovercrowded = has_flag(FLAGS_IGNORE_EOVERCROWDED);
int rc;
size_t packet_size = 0;
Expand Down
3 changes: 2 additions & 1 deletion src/brpc/controller.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ friend void policy::ProcessThriftRequest(InputMessageBase*);
static const uint32_t FLAGS_PB_BYTES_TO_BASE64 = (1 << 11);
static const uint32_t FLAGS_ALLOW_DONE_TO_RUN_IN_PLACE = (1 << 12);
static const uint32_t FLAGS_USED_BY_RPC = (1 << 13);
static const uint32_t FLAGS_REQUEST_WITH_AUTH = (1 << 15);
static const uint32_t FLAGS_PB_JSONIFY_EMPTY_ARRAY = (1 << 16);
static const uint32_t FLAGS_ENABLED_CIRCUIT_BREAKER = (1 << 17);
static const uint32_t FLAGS_ALWAYS_PRINT_PRIMITIVE_FIELDS = (1 << 18);
Expand Down Expand Up @@ -807,6 +806,8 @@ friend void policy::ProcessThriftRequest(InputMessageBase*);

// Thrift method name, only used when thrift protocol enabled
std::string _thrift_method_name;

uint32_t _auth_flags;
};

// Advises the RPC system that the caller desires that the RPC call be
Expand Down
8 changes: 3 additions & 5 deletions src/brpc/details/controller_private_accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,11 @@ class ControllerPrivateAccessor {
void set_readable_progressive_attachment(ReadableProgressiveAttachment* s)
{ _cntl->_rpa.reset(s); }

void add_with_auth() {
_cntl->add_flag(Controller::FLAGS_REQUEST_WITH_AUTH);
void set_auth_flags(uint32_t auth_flags) {
_cntl->_auth_flags = auth_flags;
}

void clear_with_auth() {
_cntl->clear_flag(Controller::FLAGS_REQUEST_WITH_AUTH);
}
void clear_auth_flags() { _cntl->_auth_flags = 0; }

std::string& protocol_param() { return _cntl->protocol_param(); }
const std::string& protocol_param() const { return _cntl->protocol_param(); }
Expand Down
7 changes: 6 additions & 1 deletion src/brpc/policy/redis_authenticator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ namespace policy {

int RedisAuthenticator::GenerateCredential(std::string* auth_str) const {
butil::IOBuf buf;
brpc::RedisCommandFormat(&buf, "AUTH %s", passwd_.c_str());
if (!passwd_.empty()) {
brpc::RedisCommandFormat(&buf, "AUTH %s", passwd_.c_str());
}
if (db_ >= 0) {
brpc::RedisCommandFormat(&buf, "SELECT %d", db_);
}
*auth_str = buf.to_string();
return 0;
}
Expand Down
17 changes: 15 additions & 2 deletions src/brpc/policy/redis_authenticator.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ namespace policy {
// Request to redis for authentication.
class RedisAuthenticator : public Authenticator {
public:
RedisAuthenticator(const std::string& passwd)
: passwd_(passwd) {}
RedisAuthenticator(const std::string& passwd, int db = -1)
: passwd_(passwd), db_(db) {}

int GenerateCredential(std::string* auth_str) const;

Expand All @@ -36,8 +36,21 @@ class RedisAuthenticator : public Authenticator {
return 0;
}

uint32_t GetAuthFlags() const {
uint32_t n = 0;
if (!passwd_.empty()) {
++n;
}
if (db_ >= 0) {
++n;
}
return n;
}

private:
const std::string passwd_;

int db_;
};

} // namespace policy
Expand Down
32 changes: 21 additions & 11 deletions src/brpc/policy/redis_protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <google/protobuf/descriptor.h> // MethodDescriptor
#include <google/protobuf/message.h> // Message
#include <gflags/gflags.h>
#include "brpc/policy/redis_authenticator.h"
#include "butil/logging.h" // LOG()
#include "butil/time.h"
#include "butil/iobuf.h" // butil::IOBuf
Expand Down Expand Up @@ -214,26 +215,29 @@ ParseResult ParseRedisMessage(butil::IOBuf* source, Socket* socket,
socket->reset_parsing_context(msg);
}

const int consume_count = (pi.with_auth ? 1 : pi.count);
const int consume_count = (pi.auth_flags ? pi.auth_flags : pi.count);

ParseError err = msg->response.ConsumePartialIOBuf(*source, consume_count);
if (err != PARSE_OK) {
socket->GivebackPipelinedInfo(pi);
return MakeParseError(err);
}

if (pi.with_auth) {
if (msg->response.reply_size() != 1 ||
!(msg->response.reply(0).type() == brpc::REDIS_REPLY_STATUS &&
msg->response.reply(0).data().compare("OK") == 0)) {
LOG(ERROR) << "Redis Auth failed: " << msg->response;
return MakeParseError(PARSE_ERROR_NO_RESOURCE,
"Fail to authenticate with Redis");
if (pi.auth_flags) {
for (int i = 0; i < (int)pi.auth_flags; ++i) {
if (i >= msg->response.reply_size() ||
!(msg->response.reply(i).type() ==
brpc::REDIS_REPLY_STATUS &&
msg->response.reply(i).data().compare("OK") == 0)) {
LOG(ERROR) << "Redis Auth failed: " << msg->response;
return MakeParseError(PARSE_ERROR_NO_RESOURCE,
"Fail to authenticate with Redis");
}
}

DestroyingPtr<InputResponse> auth_msg(
static_cast<InputResponse*>(socket->release_parsing_context()));
pi.with_auth = false;
pi.auth_flags = 0;
continue;
}

Expand Down Expand Up @@ -333,9 +337,15 @@ void PackRedisRequest(butil::IOBuf* buf,
return cntl->SetFailed(EREQUEST, "Fail to generate credential");
}
buf->append(auth_str);
ControllerPrivateAccessor(cntl).add_with_auth();
const RedisAuthenticator* redis_auth =
dynamic_cast<const RedisAuthenticator*>(auth);
if (redis_auth == NULL) {
return cntl->SetFailed(EREQUEST, "Fail to generate credential");
}
ControllerPrivateAccessor(cntl).set_auth_flags(
redis_auth->GetAuthFlags());
} else {
ControllerPrivateAccessor(cntl).clear_with_auth();
ControllerPrivateAccessor(cntl).clear_auth_flags();
}

buf->append(request);
Expand Down
26 changes: 13 additions & 13 deletions src/brpc/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ bool Socket::CreatedByConnect() const {
}

SocketMessage* const DUMMY_USER_MESSAGE = (SocketMessage*)0x1;
const uint32_t MAX_PIPELINED_COUNT = 32768;
const uint32_t MAX_PIPELINED_COUNT = 16384;

struct BAIDU_CACHELINE_ALIGNMENT Socket::WriteRequest {
static WriteRequest* const UNCONNECTED;
Expand All @@ -306,12 +306,12 @@ struct BAIDU_CACHELINE_ALIGNMENT Socket::WriteRequest {
Socket* socket;

uint32_t pipelined_count() const {
return (_pc_and_udmsg >> 48) & 0x7FFF;
return (_pc_and_udmsg >> 48) & 0x3FFF;
}
bool is_with_auth() const {
return _pc_and_udmsg & 0x8000000000000000ULL;
uint32_t get_auth_flags() const {
return (_pc_and_udmsg >> 62) & 0x03;
}
void clear_pipelined_count_and_with_auth() {
void clear_pipelined_count_and_auth_flags() {
_pc_and_udmsg &= 0xFFFFFFFFFFFFULL;
}
SocketMessage* user_message() const {
Expand All @@ -321,9 +321,9 @@ struct BAIDU_CACHELINE_ALIGNMENT Socket::WriteRequest {
_pc_and_udmsg &= 0xFFFF000000000000ULL;
}
void set_pipelined_count_and_user_message(
uint32_t pc, SocketMessage* msg, bool with_auth) {
if (with_auth) {
pc |= (1 << 15);
uint32_t pc, SocketMessage* msg, uint32_t auth_flags) {
if (auth_flags) {
pc |= (auth_flags & 0x03) << 14;
}
_pc_and_udmsg = ((uint64_t)pc << 48) | (uint64_t)(uintptr_t)msg;
}
Expand All @@ -337,7 +337,7 @@ struct BAIDU_CACHELINE_ALIGNMENT Socket::WriteRequest {
// is already failed.
(void)msg->AppendAndDestroySelf(&dummy_buf, NULL);
}
set_pipelined_count_and_user_message(0, NULL, false);
set_pipelined_count_and_user_message(0, NULL, 0);
return true;
}
return false;
Expand Down Expand Up @@ -376,9 +376,9 @@ void Socket::WriteRequest::Setup(Socket* s) {
// The struct will be popped when reading a message from the socket.
PipelinedInfo pi;
pi.count = pc;
pi.with_auth = is_with_auth();
pi.auth_flags = get_auth_flags();
pi.id_wait = id_wait;
clear_pipelined_count_and_with_auth(); // avoid being pushed again
clear_pipelined_count_and_auth_flags(); // avoid being pushed again
s->PushPipelinedInfo(pi);
}
}
Expand Down Expand Up @@ -1462,7 +1462,7 @@ int Socket::Write(butil::IOBuf* data, const WriteOptions* options_in) {
req->next = WriteRequest::UNCONNECTED;
req->id_wait = opt.id_wait;
req->set_pipelined_count_and_user_message(
opt.pipelined_count, DUMMY_USER_MESSAGE, opt.with_auth);
opt.pipelined_count, DUMMY_USER_MESSAGE, opt.auth_flags);
return StartWrite(req, opt);
}

Expand Down Expand Up @@ -1497,7 +1497,7 @@ int Socket::Write(SocketMessagePtr<>& msg, const WriteOptions* options_in) {
// wait until it points to a valid WriteRequest or NULL.
req->next = WriteRequest::UNCONNECTED;
req->id_wait = opt.id_wait;
req->set_pipelined_count_and_user_message(opt.pipelined_count, msg.release(), opt.with_auth);
req->set_pipelined_count_and_user_message(opt.pipelined_count, msg.release(), opt.auth_flags);
return StartWrite(req, opt);
}

Expand Down
8 changes: 4 additions & 4 deletions src/brpc/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,11 @@ struct PipelinedInfo {
PipelinedInfo() { reset(); }
void reset() {
count = 0;
with_auth = false;
auth_flags = 0;
id_wait = INVALID_BTHREAD_ID;
}
uint32_t count;
bool with_auth;
uint32_t auth_flags;
bthread_id_t id_wait;
};

Expand Down Expand Up @@ -256,15 +256,15 @@ friend class policy::H2GlobalStreamCreator;
// The request contains authenticating information which will be
// responded by the server and processed specially when dealing
// with the response.
bool with_auth;
uint32_t auth_flags;

// Do not return EOVERCROWDED
// Default: false
bool ignore_eovercrowded;

WriteOptions()
: id_wait(INVALID_BTHREAD_ID), abstime(NULL)
, pipelined_count(0), with_auth(false)
, pipelined_count(0), auth_flags(0)
, ignore_eovercrowded(false) {}
};
int Write(butil::IOBuf *msg, const WriteOptions* options = NULL);
Expand Down

0 comments on commit 53cbd1a

Please sign in to comment.