Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

matchers: implement interval tree for sublinear port range matching #19912

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions api/bazel/repository_locations.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ REPOSITORY_LOCATIONS_SPEC = dict(
project_desc = "xDS API Working Group (xDS-WG)",
project_url = "https://github.com/cncf/xds",
# During the UDPA -> xDS migration, we aren't working with releases.
version = "0fa49ea1db0ccf084453766a755b2e76434d99fc",
sha256 = "9369c65e20201ea43e2c293cf024f58167dd727d864706481972ccdf3aacdaab",
release_date = "2022-01-12",
version = "4a2b9fdd466b16721f8c058d7cadf5a54e229d66",
sha256 = "518c99eded8383bd35932879a15f195c799e62929bc1f36803e741ca15238a58",
release_date = "2022-01-21",
strip_prefix = "xds-{version}",
urls = ["https://github.com/cncf/xds/archive/{version}.tar.gz"],
use_category = ["api"],
Expand Down
4 changes: 4 additions & 0 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,10 @@ def _com_google_absl():
name = "abseil_strings",
actual = "@com_google_absl//absl/strings:strings",
)
native.bind(
name = "abseil_span",
actual = "@com_google_absl//absl/types:span",
)
native.bind(
name = "abseil_int128",
actual = "@com_google_absl//absl/numeric:int128",
Expand Down
12 changes: 12 additions & 0 deletions source/common/network/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ envoy_cc_library(
],
)

envoy_cc_library(
name = "interval_tree_lib",
hdrs = ["interval_tree.h"],
external_deps = [
"abseil_optional",
"abseil_span",
],
deps = [
"//source/common/common:assert_lib",
],
)

envoy_cc_library(
name = "socket_interface_lib",
hdrs = ["socket_interface.h"],
Expand Down
203 changes: 203 additions & 0 deletions source/common/network/interval_tree.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
#pragma once

#include <algorithm>
#include <iostream>
#include <vector>

#include "source/common/common/assert.h"

#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"

namespace Envoy {
namespace Network {
namespace IntervalTree {

template <class Data, class N> class IntervalTree {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's very difficult/impossible to read this kind of code and understand what is going on as a new reader. Can you please add a large block comment at the top that describes what this does, provides a high level implementation overview, etc.? Also in general I would over comment the code below if possible.

Also in general this kind of code scares me a bit so you might consider writing a small fuzz test for it to help look for corner cases.

public:
/**
* @param data supplies a list of data and possibly overlapping number intervals [start, end).
*/
IntervalTree(const std::vector<std::tuple<Data, N, N>>& data) {
ASSERT(!data.empty(), "Must supply non-empty list.");
// Create a scaffold tree from medians in the set of all start
// and end points of the intervals to ensure the tree is balanced.
std::vector<N> medians;
medians.reserve(data.size() * 2);
for (const auto [_, start, end] : data) {
ASSERT(start < end, "Interval must be properly formed.");
medians.push_back(start);
medians.push_back(end);
}
std::sort(medians.begin(), medians.end());
root_ = std::make_unique<Node>();
root_->populate(medians);
size_t rank = 0;
// Insert intervals and recursively prune and order intervals in each node.
for (const auto& datum : data) {
root_->insert(datum, rank++);
}
root_->pruneAndOrder();
}
/**
* Returns the data of intervals containing the query number in the original list order.
**/
std::vector<Data> getData(N query) {
std::vector<RankedInterval*> result;
root_->search(query, result);
std::sort(result.begin(), result.end(),
[](const auto* lhs, const auto* rhs) { return lhs->rank_ < rhs->rank_; });
std::vector<Data> out;
out.reserve(result.size());
for (auto* elt : result) {
out.push_back(elt->data_);
}
return out;
}

private:
using Interval = std::tuple<Data, N, N>;
struct RankedInterval {
RankedInterval(const Interval& datum, size_t rank)
: data_(std::get<0>(datum)), rank_(rank), start_(std::get<1>(datum)),
end_(std::get<2>(datum)) {}
Data data_;
size_t rank_;
N start_;
N end_;
// Intervals are linked in their increasing start order and decreasing end order.
RankedInterval* next_start_{nullptr};
RankedInterval* next_end_{nullptr};
};
/**
* Nodes in the tree satisfy the following invariants:
* * Intervals in the node contain the median value.
* * Intervals in the left sub tree have the end less or equal to the median value.
* * Intervals in the right sub tree have the start greater than the median value.
* * Intervals in the node are linked from the lowest start in the ascending order.
* * Intervals in the node are linked from the highest end in the descending order.
**/
struct Node {
N median_;
std::unique_ptr<Node> left_;
std::unique_ptr<Node> right_;
// Ranked intervals are sorted by position in the input vector.
std::vector<RankedInterval> intervals_;
// Interval with the lowest start.
RankedInterval* low_start_;
// Interval with the highest end.
RankedInterval* high_end_;

void populate(absl::Span<const N> span) {
const size_t size = span.size();
const size_t mid = size >> 1;
median_ = span[mid];
// Last value equal to median on the left.
size_t left = mid;
while (left > 0 && span[left - 1] == median_) {
left--;
}
if (left > 0) {
left_ = std::make_unique<Node>();
left_->populate(span.subspan(0, left));
}
// Last value equal to median on the right.
size_t right = mid;
while (right < size - 1 && span[right + 1] == median_) {
right++;
}
if (right < size - 1) {
right_ = std::make_unique<Node>();
right_->populate(span.subspan(right + 1));
}
}
void insert(const Interval& datum, size_t rank) {
N start = std::get<1>(datum);
N end = std::get<2>(datum);
if (end <= median_) {
left_->insert(datum, rank);
} else if (median_ < start) {
right_->insert(datum, rank);
} else {
intervals_.emplace_back(datum, rank);
}
}
bool pruneAndOrder() {
bool left_empty = true;
if (left_) {
left_empty = left_->pruneAndOrder();
if (left_empty) {
left_ = nullptr;
}
}
bool right_empty = true;
if (right_) {
right_empty = right_->pruneAndOrder();
if (right_empty) {
right_ = nullptr;
}
}
if (!intervals_.empty()) {
linkStart();
linkEnd();
return false;
}
return left_empty && right_empty;
}
void linkStart() {
std::vector<RankedInterval*> sorted;
sorted.reserve(intervals_.size());
for (auto& elt : intervals_) {
sorted.push_back(&elt);
}
std::sort(sorted.begin(), sorted.end(),
[](const auto* lhs, const auto* rhs) { return lhs->start_ < rhs->start_; });
for (size_t i = 0; i < sorted.size() - 1; i++) {
sorted[i]->next_start_ = sorted[i + 1];
}
low_start_ = sorted[0];
}
void linkEnd() {
std::vector<RankedInterval*> sorted;
sorted.reserve(intervals_.size());
for (auto& elt : intervals_) {
sorted.push_back(&elt);
}
std::sort(sorted.begin(), sorted.end(),
[](const auto* lhs, const auto* rhs) { return lhs->end_ > rhs->end_; });
for (size_t i = 0; i < sorted.size() - 1; i++) {
sorted[i]->next_end_ = sorted[i + 1];
}
high_end_ = sorted[0];
}
void search(N query, std::vector<RankedInterval*>& result) {
// Always search within the node.
if (query <= median_) {
auto* cur = low_start_;
while (cur && cur->start_ <= query) {
result.push_back(cur);
cur = cur->next_start_;
}
if (query < median_ && left_) {
left_->search(query, result);
}
} else {
auto* cur = high_end_;
while (cur && cur->end_ > query) {
result.push_back(cur);
cur = cur->next_end_;
}
if (right_) {
right_->search(query, result);
}
}
}
};
using NodePtr = std::unique_ptr<Node>;
NodePtr root_;
};

} // namespace IntervalTree
} // namespace Network
} // namespace Envoy
15 changes: 15 additions & 0 deletions source/extensions/common/matcher/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,18 @@ envoy_cc_library(
"@com_github_cncf_udpa//xds/type/matcher/v3:pkg_cc_proto",
],
)

envoy_cc_library(
name = "range_matcher_lib",
srcs = ["range_matcher.cc"],
hdrs = ["range_matcher.h"],
deps = [
"//envoy/matcher:matcher_interface",
"//envoy/network:filter_interface",
"//envoy/registry",
"//envoy/server:factory_context_interface",
"//source/common/matcher:matcher_lib",
"//source/common/network:interval_tree_lib",
"@com_github_cncf_udpa//xds/type/matcher/v3:pkg_cc_proto",
],
)
16 changes: 16 additions & 0 deletions source/extensions/common/matcher/range_matcher.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include "source/extensions/common/matcher/range_matcher.h"

#include "envoy/registry/registry.h"

namespace Envoy {
namespace Extensions {
namespace Common {
namespace Matcher {

REGISTER_FACTORY(NetworkRangeMatcherFactory,
::Envoy::Matcher::CustomMatcherFactory<Network::MatchingData>);

} // namespace Matcher
} // namespace Common
} // namespace Extensions
} // namespace Envoy
113 changes: 113 additions & 0 deletions source/extensions/common/matcher/range_matcher.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#pragma once

#include "envoy/matcher/matcher.h"
#include "envoy/network/filter.h"
#include "envoy/server/factory_context.h"

#include "source/common/matcher/matcher.h"
#include "source/common/network/interval_tree.h"

#include "xds/type/matcher/v3/range.pb.h"
#include "xds/type/matcher/v3/range.pb.validate.h"

namespace Envoy {
namespace Extensions {
namespace Common {
namespace Matcher {

using ::Envoy::Matcher::DataInputFactoryCb;
using ::Envoy::Matcher::DataInputGetResult;
using ::Envoy::Matcher::DataInputPtr;
using ::Envoy::Matcher::evaluateMatch;
using ::Envoy::Matcher::MatchState;
using ::Envoy::Matcher::MatchTree;
using ::Envoy::Matcher::OnMatch;
using ::Envoy::Matcher::OnMatchFactory;
using ::Envoy::Matcher::OnMatchFactoryCb;

/**
* Implementation of a `sublinear` port range matcher.
*/
template <class DataType> class RangeMatcher : public MatchTree<DataType> {
public:
RangeMatcher(
DataInputPtr<DataType>&& data_input,
const std::shared_ptr<Network::IntervalTree::IntervalTree<OnMatch<DataType>, int32_t>>& tree)
: data_input_(std::move(data_input)), tree_(tree) {}

typename MatchTree<DataType>::MatchResult match(const DataType& data) override {
const auto input = data_input_->get(data);
if (input.data_availability_ != DataInputGetResult::DataAvailability::AllDataAvailable) {
return {MatchState::UnableToMatch, absl::nullopt};
}
if (!input.data_) {
return {MatchState::MatchComplete, absl::nullopt};
}
int32_t port;
if (!absl::SimpleAtoi(*input.data_, &port)) {
return {MatchState::MatchComplete, absl::nullopt};
}
auto values = tree_->getData(port);
for (const auto on_match : values) {
if (on_match.action_cb_) {
return {MatchState::MatchComplete, OnMatch<DataType>{on_match.action_cb_, nullptr}};
}
auto matched = evaluateMatch(*on_match.matcher_, data);
if (matched.match_state_ == MatchState::UnableToMatch) {
return {MatchState::UnableToMatch, absl::nullopt};
}
if (matched.match_state_ == MatchState::MatchComplete && matched.result_) {
return {MatchState::MatchComplete, OnMatch<DataType>{matched.result_, nullptr}};
}
}
return {MatchState::MatchComplete, absl::nullopt};
}

private:
const DataInputPtr<DataType> data_input_;
std::shared_ptr<Network::IntervalTree::IntervalTree<OnMatch<DataType>, int32_t>> tree_;
};

template <class DataType>
class RangeMatcherFactoryBase : public ::Envoy::Matcher::CustomMatcherFactory<DataType> {
public:
::Envoy::Matcher::MatchTreeFactoryCb<DataType>
createCustomMatcherFactoryCb(const Protobuf::Message& config,
Server::Configuration::ServerFactoryContext& factory_context,
DataInputFactoryCb<DataType> data_input,
OnMatchFactory<DataType>& on_match_factory) override {
const auto& typed_config =
MessageUtil::downcastAndValidate<const xds::type::matcher::v3::Int32RangeMatcher&>(
config, factory_context.messageValidationVisitor());
std::vector<OnMatchFactoryCb<DataType>> match_children;
match_children.reserve(typed_config.range_matchers().size());
for (const auto& range_matcher : typed_config.range_matchers()) {
match_children.push_back(*on_match_factory.createOnMatch(range_matcher.on_match()));
}
std::vector<std::tuple<OnMatch<DataType>, int32_t, int32_t>> data;
data.reserve(match_children.size());
size_t i = 0;
for (const auto& range_matcher : typed_config.range_matchers()) {
auto on_match = match_children[i++]();
for (const auto& range : range_matcher.ranges()) {
data.emplace_back(on_match, range.start(), range.end());
}
}
auto tree =
std::make_shared<Network::IntervalTree::IntervalTree<OnMatch<DataType>, int32_t>>(data);
return [data_input, tree]() {
return std::make_unique<RangeMatcher<DataType>>(data_input(), tree);
};
};
ProtobufTypes::MessagePtr createEmptyConfigProto() override {
return std::make_unique<xds::type::matcher::v3::Int32RangeMatcher>();
}
std::string name() const override { return "range-matcher"; }
};

class NetworkRangeMatcherFactory : public RangeMatcherFactoryBase<Network::MatchingData> {};

} // namespace Matcher
} // namespace Common
} // namespace Extensions
} // namespace Envoy
Loading