Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support on_failed callback for streaming rpc #2565

Merged
merged 2 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,6 @@ CTestTestfile.cmake
/test/curl.out
/test/out.txt
/test/recordio_ref.io

# Ignore protoc-gen-mcpack files
/protoc-gen-mcpack*/
3 changes: 2 additions & 1 deletion src/brpc/controller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1387,7 +1387,8 @@ void Controller::HandleStreamConnection(Socket *host_socket) {
}
}
if (FailedInline()) {
Stream::SetFailed(_request_stream);
Stream::SetFailed(_request_stream, _error_code,
"%s", _error_text.c_str());
if (_remote_stream_settings != NULL) {
policy::SendStreamRst(host_socket,
_remote_stream_settings->stream_id());
Expand Down
10 changes: 6 additions & 4 deletions src/brpc/policy/baidu_rpc_protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,13 @@ void SendRpcResponse(int64_t correlation_id,
accessor.remote_stream_settings()->stream_id(),
accessor.response_stream()) != 0) {
const int errcode = errno;
PLOG_IF(WARNING, errcode != EPIPE) << "Fail to write into " << *sock;
cntl->SetFailed(errcode, "Fail to write into %s",
sock->description().c_str());
std::string error_text = butil::string_printf(64, "Fail to write into %s",
sock->description().c_str());
PLOG_IF(WARNING, errcode != EPIPE) << error_text;
cntl->SetFailed(errcode, "%s", error_text.c_str());
if(stream_ptr) {
((Stream*)stream_ptr->conn())->Close();
((Stream*)stream_ptr->conn())->Close(errcode, "%s",
error_text.c_str());
}
return;
}
Expand Down
10 changes: 5 additions & 5 deletions src/brpc/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,7 @@ int Socket::SetFailed(int error_code, const char* error_fmt, ...) {
&_id_wait_list, error_code, error_text,
&_id_wait_list_mutex));

ResetAllStreams();
ResetAllStreams(error_code, error_text);
// _app_connect shouldn't be set to NULL in SetFailed otherwise
// HC is always not supported.
// FIXME: Design a better interface for AppConnect
Expand Down Expand Up @@ -2541,7 +2541,7 @@ int Socket::RemoveStream(StreamId stream_id) {
return 0;
}

void Socket::ResetAllStreams() {
void Socket::ResetAllStreams(int error_code, const std::string& error_text) {
DCHECK(Failed());
std::set<StreamId> saved_stream_set;
_stream_mutex.lock();
Expand All @@ -2552,9 +2552,9 @@ void Socket::ResetAllStreams() {
saved_stream_set.swap(*_stream_set);
}
_stream_mutex.unlock();
for (std::set<StreamId>::const_iterator
it = saved_stream_set.begin(); it != saved_stream_set.end(); ++it) {
Stream::SetFailed(*it);
for (auto it = saved_stream_set.begin();
it != saved_stream_set.end(); ++it) {
Stream::SetFailed(*it, error_code, "%s", error_text.c_str());
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/brpc/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ friend void DereferenceSocket(Socket*);
// broken socket.
int AddStream(StreamId stream_id);
int RemoveStream(StreamId stream_id);
void ResetAllStreams();
void ResetAllStreams(int error_code, const std::string& error_text);

bool ValidFileDescriptor(int fd);

Expand Down
52 changes: 38 additions & 14 deletions src/brpc/stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ Stream::Stream()
, _fake_socket_weak_ref(NULL)
, _connected(false)
, _closed(false)
, _error_code(0)
, _produced(0)
, _remote_consumed(0)
, _cur_buf_size(0)
Expand Down Expand Up @@ -74,6 +75,7 @@ int Stream::Create(const StreamOptions &options,
s->_connected = false;
s->_options = options;
s->_closed = false;
s->_error_code = 0;
s->_cur_buf_size = options.max_buf_size > 0 ? options.max_buf_size : 0;
if (options.max_buf_size > 0 && options.min_buf_size > options.max_buf_size) {
// set 0 if min_buf_size is invalid.
Expand Down Expand Up @@ -131,7 +133,7 @@ void Stream::BeforeRecycle(Socket *) {
if (_host_socket) {
_host_socket->RemoveStream(id());
}

// The instance is to be deleted in the consumer thread
bthread::execution_queue_stop(_consumer_queue);
}
Expand Down Expand Up @@ -466,21 +468,22 @@ int Stream::OnReceived(const StreamFrameMeta& fm, butil::IOBuf *buf, Socket* soc
if (!fm.has_continuation()) {
butil::IOBuf *tmp = _pending_buf;
_pending_buf = NULL;
if (bthread::execution_queue_execute(_consumer_queue, tmp) != 0) {
int rc = bthread::execution_queue_execute(_consumer_queue, tmp);
if (rc != 0) {
CHECK(false) << "Fail to push into channel";
delete tmp;
Close();
Close(rc, "Fail to push into channel");
}
}
break;
case FRAME_TYPE_RST:
RPC_VLOG << "stream=" << id() << " received rst frame";
Close();
Close(ECONNRESET, "Received RST frame");
break;
case FRAME_TYPE_CLOSE:
RPC_VLOG << "stream=" << id() << " received close frame";
// TODO:: See the comments in Consume
Close();
Close(0, "Received CLOSE frame");
break;
case FRAME_TYPE_UNKNOWN:
RPC_VLOG << "Received unknown frame";
Expand Down Expand Up @@ -517,7 +520,7 @@ class MessageBatcher {
_total_length += buf->length();

}
size_t total_length() { return _total_length; }
size_t total_length() const { return _total_length; }
private:
butil::IOBuf** _storage;
size_t _cap;
Expand All @@ -530,15 +533,26 @@ int Stream::Consume(void *meta, bthread::TaskIterator<butil::IOBuf*>& iter) {
Stream* s = (Stream*)meta;
s->StopIdleTimer();
if (iter.is_queue_stopped()) {
// indicating the queue was closed
scoped_ptr<Stream> recycled_stream(s);
// Indicating the queue was closed.
if (s->_host_socket) {
DereferenceSocket(s->_host_socket);
s->_host_socket = NULL;
}
if (s->_options.handler != NULL) {
int error_code;
std::string error_text;
{
BAIDU_SCOPED_LOCK(s->_connect_mutex);
error_code = s->_error_code;
error_text = s->_error_text;
}
if (error_code != 0) {
// The stream is closed abnormally.
s->_options.handler->on_failed(s->id(), error_code, error_text);
}
s->_options.handler->on_closed(s->id());
}
delete s;
return 0;
}
DEFINE_SMALL_ARRAY(butil::IOBuf*, buf_list, s->_options.messages_in_batch, 256);
Expand Down Expand Up @@ -630,14 +644,21 @@ void Stream::StopIdleTimer() {
}
}

void Stream::Close() {
void Stream::Close(int error_code, const char* reason_fmt, ...) {
_fake_socket_weak_ref->SetFailed();
bthread_mutex_lock(&_connect_mutex);
if (_closed) {
bthread_mutex_unlock(&_connect_mutex);
return;
}
_closed = true;
_error_code = error_code;

va_list ap;
va_start(ap, reason_fmt);
butil::string_vappendf(&_error_text, reason_fmt, ap);
va_end(ap);

if (_connected) {
bthread_mutex_unlock(&_connect_mutex);
return;
Expand All @@ -647,14 +668,17 @@ void Stream::Close() {
return TriggerOnConnectIfNeed();
}

int Stream::SetFailed(StreamId id) {
int Stream::SetFailed(StreamId id, int error_code, const char* reason_fmt, ...) {
SocketUniquePtr ptr;
if (Socket::AddressFailedAsWell(id, &ptr) == -1) {
// Don't care recycled stream
return 0;
}
Stream* s = (Stream*)ptr->conn();
s->Close();
va_list ap;
va_start(ap, reason_fmt);
s->Close(error_code, reason_fmt, ap);
va_end(ap);
return 0;
}

Expand All @@ -665,13 +689,13 @@ void Stream::HandleRpcResponse(butil::IOBuf* response_buffer) {
ParseResult pr = policy::ParseRpcMessage(response_buffer, NULL, true, NULL);
if (!pr.is_ok()) {
CHECK(false);
Close();
Close(EPROTO, "Fail to parse rpc response message");
return;
}
InputMessageBase* msg = pr.message();
if (msg == NULL) {
CHECK(false);
Close();
Close(ENOMEM, "Message is NULL");
return;
}
_host_socket->PostponeEOF();
Expand Down Expand Up @@ -730,7 +754,7 @@ int StreamWait(StreamId stream_id, const timespec* due_time) {
}

int StreamClose(StreamId stream_id) {
return Stream::SetFailed(stream_id);
return Stream::SetFailed(stream_id, 0, "Local close");
}

int StreamCreate(StreamId *request_stream, Controller &cntl,
Expand Down
9 changes: 6 additions & 3 deletions src/brpc/stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ class StreamInputHandler {
butil::IOBuf *const messages[],
size_t size) = 0;
virtual void on_idle_timeout(StreamId id) = 0;
virtual void on_closed(StreamId id) = 0;
virtual void on_closed(StreamId id) = 0;
// `on_failed` will be called before `on_closed`
// when the stream is closed abnormally.
virtual void on_failed(StreamId id, int error_code,
const std::string& error_text) {}
};

struct StreamOptions {
Expand Down Expand Up @@ -82,8 +86,7 @@ struct StreamOptions {
StreamInputHandler* handler;
};

struct StreamWriteOptions
{
struct StreamWriteOptions {
StreamWriteOptions() : write_in_background(false) {}

// Write message to socket in background thread.
Expand Down
11 changes: 8 additions & 3 deletions src/brpc/stream_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,16 @@ class BAIDU_CACHELINE_ALIGNMENT Stream : public SocketConnection {
const timespec *due_time);
int Wait(const timespec* due_time);
void FillSettings(StreamSettings *settings);
static int SetFailed(StreamId id);
void Close();
static int SetFailed(StreamId id, int error_code, const char* reason_fmt, ...)
__attribute__ ((__format__ (__printf__, 3, 4)));
void Close(int error_code, const char* reason_fmt, ...)
__attribute__ ((__format__ (__printf__, 3, 4)));

private:
friend void StreamWait(StreamId stream_id, const timespec *due_time,
void (*on_writable)(StreamId, void*, int), void *arg);
void (*on_writable)(StreamId, void*, int), void *arg);
friend class MessageBatcher;
friend struct butil::DefaultDeleter<Stream>;
Stream();
~Stream();
int Init(const StreamOptions options);
Expand Down Expand Up @@ -111,6 +114,8 @@ friend class MessageBatcher;
ConnectMeta _connect_meta;
bool _connected;
bool _closed;
int _error_code;
std::string _error_text;

bthread_mutex_t _congestion_control_mutex;
size_t _produced;
Expand Down
Loading
Loading