Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix CQL Sticher 3/4] Populate map of streams to frames during parsing #1716

Merged
merged 5 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 35 additions & 59 deletions src/stirling/source_connectors/socket_tracer/conn_tracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,13 +260,14 @@ class ConnTracker : NotCopyMoveable {
using TRecordType = typename TProtocolTraits::record_type;
using TFrameType = typename TProtocolTraits::frame_type;
using TStateType = typename TProtocolTraits::state_type;
using TKey = typename TProtocolTraits::key_type;

InitProtocolState<TStateType>();

DataStreamsToFrames<TFrameType, TStateType>();
DataStreamsToFrames<TKey, TFrameType, TStateType>();

auto& req_frames = req_data()->Frames<TFrameType>();
auto& resp_frames = resp_data()->Frames<TFrameType>();
auto& req_frames = req_data()->Frames<TKey, TFrameType>();
auto& resp_frames = resp_data()->Frames<TKey, TFrameType>();
auto state_ptr = protocol_state<TStateType>();

CONN_TRACE(2) << absl::Substitute("req_frames=$0 resp_frames=$1", req_frames.size(),
Expand All @@ -279,39 +280,11 @@ class ConnTracker : NotCopyMoveable {
// TODO(@benkilimnik): Eventually, we should migrate all of the protocols to use the map.
if constexpr (TProtocolTraits::stream_support ==
protocols::BaseProtocolTraits<TRecordType>::UseStream) {
using TKey = typename TProtocolTraits::key_type;
// TODO(@benkilimnik): For now, we populate the map using the parsed req and resp deques.
// In a future PR, we should parse the map earlier in the event parser.
absl::flat_hash_map<TKey, std::deque<TFrameType>> requests;
absl::flat_hash_map<TKey, std::deque<TFrameType>> responses;
for (auto& frame : req_frames) {
// GetStreamID returns 0 by default if not specialized in protocol.
auto key = protocols::GetStreamID<TKey, TFrameType>(&frame);
requests[key].push_back(std::move(frame));
}
for (auto& frame : resp_frames) {
auto key = protocols::GetStreamID<TKey, TFrameType>(&frame);
responses[key].push_back(std::move(frame));
}
result = protocols::StitchFrames<TRecordType, TKey, TFrameType, TStateType>(
&requests, &responses, state_ptr);
// TODO(@benkilimnik): Update req and resp frame deques to match maps for now. Populate maps
// during parsing in a future PR.
req_frames.clear();
for (auto& [_, frames] : requests) {
for (auto& frame : frames) {
req_frames.push_back(std::move(frame));
}
}
resp_frames.clear();
for (auto& [_, frames] : responses) {
for (auto& frame : frames) {
resp_frames.push_back(std::move(frame));
}
}
&req_frames, &resp_frames, state_ptr);
} else {
result = protocols::StitchFrames<TRecordType, TFrameType, TStateType>(
&req_frames, &resp_frames, state_ptr);
&req_frames[0], &resp_frames[0], state_ptr);
}

CONN_TRACE(2) << absl::Substitute("records=$0", result.records.size());
Expand All @@ -325,15 +298,15 @@ class ConnTracker : NotCopyMoveable {
* Returns reference to current set of unconsumed requests.
* Note: A call to ProcessBytesToFrames() is required to parse new requests.
*/
template <typename TFrameType>
std::deque<TFrameType>& req_frames() {
return req_data()->Frames<TFrameType>();
template <typename TKey, typename TFrameType>
absl::flat_hash_map<TKey, std::deque<TFrameType>>& req_frames() {
return req_data()->Frames<TKey, TFrameType>();
}
// TODO(yzhao): req_data() requires role_ to be set. But HTTP2 uprobe tracing does
// not set that. So send_data() is created. Investigate more unified approach.
template <typename TFrameType>
const std::deque<TFrameType>& send_frames() const {
return send_data_.Frames<TFrameType>();
template <typename TKey, typename TFrameType>
const absl::flat_hash_map<TKey, std::deque<TFrameType>>& send_frames() const {
return send_data_.Frames<TKey, TFrameType>();
}

size_t http2_client_streams_size() const { return http2_client_streams_.streams().size(); }
Expand All @@ -343,13 +316,13 @@ class ConnTracker : NotCopyMoveable {
* Returns reference to current set of unconsumed responses.
* Note: A call to ProcessBytesToFrames() is required to parse new responses.
*/
template <typename TFrameType>
std::deque<TFrameType>& resp_frames() {
return resp_data()->Frames<TFrameType>();
template <typename TKey, typename TFrameType>
absl::flat_hash_map<TKey, std::deque<TFrameType>>& resp_frames() {
return resp_data()->Frames<TKey, TFrameType>();
}
template <typename TFrameType>
const std::deque<TFrameType>& recv_frames() const {
return recv_data_.Frames<TFrameType>();
template <typename TKey, typename TFrameType>
const absl::flat_hash_map<TKey, std::deque<TFrameType>>& recv_frames() const {
return recv_data_.Frames<TKey, TFrameType>();
}

const conn_id_t& conn_id() const { return conn_id_; }
Expand Down Expand Up @@ -572,13 +545,14 @@ class ConnTracker : NotCopyMoveable {
std::chrono::time_point<std::chrono::steady_clock> buffer_expiry_timestamp) {
using TFrameType = typename TProtocolTraits::frame_type;
using TStateType = typename TProtocolTraits::state_type;
using TKey = typename TProtocolTraits::key_type;

if constexpr (std::is_same_v<TFrameType, protocols::http2::Stream>) {
http2_client_streams_.Cleanup(frame_size_limit_bytes, frame_expiry_timestamp);
http2_server_streams_.Cleanup(frame_size_limit_bytes, frame_expiry_timestamp);
} else {
send_data_.CleanupFrames<TFrameType>(frame_size_limit_bytes, frame_expiry_timestamp);
recv_data_.CleanupFrames<TFrameType>(frame_size_limit_bytes, frame_expiry_timestamp);
send_data_.CleanupFrames<TKey, TFrameType>(frame_size_limit_bytes, frame_expiry_timestamp);
recv_data_.CleanupFrames<TKey, TFrameType>(frame_size_limit_bytes, frame_expiry_timestamp);
}

auto* state = protocol_state<TStateType>();
Expand Down Expand Up @@ -617,11 +591,11 @@ class ConnTracker : NotCopyMoveable {

std::string ToString() const;

template <typename TFrameType>
template <typename TKey, typename TFrameType>
void InitFrames() {
if constexpr (!std::is_same_v<TFrameType, protocols::http2::Stream>) {
send_data_.InitFrames<TFrameType>();
recv_data_.InitFrames<TFrameType>();
send_data_.InitFrames<TKey, TFrameType>();
recv_data_.InitFrames<TKey, TFrameType>();
}
}

Expand All @@ -631,6 +605,7 @@ class ConnTracker : NotCopyMoveable {
template <typename TProtocolTraits>
size_t MemUsage() const {
using TFrameType = typename TProtocolTraits::frame_type;
using TKey = typename TProtocolTraits::key_type;

size_t data_buffer_total = 0;
data_buffer_total += send_data().data_buffer().capacity();
Expand All @@ -642,8 +617,8 @@ class ConnTracker : NotCopyMoveable {
http2_events_total += http2_client_streams_.StreamsSize();
http2_events_total += http2_server_streams_.StreamsSize();
} else {
parsed_msg_total += send_data().FramesSize<TFrameType>();
parsed_msg_total += recv_data().FramesSize<TFrameType>();
parsed_msg_total += send_data().FramesSize<TKey, TFrameType>();
parsed_msg_total += recv_data().FramesSize<TKey, TFrameType>();
}

return data_buffer_total + http2_events_total + parsed_msg_total;
Expand Down Expand Up @@ -687,19 +662,19 @@ class ConnTracker : NotCopyMoveable {

void UpdateDataStats(const SocketDataEvent& event);

template <typename TFrameType, typename TStateType>
template <typename TKey, typename TFrameType, typename TStateType>
void DataStreamsToFrames() {
auto state_ptr = protocol_state<TStateType>();

DataStream* req_data_ptr = req_data();
DCHECK_NE(req_data_ptr, nullptr);
req_data_ptr->template ProcessBytesToFrames<TFrameType, TStateType>(message_type_t::kRequest,
state_ptr);
req_data_ptr->template ProcessBytesToFrames<TKey, TFrameType, TStateType>(
message_type_t::kRequest, state_ptr);

DataStream* resp_data_ptr = resp_data();
DCHECK_NE(resp_data_ptr, nullptr);
resp_data_ptr->template ProcessBytesToFrames<TFrameType, TStateType>(message_type_t::kResponse,
state_ptr);
resp_data_ptr->template ProcessBytesToFrames<TKey, TFrameType, TStateType>(
message_type_t::kResponse, state_ptr);
}

template <typename TRecordType>
Expand Down Expand Up @@ -817,6 +792,7 @@ ConnTracker::ProcessToRecords<protocols::http2::ProtocolTraits>();
template <typename TProtocolTraits>
std::string DebugString(const ConnTracker& c, std::string_view prefix) {
using TFrameType = typename TProtocolTraits::frame_type;
using TKey = typename TProtocolTraits::key_type;

std::string info;
info += absl::Substitute("$0conn_id=$1\n", prefix, ToString(c.conn_id()));
Expand All @@ -829,9 +805,9 @@ std::string DebugString(const ConnTracker& c, std::string_view prefix) {
info += c.http2_server_streams_.DebugString(absl::StrCat(prefix, " "));
} else {
info += absl::Substitute("$0recv queue\n", prefix);
info += DebugString<TFrameType>(c.recv_data(), absl::StrCat(prefix, " "));
info += DebugString<TKey, TFrameType>(c.recv_data(), absl::StrCat(prefix, " "));
info += absl::Substitute("$0send queue\n", prefix);
info += DebugString<TFrameType>(c.send_data(), absl::StrCat(prefix, " "));
info += DebugString<TKey, TFrameType>(c.send_data(), absl::StrCat(prefix, " "));
}

return info;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ TEST_F(ConnTrackerTest, MemUsage) {
auto frame1 = event_gen_.InitSendEvent<kProtocolHTTP>(kHTTPResp0);

ConnTracker tracker;
tracker.InitFrames<http::Message>();
tracker.InitFrames<http::stream_id_t, http::Message>();

// Initial memory use is not 0, because the DataStreamBuffer has a small initial capacity.
size_t mem_usage = tracker.MemUsage<http::ProtocolTraits>();
Expand Down Expand Up @@ -442,7 +442,7 @@ TEST_F(ConnTrackerTest, BufferClearedAfterExpiration) {
tracker.ProcessToRecords<http::ProtocolTraits>();
tracker.Cleanup<http::ProtocolTraits>(frame_size_limit_bytes, buffer_size_limit_bytes,
frame_expiry_timestamp, buffer_expiry_timestamp);
EXPECT_EQ(tracker.req_data()->Frames<http::Message>().size(), 1);
EXPECT_EQ((tracker.req_data()->Frames<http::stream_id_t, http::Message>()[0].size()), 1);
}

TEST_F(ConnTrackerTest, BufferTruncatedBeyondSizeLimit) {
Expand All @@ -460,7 +460,7 @@ TEST_F(ConnTrackerTest, BufferTruncatedBeyondSizeLimit) {
tracker.Cleanup<http::ProtocolTraits>(frame_size_limit_bytes, buffer_size_limit_bytes,
frame_expiry_timestamp, buffer_expiry_timestamp);
EXPECT_EQ(tracker.req_data()->data_buffer().size(), buffer_size_limit_bytes);
EXPECT_THAT(tracker.req_frames<http::Message>(), IsEmpty());
EXPECT_THAT((tracker.req_frames<http::stream_id_t, http::Message>()[0]), IsEmpty());
}

TEST_F(ConnTrackerTest, MessagesErasedAfterExpiration) {
Expand All @@ -480,13 +480,13 @@ TEST_F(ConnTrackerTest, MessagesErasedAfterExpiration) {
tracker.ProcessToRecords<http::ProtocolTraits>();
tracker.Cleanup<http::ProtocolTraits>(frame_size_limit_bytes, buffer_size_limit_bytes,
frame_expiry_timestamp, buffer_expiry_timestamp);
EXPECT_THAT(tracker.req_frames<http::Message>(), SizeIs(1));
EXPECT_THAT((tracker.req_frames<http::stream_id_t, http::Message>()[0]), SizeIs(1));

frame_expiry_timestamp = now();
tracker.ProcessToRecords<http::ProtocolTraits>();
tracker.Cleanup<http::ProtocolTraits>(frame_size_limit_bytes, buffer_size_limit_bytes,
frame_expiry_timestamp, buffer_expiry_timestamp);
EXPECT_THAT(tracker.req_frames<http::Message>(), IsEmpty());
EXPECT_THAT((tracker.req_frames<http::stream_id_t, http::Message>()[0]), IsEmpty());
}

// Tests that tracker state is kDisabled if the remote address is in the cluster's CIDR range.
Expand Down
48 changes: 27 additions & 21 deletions src/stirling/source_connectors/socket_tracer/data_stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ void DataStream::AddData(std::unique_ptr<SocketDataEvent> event) {
// To be robust to lost events, which are not necessarily aligned to parseable entity boundaries,
// ProcessBytesToFrames() will invoke a call to ParseFrames() with a stream recovery argument when
// necessary.
template <typename TFrameType, typename TStateType = protocols::NoState>
template <typename TKey, typename TFrameType, typename TStateType = protocols::NoState>
void DataStream::ProcessBytesToFrames(message_type_t type, TStateType* state) {
auto& typed_messages = Frames<TFrameType>();
auto& typed_messages = Frames<TKey, TFrameType>();

// TODO(oazizi): Convert to ECHECK once we have more confidence.
LOG_IF(WARNING, IsEOS()) << "DataStream reaches EOS, no more data to process.";
Expand Down Expand Up @@ -180,30 +180,36 @@ void DataStream::ProcessBytesToFrames(message_type_t type, TStateType* state) {
}

// PROTOCOL_LIST: Requires update on new protocols.
template void
DataStream::ProcessBytesToFrames<protocols::http::Message, protocols::http::StateWrapper>(
template void DataStream::ProcessBytesToFrames<
protocols::http::stream_id_t, protocols::http::Message, protocols::http::StateWrapper>(
message_type_t type, protocols::http::StateWrapper* state);
template void DataStream::ProcessBytesToFrames<protocols::mux::Frame, protocols::NoState>(
message_type_t type, protocols::NoState* state);
template void
DataStream::ProcessBytesToFrames<protocols::mysql::Packet, protocols::mysql::StateWrapper>(
template void DataStream::ProcessBytesToFrames<protocols::mux::stream_id_t, protocols::mux::Frame,
protocols::NoState>(message_type_t type,
protocols::NoState* state);
template void DataStream::ProcessBytesToFrames<
protocols::mysql::connection_id_t, protocols::mysql::Packet, protocols::mysql::StateWrapper>(
message_type_t type, protocols::mysql::StateWrapper* state);
template void DataStream::ProcessBytesToFrames<protocols::cass::Frame, protocols::NoState>(
message_type_t type, protocols::NoState* state);
template void
DataStream::ProcessBytesToFrames<protocols::pgsql::RegularMessage, protocols::pgsql::StateWrapper>(
message_type_t type, protocols::pgsql::StateWrapper* state);
template void DataStream::ProcessBytesToFrames<protocols::dns::Frame, protocols::NoState>(
template void DataStream::ProcessBytesToFrames<protocols::cass::stream_id_t, protocols::cass::Frame,
protocols::NoState>(message_type_t type,
protocols::NoState* state);
template void DataStream::ProcessBytesToFrames<
protocols::pgsql::connection_id_t, protocols::pgsql::RegularMessage,
protocols::pgsql::StateWrapper>(message_type_t type, protocols::pgsql::StateWrapper* state);
template void DataStream::ProcessBytesToFrames<protocols::dns::stream_id_t, protocols::dns::Frame,
protocols::NoState>(message_type_t type,
protocols::NoState* state);
template void DataStream::ProcessBytesToFrames<protocols::redis::stream_id_t,
protocols::redis::Message, protocols::NoState>(
message_type_t type, protocols::NoState* state);
template void DataStream::ProcessBytesToFrames<protocols::redis::Message, protocols::NoState>(
message_type_t type, protocols::NoState* state);
template void
DataStream::ProcessBytesToFrames<protocols::kafka::Packet, protocols::kafka::StateWrapper>(
template void DataStream::ProcessBytesToFrames<
protocols::kafka::correlation_id_t, protocols::kafka::Packet, protocols::kafka::StateWrapper>(
message_type_t type, protocols::kafka::StateWrapper* state);
template void DataStream::ProcessBytesToFrames<protocols::nats::Message, protocols::NoState>(
message_type_t type, protocols::NoState* state);
template void DataStream::ProcessBytesToFrames<protocols::amqp::Frame, protocols::NoState>(
template void DataStream::ProcessBytesToFrames<protocols::nats::stream_id_t,
protocols::nats::Message, protocols::NoState>(
message_type_t type, protocols::NoState* state);
template void DataStream::ProcessBytesToFrames<protocols::amqp::channel_id, protocols::amqp::Frame,
protocols::NoState>(message_type_t type,
protocols::NoState* state);
void DataStream::Reset() {
data_buffer_.Reset();
has_new_events_ = false;
Expand Down
Loading