-
Notifications
You must be signed in to change notification settings - Fork 4.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
matchers: implement interval tree for sublinear port range matching #19912
Closed
Closed
Changes from 6 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
bfb1d89
initial
kyessenov 6f849cf
Merge remote-tracking branch 'upstream/main' into interval_tree
kyessenov 6f49f4b
impl
kyessenov e348503
matchers: implement range matching
kyessenov 5f86a9c
remove stale
kyessenov f93a17d
add missing word
kyessenov 4f6b4b1
Merge remote-tracking branch 'upstream/main' into interval_tree
kyessenov 8e4ec04
implementation
kyessenov File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
#pragma once | ||
|
||
#include <algorithm> | ||
#include <iostream> | ||
#include <vector> | ||
|
||
#include "source/common/common/assert.h" | ||
|
||
#include "absl/strings/str_cat.h" | ||
#include "absl/types/optional.h" | ||
#include "absl/types/span.h" | ||
|
||
namespace Envoy { | ||
namespace Network { | ||
namespace IntervalTree { | ||
|
||
template <class Data, class N> class IntervalTree { | ||
public: | ||
/** | ||
* @param data supplies a list of data and possibly overlapping number intervals [start, end). | ||
*/ | ||
IntervalTree(const std::vector<std::tuple<Data, N, N>>& data) { | ||
ASSERT(!data.empty(), "Must supply non-empty list."); | ||
// Create a scaffold tree from medians in the set of all start | ||
// and end points of the intervals to ensure the tree is balanced. | ||
std::vector<N> medians; | ||
medians.reserve(data.size() * 2); | ||
for (const auto [_, start, end] : data) { | ||
ASSERT(start < end, "Interval must be properly formed."); | ||
medians.push_back(start); | ||
medians.push_back(end); | ||
} | ||
std::sort(medians.begin(), medians.end()); | ||
root_ = std::make_unique<Node>(); | ||
root_->populate(medians); | ||
size_t rank = 0; | ||
// Insert intervals and recursively prune and order intervals in each node. | ||
for (const auto& datum : data) { | ||
root_->insert(datum, rank++); | ||
} | ||
root_->pruneAndOrder(); | ||
} | ||
/** | ||
* Returns the data of intervals containing the query number in the original list order. | ||
**/ | ||
std::vector<Data> getData(N query) { | ||
std::vector<RankedInterval*> result; | ||
root_->search(query, result); | ||
std::sort(result.begin(), result.end(), | ||
[](const auto* lhs, const auto* rhs) { return lhs->rank_ < rhs->rank_; }); | ||
std::vector<Data> out; | ||
out.reserve(result.size()); | ||
for (auto* elt : result) { | ||
out.push_back(elt->data_); | ||
} | ||
return out; | ||
} | ||
|
||
private: | ||
using Interval = std::tuple<Data, N, N>; | ||
struct RankedInterval { | ||
RankedInterval(const Interval& datum, size_t rank) | ||
: data_(std::get<0>(datum)), rank_(rank), start_(std::get<1>(datum)), | ||
end_(std::get<2>(datum)) {} | ||
Data data_; | ||
size_t rank_; | ||
N start_; | ||
N end_; | ||
// Intervals are linked in their increasing start order and decreasing end order. | ||
RankedInterval* next_start_{nullptr}; | ||
RankedInterval* next_end_{nullptr}; | ||
}; | ||
/** | ||
* Nodes in the tree satisfy the following invariants: | ||
* * Intervals in the node contain the median value. | ||
* * Intervals in the left sub tree have the end less or equal to the median value. | ||
* * Intervals in the right sub tree have the start greater than the median value. | ||
* * Intervals in the node are linked from the lowest start in the ascending order. | ||
* * Intervals in the node are linked from the highest end in the descending order. | ||
**/ | ||
struct Node { | ||
N median_; | ||
std::unique_ptr<Node> left_; | ||
std::unique_ptr<Node> right_; | ||
// Ranked intervals are sorted by position in the input vector. | ||
std::vector<RankedInterval> intervals_; | ||
// Interval with the lowest start. | ||
RankedInterval* low_start_; | ||
// Interval with the highest end. | ||
RankedInterval* high_end_; | ||
|
||
void populate(absl::Span<const N> span) { | ||
const size_t size = span.size(); | ||
const size_t mid = size >> 1; | ||
median_ = span[mid]; | ||
// Last value equal to median on the left. | ||
size_t left = mid; | ||
while (left > 0 && span[left - 1] == median_) { | ||
left--; | ||
} | ||
if (left > 0) { | ||
left_ = std::make_unique<Node>(); | ||
left_->populate(span.subspan(0, left)); | ||
} | ||
// Last value equal to median on the right. | ||
size_t right = mid; | ||
while (right < size - 1 && span[right + 1] == median_) { | ||
right++; | ||
} | ||
if (right < size - 1) { | ||
right_ = std::make_unique<Node>(); | ||
right_->populate(span.subspan(right + 1)); | ||
} | ||
} | ||
void insert(const Interval& datum, size_t rank) { | ||
N start = std::get<1>(datum); | ||
N end = std::get<2>(datum); | ||
if (end <= median_) { | ||
left_->insert(datum, rank); | ||
} else if (median_ < start) { | ||
right_->insert(datum, rank); | ||
} else { | ||
intervals_.emplace_back(datum, rank); | ||
} | ||
} | ||
bool pruneAndOrder() { | ||
bool left_empty = true; | ||
if (left_) { | ||
left_empty = left_->pruneAndOrder(); | ||
if (left_empty) { | ||
left_ = nullptr; | ||
} | ||
} | ||
bool right_empty = true; | ||
if (right_) { | ||
right_empty = right_->pruneAndOrder(); | ||
if (right_empty) { | ||
right_ = nullptr; | ||
} | ||
} | ||
if (!intervals_.empty()) { | ||
linkStart(); | ||
linkEnd(); | ||
return false; | ||
} | ||
return left_empty && right_empty; | ||
} | ||
void linkStart() { | ||
std::vector<RankedInterval*> sorted; | ||
sorted.reserve(intervals_.size()); | ||
for (auto& elt : intervals_) { | ||
sorted.push_back(&elt); | ||
} | ||
std::sort(sorted.begin(), sorted.end(), | ||
[](const auto* lhs, const auto* rhs) { return lhs->start_ < rhs->start_; }); | ||
for (size_t i = 0; i < sorted.size() - 1; i++) { | ||
sorted[i]->next_start_ = sorted[i + 1]; | ||
} | ||
low_start_ = sorted[0]; | ||
} | ||
void linkEnd() { | ||
std::vector<RankedInterval*> sorted; | ||
sorted.reserve(intervals_.size()); | ||
for (auto& elt : intervals_) { | ||
sorted.push_back(&elt); | ||
} | ||
std::sort(sorted.begin(), sorted.end(), | ||
[](const auto* lhs, const auto* rhs) { return lhs->end_ > rhs->end_; }); | ||
for (size_t i = 0; i < sorted.size() - 1; i++) { | ||
sorted[i]->next_end_ = sorted[i + 1]; | ||
} | ||
high_end_ = sorted[0]; | ||
} | ||
void search(N query, std::vector<RankedInterval*>& result) { | ||
// Always search within the node. | ||
if (query <= median_) { | ||
auto* cur = low_start_; | ||
while (cur && cur->start_ <= query) { | ||
result.push_back(cur); | ||
cur = cur->next_start_; | ||
} | ||
if (query < median_ && left_) { | ||
left_->search(query, result); | ||
} | ||
} else { | ||
auto* cur = high_end_; | ||
while (cur && cur->end_ > query) { | ||
result.push_back(cur); | ||
cur = cur->next_end_; | ||
} | ||
if (right_) { | ||
right_->search(query, result); | ||
} | ||
} | ||
} | ||
}; | ||
using NodePtr = std::unique_ptr<Node>; | ||
NodePtr root_; | ||
}; | ||
|
||
} // namespace IntervalTree | ||
} // namespace Network | ||
} // namespace Envoy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#include "source/extensions/common/matcher/range_matcher.h" | ||
|
||
#include "envoy/registry/registry.h" | ||
|
||
namespace Envoy { | ||
namespace Extensions { | ||
namespace Common { | ||
namespace Matcher { | ||
|
||
REGISTER_FACTORY(NetworkRangeMatcherFactory, | ||
::Envoy::Matcher::CustomMatcherFactory<Network::MatchingData>); | ||
|
||
} // namespace Matcher | ||
} // namespace Common | ||
} // namespace Extensions | ||
} // namespace Envoy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
#pragma once | ||
|
||
#include "envoy/matcher/matcher.h" | ||
#include "envoy/network/filter.h" | ||
#include "envoy/server/factory_context.h" | ||
|
||
#include "source/common/matcher/matcher.h" | ||
#include "source/common/network/interval_tree.h" | ||
|
||
#include "xds/type/matcher/v3/range.pb.h" | ||
#include "xds/type/matcher/v3/range.pb.validate.h" | ||
|
||
namespace Envoy { | ||
namespace Extensions { | ||
namespace Common { | ||
namespace Matcher { | ||
|
||
using ::Envoy::Matcher::DataInputFactoryCb; | ||
using ::Envoy::Matcher::DataInputGetResult; | ||
using ::Envoy::Matcher::DataInputPtr; | ||
using ::Envoy::Matcher::evaluateMatch; | ||
using ::Envoy::Matcher::MatchState; | ||
using ::Envoy::Matcher::MatchTree; | ||
using ::Envoy::Matcher::OnMatch; | ||
using ::Envoy::Matcher::OnMatchFactory; | ||
using ::Envoy::Matcher::OnMatchFactoryCb; | ||
|
||
/** | ||
* Implementation of a `sublinear` port range matcher. | ||
*/ | ||
template <class DataType> class RangeMatcher : public MatchTree<DataType> { | ||
public: | ||
RangeMatcher( | ||
DataInputPtr<DataType>&& data_input, | ||
const std::shared_ptr<Network::IntervalTree::IntervalTree<OnMatch<DataType>, int32_t>>& tree) | ||
: data_input_(std::move(data_input)), tree_(tree) {} | ||
|
||
typename MatchTree<DataType>::MatchResult match(const DataType& data) override { | ||
const auto input = data_input_->get(data); | ||
if (input.data_availability_ != DataInputGetResult::DataAvailability::AllDataAvailable) { | ||
return {MatchState::UnableToMatch, absl::nullopt}; | ||
} | ||
if (!input.data_) { | ||
return {MatchState::MatchComplete, absl::nullopt}; | ||
} | ||
int32_t port; | ||
if (!absl::SimpleAtoi(*input.data_, &port)) { | ||
return {MatchState::MatchComplete, absl::nullopt}; | ||
} | ||
auto values = tree_->getData(port); | ||
for (const auto on_match : values) { | ||
if (on_match.action_cb_) { | ||
return {MatchState::MatchComplete, OnMatch<DataType>{on_match.action_cb_, nullptr}}; | ||
} | ||
auto matched = evaluateMatch(*on_match.matcher_, data); | ||
if (matched.match_state_ == MatchState::UnableToMatch) { | ||
return {MatchState::UnableToMatch, absl::nullopt}; | ||
} | ||
if (matched.match_state_ == MatchState::MatchComplete && matched.result_) { | ||
return {MatchState::MatchComplete, OnMatch<DataType>{matched.result_, nullptr}}; | ||
} | ||
} | ||
return {MatchState::MatchComplete, absl::nullopt}; | ||
} | ||
|
||
private: | ||
const DataInputPtr<DataType> data_input_; | ||
std::shared_ptr<Network::IntervalTree::IntervalTree<OnMatch<DataType>, int32_t>> tree_; | ||
}; | ||
|
||
template <class DataType> | ||
class RangeMatcherFactoryBase : public ::Envoy::Matcher::CustomMatcherFactory<DataType> { | ||
public: | ||
::Envoy::Matcher::MatchTreeFactoryCb<DataType> | ||
createCustomMatcherFactoryCb(const Protobuf::Message& config, | ||
Server::Configuration::ServerFactoryContext& factory_context, | ||
DataInputFactoryCb<DataType> data_input, | ||
OnMatchFactory<DataType>& on_match_factory) override { | ||
const auto& typed_config = | ||
MessageUtil::downcastAndValidate<const xds::type::matcher::v3::Int32RangeMatcher&>( | ||
config, factory_context.messageValidationVisitor()); | ||
std::vector<OnMatchFactoryCb<DataType>> match_children; | ||
match_children.reserve(typed_config.range_matchers().size()); | ||
for (const auto& range_matcher : typed_config.range_matchers()) { | ||
match_children.push_back(*on_match_factory.createOnMatch(range_matcher.on_match())); | ||
} | ||
std::vector<std::tuple<OnMatch<DataType>, int32_t, int32_t>> data; | ||
data.reserve(match_children.size()); | ||
size_t i = 0; | ||
for (const auto& range_matcher : typed_config.range_matchers()) { | ||
auto on_match = match_children[i++](); | ||
for (const auto& range : range_matcher.ranges()) { | ||
data.emplace_back(on_match, range.start(), range.end()); | ||
} | ||
} | ||
auto tree = | ||
std::make_shared<Network::IntervalTree::IntervalTree<OnMatch<DataType>, int32_t>>(data); | ||
return [data_input, tree]() { | ||
return std::make_unique<RangeMatcher<DataType>>(data_input(), tree); | ||
}; | ||
}; | ||
ProtobufTypes::MessagePtr createEmptyConfigProto() override { | ||
return std::make_unique<xds::type::matcher::v3::Int32RangeMatcher>(); | ||
} | ||
std::string name() const override { return "range-matcher"; } | ||
}; | ||
|
||
class NetworkRangeMatcherFactory : public RangeMatcherFactoryBase<Network::MatchingData> {}; | ||
|
||
} // namespace Matcher | ||
} // namespace Common | ||
} // namespace Extensions | ||
} // namespace Envoy |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's very difficult/impossible to read this kind of code and understand what is going on as a new reader. Can you please add a large block comment at the top that describes what this does, provides a high level implementation overview, etc.? Also in general I would over comment the code below if possible.
Also in general this kind of code scares me a bit so you might consider writing a small fuzz test for it to help look for corner cases.