Skip to content

Commit

Permalink
Support on_failed callback for streaming rpc (#2565)
Browse files Browse the repository at this point in the history
* Support on_failed callback for streaming rpc

* Call on_failed before on_closed
  • Loading branch information
chenBright authored Apr 8, 2024
1 parent 16ab5b5 commit 337142f
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 48 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,6 @@ CTestTestfile.cmake
/test/curl.out
/test/out.txt
/test/recordio_ref.io

# Ignore protoc-gen-mcpack files
/protoc-gen-mcpack*/
3 changes: 2 additions & 1 deletion src/brpc/controller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1387,7 +1387,8 @@ void Controller::HandleStreamConnection(Socket *host_socket) {
}
}
if (FailedInline()) {
Stream::SetFailed(_request_stream);
Stream::SetFailed(_request_stream, _error_code,
"%s", _error_text.c_str());
if (_remote_stream_settings != NULL) {
policy::SendStreamRst(host_socket,
_remote_stream_settings->stream_id());
Expand Down
10 changes: 6 additions & 4 deletions src/brpc/policy/baidu_rpc_protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,13 @@ void SendRpcResponse(int64_t correlation_id,
accessor.remote_stream_settings()->stream_id(),
accessor.response_stream()) != 0) {
const int errcode = errno;
PLOG_IF(WARNING, errcode != EPIPE) << "Fail to write into " << *sock;
cntl->SetFailed(errcode, "Fail to write into %s",
sock->description().c_str());
std::string error_text = butil::string_printf(64, "Fail to write into %s",
sock->description().c_str());
PLOG_IF(WARNING, errcode != EPIPE) << error_text;
cntl->SetFailed(errcode, "%s", error_text.c_str());
if(stream_ptr) {
((Stream*)stream_ptr->conn())->Close();
((Stream*)stream_ptr->conn())->Close(errcode, "%s",
error_text.c_str());
}
return;
}
Expand Down
10 changes: 5 additions & 5 deletions src/brpc/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,7 @@ int Socket::SetFailed(int error_code, const char* error_fmt, ...) {
&_id_wait_list, error_code, error_text,
&_id_wait_list_mutex));

ResetAllStreams();
ResetAllStreams(error_code, error_text);
// _app_connect shouldn't be set to NULL in SetFailed otherwise
// HC is always not supported.
// FIXME: Design a better interface for AppConnect
Expand Down Expand Up @@ -2541,7 +2541,7 @@ int Socket::RemoveStream(StreamId stream_id) {
return 0;
}

void Socket::ResetAllStreams() {
void Socket::ResetAllStreams(int error_code, const std::string& error_text) {
DCHECK(Failed());
std::set<StreamId> saved_stream_set;
_stream_mutex.lock();
Expand All @@ -2552,9 +2552,9 @@ void Socket::ResetAllStreams() {
saved_stream_set.swap(*_stream_set);
}
_stream_mutex.unlock();
for (std::set<StreamId>::const_iterator
it = saved_stream_set.begin(); it != saved_stream_set.end(); ++it) {
Stream::SetFailed(*it);
for (auto it = saved_stream_set.begin();
it != saved_stream_set.end(); ++it) {
Stream::SetFailed(*it, error_code, "%s", error_text.c_str());
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/brpc/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ friend void DereferenceSocket(Socket*);
// broken socket.
int AddStream(StreamId stream_id);
int RemoveStream(StreamId stream_id);
void ResetAllStreams();
void ResetAllStreams(int error_code, const std::string& error_text);

bool ValidFileDescriptor(int fd);

Expand Down
50 changes: 37 additions & 13 deletions src/brpc/stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ Stream::Stream()
, _fake_socket_weak_ref(NULL)
, _connected(false)
, _closed(false)
, _error_code(0)
, _produced(0)
, _remote_consumed(0)
, _cur_buf_size(0)
Expand Down Expand Up @@ -74,6 +75,7 @@ int Stream::Create(const StreamOptions &options,
s->_connected = false;
s->_options = options;
s->_closed = false;
s->_error_code = 0;
s->_cur_buf_size = options.max_buf_size > 0 ? options.max_buf_size : 0;
if (options.max_buf_size > 0 && options.min_buf_size > options.max_buf_size) {
// set 0 if min_buf_size is invalid.
Expand Down Expand Up @@ -131,7 +133,7 @@ void Stream::BeforeRecycle(Socket *) {
if (_host_socket) {
_host_socket->RemoveStream(id());
}

// The instance is to be deleted in the consumer thread
bthread::execution_queue_stop(_consumer_queue);
}
Expand Down Expand Up @@ -466,21 +468,22 @@ int Stream::OnReceived(const StreamFrameMeta& fm, butil::IOBuf *buf, Socket* soc
if (!fm.has_continuation()) {
butil::IOBuf *tmp = _pending_buf;
_pending_buf = NULL;
if (bthread::execution_queue_execute(_consumer_queue, tmp) != 0) {
int rc = bthread::execution_queue_execute(_consumer_queue, tmp);
if (rc != 0) {
CHECK(false) << "Fail to push into channel";
delete tmp;
Close();
Close(rc, "Fail to push into channel");
}
}
break;
case FRAME_TYPE_RST:
RPC_VLOG << "stream=" << id() << " received rst frame";
Close();
Close(ECONNRESET, "Received RST frame");
break;
case FRAME_TYPE_CLOSE:
RPC_VLOG << "stream=" << id() << " received close frame";
// TODO:: See the comments in Consume
Close();
Close(0, "Received CLOSE frame");
break;
case FRAME_TYPE_UNKNOWN:
RPC_VLOG << "Received unknown frame";
Expand Down Expand Up @@ -530,15 +533,26 @@ int Stream::Consume(void *meta, bthread::TaskIterator<butil::IOBuf*>& iter) {
Stream* s = (Stream*)meta;
s->StopIdleTimer();
if (iter.is_queue_stopped()) {
// indicating the queue was closed
scoped_ptr<Stream> recycled_stream(s);
// Indicating the queue was closed.
if (s->_host_socket) {
DereferenceSocket(s->_host_socket);
s->_host_socket = NULL;
}
if (s->_options.handler != NULL) {
int error_code;
std::string error_text;
{
BAIDU_SCOPED_LOCK(s->_connect_mutex);
error_code = s->_error_code;
error_text = s->_error_text;
}
if (error_code != 0) {
// The stream is closed abnormally.
s->_options.handler->on_failed(s->id(), error_code, error_text);
}
s->_options.handler->on_closed(s->id());
}
delete s;
return 0;
}
DEFINE_SMALL_ARRAY(butil::IOBuf*, buf_list, s->_options.messages_in_batch, 256);
Expand Down Expand Up @@ -630,14 +644,21 @@ void Stream::StopIdleTimer() {
}
}

void Stream::Close() {
void Stream::Close(int error_code, const char* reason_fmt, ...) {
_fake_socket_weak_ref->SetFailed();
bthread_mutex_lock(&_connect_mutex);
if (_closed) {
bthread_mutex_unlock(&_connect_mutex);
return;
}
_closed = true;
_error_code = error_code;

va_list ap;
va_start(ap, reason_fmt);
butil::string_vappendf(&_error_text, reason_fmt, ap);
va_end(ap);

if (_connected) {
bthread_mutex_unlock(&_connect_mutex);
return;
Expand All @@ -647,14 +668,17 @@ void Stream::Close() {
return TriggerOnConnectIfNeed();
}

int Stream::SetFailed(StreamId id) {
int Stream::SetFailed(StreamId id, int error_code, const char* reason_fmt, ...) {
SocketUniquePtr ptr;
if (Socket::AddressFailedAsWell(id, &ptr) == -1) {
// Don't care recycled stream
return 0;
}
Stream* s = (Stream*)ptr->conn();
s->Close();
va_list ap;
va_start(ap, reason_fmt);
s->Close(error_code, reason_fmt, ap);
va_end(ap);
return 0;
}

Expand All @@ -665,13 +689,13 @@ void Stream::HandleRpcResponse(butil::IOBuf* response_buffer) {
ParseResult pr = policy::ParseRpcMessage(response_buffer, NULL, true, NULL);
if (!pr.is_ok()) {
CHECK(false);
Close();
Close(EPROTO, "Fail to parse rpc response message");
return;
}
InputMessageBase* msg = pr.message();
if (msg == NULL) {
CHECK(false);
Close();
Close(ENOMEM, "Message is NULL");
return;
}
_host_socket->PostponeEOF();
Expand Down Expand Up @@ -730,7 +754,7 @@ int StreamWait(StreamId stream_id, const timespec* due_time) {
}

int StreamClose(StreamId stream_id) {
return Stream::SetFailed(stream_id);
return Stream::SetFailed(stream_id, 0, "Local close");
}

int StreamCreate(StreamId *request_stream, Controller &cntl,
Expand Down
9 changes: 6 additions & 3 deletions src/brpc/stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ class StreamInputHandler {
butil::IOBuf *const messages[],
size_t size) = 0;
virtual void on_idle_timeout(StreamId id) = 0;
virtual void on_closed(StreamId id) = 0;
virtual void on_closed(StreamId id) = 0;
// `on_failed` will be called before `on_closed`
// when the stream is closed abnormally.
virtual void on_failed(StreamId id, int error_code,
const std::string& error_text) {}
};

struct StreamOptions {
Expand Down Expand Up @@ -82,8 +86,7 @@ struct StreamOptions {
StreamInputHandler* handler;
};

struct StreamWriteOptions
{
struct StreamWriteOptions {
StreamWriteOptions() : write_in_background(false) {}

// Write message to socket in background thread.
Expand Down
11 changes: 8 additions & 3 deletions src/brpc/stream_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,16 @@ class BAIDU_CACHELINE_ALIGNMENT Stream : public SocketConnection {
const timespec *due_time);
int Wait(const timespec* due_time);
void FillSettings(StreamSettings *settings);
static int SetFailed(StreamId id);
void Close();
static int SetFailed(StreamId id, int error_code, const char* reason_fmt, ...)
__attribute__ ((__format__ (__printf__, 3, 4)));
void Close(int error_code, const char* reason_fmt, ...)
__attribute__ ((__format__ (__printf__, 3, 4)));

private:
friend void StreamWait(StreamId stream_id, const timespec *due_time,
void (*on_writable)(StreamId, void*, int), void *arg);
void (*on_writable)(StreamId, void*, int), void *arg);
friend class MessageBatcher;
friend struct butil::DefaultDeleter<Stream>;
Stream();
~Stream();
int Init(const StreamOptions options);
Expand Down Expand Up @@ -111,6 +114,8 @@ friend class MessageBatcher;
ConnectMeta _connect_meta;
bool _connected;
bool _closed;
int _error_code;
std::string _error_text;

bthread_mutex_t _congestion_control_mutex;
size_t _produced;
Expand Down
Loading

0 comments on commit 337142f

Please sign in to comment.