Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[thrift_proxy] Replace local HeaderMap with Http::HeaderMap[Impl] #4169

1 change: 0 additions & 1 deletion source/extensions/filters/network/thrift_proxy/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ envoy_cc_library(

envoy_cc_library(
name = "metadata_lib",
srcs = ["metadata.cc"],
hdrs = ["metadata.h"],
external_deps = ["abseil_optional"],
deps = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ bool HeaderTransportImpl::decodeFrameStart(Buffer::Instance& buffer, MessageMeta
while (num_headers-- > 0) {
std::string key = drainVarString(buffer, header_size, "header key");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: couldn't key be constructed here as a LowerCaseString?

std::string value = drainVarString(buffer, header_size, "header value");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should key/value be consts?

metadata.addHeader(Header(key, value));
metadata.headers().addCopy(Http::LowerCaseString(key), value);
}
}

Expand All @@ -172,7 +172,7 @@ void HeaderTransportImpl::encodeFrame(Buffer::Instance& buffer, const MessageMet
throw EnvoyException(fmt::format("invalid thrift header transport message size {}", msg_size));
}

const HeaderMap& headers = metadata.headers();
const Http::HeaderMap& headers = metadata.headers();
if (headers.size() > MaxHeadersSize / 2) {
// Each header takes a minimum of 2 bytes, yielding this limit.
throw EnvoyException(
Expand Down Expand Up @@ -205,10 +205,14 @@ void HeaderTransportImpl::encodeFrame(Buffer::Instance& buffer, const MessageMet
// Num headers
BufferHelper::writeVarIntI32(header_buffer, static_cast<int32_t>(headers.size()));

for (const Header& header : headers) {
writeVarString(header_buffer, header.key());
writeVarString(header_buffer, header.value());
}
headers.iterate(
[](const Http::HeaderEntry& header, void* context) -> Http::HeaderMap::Iterate {
Buffer::Instance* hb = static_cast<Buffer::Instance*>(context);
writeVarString(*hb, header.key().getStringView());
writeVarString(*hb, header.value().getStringView());
return Http::HeaderMap::Iterate::Continue;
},
&header_buffer);
}

uint64_t header_size = header_buffer.length();
Expand Down Expand Up @@ -286,7 +290,7 @@ std::string HeaderTransportImpl::drainVarString(Buffer::Instance& buffer, int32_
return value;
}

void HeaderTransportImpl::writeVarString(Buffer::Instance& buffer, const std::string& str) {
void HeaderTransportImpl::writeVarString(Buffer::Instance& buffer, const absl::string_view str) {
std::string::size_type len = str.length();
if (len > static_cast<uint32_t>(std::numeric_limits<int16_t>::max())) {
throw EnvoyException(fmt::format("header string too long: {}", len));
Expand All @@ -296,7 +300,7 @@ void HeaderTransportImpl::writeVarString(Buffer::Instance& buffer, const std::st
if (len == 0) {
return;
}
buffer.add(str);
buffer.add(str.data(), len);
}

class HeaderTransportConfigFactory : public TransportFactoryBase<HeaderTransportImpl> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class HeaderTransportImpl : public Transport {
static int32_t drainVarIntI32(Buffer::Instance& buffer, int32_t& header_size, const char* desc);
static std::string drainVarString(Buffer::Instance& buffer, int32_t& header_size,
const char* desc);
static void writeVarString(Buffer::Instance& buffer, const std::string& str);
static void writeVarString(Buffer::Instance& buffer, const absl::string_view str);

void setException(AppExceptionType type, std::string reason) {
if (exception_.has_value()) {
Expand Down
47 changes: 0 additions & 47 deletions source/extensions/filters/network/thrift_proxy/metadata.cc

This file was deleted.

83 changes: 4 additions & 79 deletions source/extensions/filters/network/thrift_proxy/metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <string>

#include "common/common/macros.h"
#include "common/http/header_map_impl.h"

#include "extensions/filters/network/thrift_proxy/thrift.h"

Expand All @@ -17,82 +18,6 @@ namespace Extensions {
namespace NetworkFilters {
namespace ThriftProxy {

/**
* Header is a name-value pair in Thrift transport or protocol headers.
*/
class Header {
public:
Header(const std::string key, const std::string value) : key_(key), value_(value) {}
Header(const Header& rhs) : key_(rhs.key_), value_(rhs.value_) {}

const std::string& key() const { return key_; }
const std::string& value() const { return value_; }

private:
std::string key_;
std::string value_;
};

// TODO(zuercher): replace this with Http::HeaderMap[Impl]
/*
* HeaderMap contains Thrift transport and/or protocol-level headers.
*/
class HeaderMap {
public:
HeaderMap() {}
HeaderMap(const std::initializer_list<std::pair<std::string, std::string>>& values);
HeaderMap(const HeaderMap& rhs);

/**
* @return true if the HeaderMap is empty
*/
bool empty() const { return headers_.empty(); }

/**
* @return uint32_t the number of headers in the map
*/
uint32_t size() const { return headers_.size(); }

/**
* @param header Header to move into the HeaderMap
*/
void add(Header&& header) { headers_.emplace_back(std::move(header)); }

/**
* Clears all Headers from the HeaderMap.
*/
void clear() { headers_.clear(); }

/**
* Retrieves a Header from the HeaderMap.
* @param key std::string containing the key to lookup
* @return Header* corresponding to key or nullptr if not found.
*/
Header* get(const std::string& key);

/**
* Const iterators for the HeaderMap.
*/
std::list<Header>::const_iterator begin() const noexcept { return headers_.begin(); }
std::list<Header>::const_iterator end() const noexcept { return headers_.end(); }
std::list<Header>::const_iterator cbegin() const noexcept { return headers_.cbegin(); }
std::list<Header>::const_iterator cend() const noexcept { return headers_.cend(); }

/**
* For testing. Equality is based on equality of the backing list. This is an exact match
* comparison (order matters).
*/
bool operator==(const HeaderMap& rhs) const;

/**
* @return an empty HeaderMap
*/
static const HeaderMap& emptyHeaderMap() { CONSTRUCT_ON_FIRST_USE(HeaderMap, HeaderMap({})); }

private:
std::list<Header> headers_;
};

/**
* MessageMetadata encapsulates metadata about Thrift messages. The various fields are considered
* optional since they may come from either the transport or protocol in some cases. Unless
Expand Down Expand Up @@ -123,11 +48,11 @@ class MessageMetadata {
MessageType messageType() const { return msg_type_.value(); }
void setMessageType(MessageType msg_type) { msg_type_ = msg_type; }

void addHeader(Header&& header) { headers_.add(std::move(header)); }
/**
* @return HeaderMap of current headers (never throws)
*/
const HeaderMap& headers() const { return headers_; }
const Http::HeaderMap& headers() const { return headers_; }
Http::HeaderMap& headers() { return headers_; }

bool hasAppException() const { return app_ex_type_.has_value(); }
void setAppException(AppExceptionType app_ex_type, const std::string& message) {
Expand All @@ -143,7 +68,7 @@ class MessageMetadata {
absl::optional<std::string> method_name_{};
absl::optional<int32_t> seq_id_{};
absl::optional<MessageType> msg_type_{};
HeaderMap headers_;
Http::HeaderMapImpl headers_;
absl::optional<AppExceptionType> app_ex_type_;
absl::optional<std::string> app_ex_msg_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class BinaryProtocolTest : public testing::Test {
EXPECT_FALSE(metadata_.hasFrameSize());
EXPECT_FALSE(metadata_.hasProtocol());
EXPECT_FALSE(metadata_.hasAppException());
EXPECT_TRUE(metadata_.headers().empty());
EXPECT_EQ(metadata_.headers().size(), 0);
}

void expectDefaultMetadata() { expectMetadata("-", MessageType::Oneway, 1); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class CompactProtocolTest : public testing::Test {
EXPECT_FALSE(metadata_.hasFrameSize());
EXPECT_FALSE(metadata_.hasProtocol());
EXPECT_FALSE(metadata_.hasAppException());
EXPECT_TRUE(metadata_.headers().empty());
EXPECT_EQ(metadata_.headers().size(), 0);
}

void expectDefaultMetadata() { expectMetadata("-", MessageType::Oneway, 1); }
Expand Down
2 changes: 0 additions & 2 deletions test/extensions/filters/network/thrift_proxy/decoder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1040,8 +1040,6 @@ TEST(DecoderTest, OnDataHandlesStopIterationAndResumes) {
Buffer::OwnedImpl buffer;
bool underflow = true;

HeaderMap headers{{"test", "header"}};

EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _))
.WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool {
metadata.setFrameSize(100);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ class MockBuffer : public Envoy::MockBuffer {
MOCK_CONST_METHOD0(length, uint64_t());
};

MessageMetadata mkMessageMetadata(uint32_t num_headers) {
MessageMetadata metadata;
MessageMetadata* mkMessageMetadata(uint32_t num_headers) {
MessageMetadata* metadata = new MessageMetadata;
while (num_headers-- > 0) {
metadata.addHeader(Header("x", "y"));
metadata->headers().addCopy(Http::LowerCaseString("x"), "y");
}
return metadata;
}
Expand Down Expand Up @@ -439,7 +439,7 @@ TEST(HeaderTransportTest, InfoBlock) {
HeaderTransportImpl transport;
Buffer::OwnedImpl buffer;
MessageMetadata metadata;
metadata.addHeader(Header("not", "empty"));
metadata.headers().addCopy(Http::LowerCaseString("not"), "empty");

addInt32(buffer, 200);
addInt16(buffer, 0x0FFF);
Expand All @@ -459,16 +459,17 @@ TEST(HeaderTransportTest, InfoBlock) {
addInt8(buffer, 0); // empty value
addInt8(buffer, 0); // padding

HeaderMap expected_headers{
{"not", "empty"},
{"key", "value"},
{"key2", std::string(128, 'x')},
{"", ""},
};
Http::HeaderMapImpl expected_headers;
expected_headers.addCopy(Http::LowerCaseString("not"), "empty");
expected_headers.addCopy(Http::LowerCaseString("key"), "value");
expected_headers.addCopy(Http::LowerCaseString("key2"), std::string(128, 'x'));
expected_headers.addCopy(Http::LowerCaseString(""), "");

EXPECT_TRUE(transport.decodeFrameStart(buffer, metadata));
EXPECT_THAT(metadata, HasFrameSize(38U));
EXPECT_EQ(expected_headers, metadata.headers());

Http::HeaderMapImpl& actual_headers = dynamic_cast<Http::HeaderMapImpl&>(metadata.headers());
EXPECT_EQ(expected_headers, actual_headers);
EXPECT_EQ(buffer.length(), 0);
}

Expand Down Expand Up @@ -530,22 +531,23 @@ TEST(HeaderTransportImpl, TestEncodeFrame) {
// Too many headers
{
Buffer::OwnedImpl buffer;
MessageMetadata metadata = mkMessageMetadata(32769);
metadata.setProtocol(ProtocolType::Binary);
MessageMetadata* metadata = mkMessageMetadata(32769);
metadata->setProtocol(ProtocolType::Binary);

Buffer::OwnedImpl msg;
msg.add("fake message");

EXPECT_THROW_WITH_MESSAGE(transport.encodeFrame(buffer, metadata, msg), EnvoyException,
EXPECT_THROW_WITH_MESSAGE(transport.encodeFrame(buffer, *metadata, msg), EnvoyException,
"invalid thrift header transport too many headers 32769");
delete (metadata);
}

// Header string too large
{
Buffer::OwnedImpl buffer;
MessageMetadata metadata;
metadata.setProtocol(ProtocolType::Binary);
metadata.addHeader(Header("key", std::string(32768, 'x')));
metadata.headers().addCopy(Http::LowerCaseString("key"), std::string(32768, 'x'));

Buffer::OwnedImpl msg;
msg.add("fake message");
Expand All @@ -559,10 +561,10 @@ TEST(HeaderTransportImpl, TestEncodeFrame) {
Buffer::OwnedImpl buffer;
MessageMetadata metadata;
metadata.setProtocol(ProtocolType::Binary);
metadata.addHeader(Header("k1", std::string(16384, 'x')));
metadata.addHeader(Header("k2", std::string(16384, 'x')));
metadata.addHeader(Header("k3", std::string(16384, 'x')));
metadata.addHeader(Header("k4", std::string(16384, 'x')));
metadata.headers().addCopy(Http::LowerCaseString("k1"), std::string(16384, 'x'));
metadata.headers().addCopy(Http::LowerCaseString("k2"), std::string(16384, 'x'));
metadata.headers().addCopy(Http::LowerCaseString("k3"), std::string(16384, 'x'));
metadata.headers().addCopy(Http::LowerCaseString("k4"), std::string(16384, 'x'));

Buffer::OwnedImpl msg;
msg.add("fake message");
Expand Down Expand Up @@ -620,8 +622,8 @@ TEST(HeaderTransportImpl, TestEncodeFrame) {
MessageMetadata metadata;
metadata.setProtocol(ProtocolType::Compact);
metadata.setSequenceId(10);
metadata.addHeader(Header("key", "value"));
metadata.addHeader(Header("", ""));
metadata.headers().addCopy(Http::LowerCaseString("key"), "value");
metadata.headers().addCopy(Http::LowerCaseString(""), "");
Buffer::OwnedImpl msg;
msg.add("fake message");

Expand Down
Loading