Skip to content

Commit

Permalink
aggregate cluster: cleanups (#14411)
Browse files Browse the repository at this point in the history
Follow up to #14382. Remove TLS use in aggregate cluster. Move all
logic into the thread local load balancers making the implementation
less brittle and easier to understand.

Signed-off-by: Matt Klein <mklein@lyft.com>
  • Loading branch information
mattklein123 authored Dec 15, 2020
1 parent 0d89faf commit 0cb98ff
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 128 deletions.
113 changes: 60 additions & 53 deletions source/extensions/clusters/aggregate/cluster.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,46 @@ Cluster::Cluster(const envoy::config::cluster::v3::Cluster& cluster,
Upstream::ClusterManager& cluster_manager, Runtime::Loader& runtime,
Random::RandomGenerator& random,
Server::Configuration::TransportSocketFactoryContextImpl& factory_context,
Stats::ScopePtr&& stats_scope, ThreadLocal::SlotAllocator& tls, bool added_via_api)
Stats::ScopePtr&& stats_scope, bool added_via_api)
: Upstream::ClusterImplBase(cluster, runtime, factory_context, std::move(stats_scope),
added_via_api, factory_context.dispatcher().timeSource()),
cluster_manager_(cluster_manager), runtime_(runtime), random_(random), tls_(tls),
clusters_(config.clusters().begin(), config.clusters().end()) {
tls_.set([info = info(), &runtime, &random](Event::Dispatcher&) {
auto per_thread_load_balancer = std::make_unique<PerThreadLoadBalancer>();
per_thread_load_balancer->lb_ = std::make_unique<AggregateClusterLoadBalancer>(
info->stats(), runtime, random, info->lbConfig());
return per_thread_load_balancer;
});
cluster_manager_(cluster_manager), runtime_(runtime), random_(random),
clusters_(std::make_shared<ClusterSet>(config.clusters().begin(), config.clusters().end())) {}

AggregateClusterLoadBalancer::AggregateClusterLoadBalancer(
const Upstream::ClusterInfoConstSharedPtr& parent_info,
Upstream::ClusterManager& cluster_manager, Runtime::Loader& runtime,
Random::RandomGenerator& random, const ClusterSetConstSharedPtr& clusters)
: parent_info_(parent_info), cluster_manager_(cluster_manager), runtime_(runtime),
random_(random), clusters_(clusters) {
for (const auto& cluster : *clusters_) {
auto tlc = cluster_manager_.getThreadLocalCluster(cluster);
// It is possible when initializing the cluster, the included cluster doesn't exist. e.g., the
// cluster could be added dynamically by xDS.
if (tlc == nullptr) {
continue;
}

// Add callback for clusters initialized before aggregate cluster.
addMemberUpdateCallbackForCluster(*tlc);
}
refresh();
handle_ = cluster_manager_.addThreadLocalClusterUpdateCallbacks(*this);
}

void AggregateClusterLoadBalancer::addMemberUpdateCallbackForCluster(
Upstream::ThreadLocalCluster& thread_local_cluster) {
thread_local_cluster.prioritySet().addMemberUpdateCb(
[this, target_cluster_info = thread_local_cluster.info()](const Upstream::HostVector&,
const Upstream::HostVector&) {
ENVOY_LOG(debug, "member update for cluster '{}' in aggregate cluster '{}'",
target_cluster_info->name(), parent_info_->name());
refresh();
});
}

PriorityContextPtr
Cluster::linearizePrioritySet(const std::function<bool(const std::string&)>& skip_predicate) {
AggregateClusterLoadBalancer::linearizePrioritySet(OptRef<const std::string> excluded_cluster) {
PriorityContextPtr priority_context = std::make_unique<PriorityContext>();
uint32_t next_priority_after_linearizing = 0;

Expand All @@ -42,15 +67,20 @@ Cluster::linearizePrioritySet(const std::function<bool(const std::string&)>& ski
// The linearization result is:
// [C_0.P_0, C_0.P_1, C_0.P_2, C_1.P_0, C_1.P_1, C_2.P_0, C_2.P_1, C_2.P_2, C_2.P_3]
// and the traffic will be distributed among these priorities.
for (const auto& cluster : clusters_) {
if (skip_predicate(cluster)) {
for (const auto& cluster : *clusters_) {
if (excluded_cluster.has_value() && excluded_cluster.value().get() == cluster) {
continue;
}
auto tlc = cluster_manager_.getThreadLocalCluster(cluster);
// It is possible that the cluster doesn't exist, e.g., the cluster could be deleted or the
// cluster hasn't been added by xDS.
if (tlc == nullptr) {
ENVOY_LOG(debug, "refresh: cluster '{}' absent in aggregate cluster '{}'", cluster,
parent_info_->name());
continue;
} else {
ENVOY_LOG(debug, "refresh: cluster '{}' found in aggregate cluster '{}'", cluster,
parent_info_->name());
}

uint32_t priority_in_current_cluster = 0;
Expand All @@ -73,57 +103,34 @@ Cluster::linearizePrioritySet(const std::function<bool(const std::string&)>& ski
return priority_context;
}

void Cluster::startPreInit() {
for (const auto& cluster : clusters_) {
auto tlc = cluster_manager_.getThreadLocalCluster(cluster);
// It is possible when initializing the cluster, the included cluster doesn't exist. e.g., the
// cluster could be added dynamically by xDS.
if (tlc == nullptr) {
continue;
}

// Add callback for clusters initialized before aggregate cluster.
tlc->prioritySet().addMemberUpdateCb(
[this, cluster](const Upstream::HostVector&, const Upstream::HostVector&) {
ENVOY_LOG(debug, "member update for cluster '{}' in aggregate cluster '{}'", cluster,
this->info()->name());
refresh();
});
void AggregateClusterLoadBalancer::refresh(OptRef<const std::string> excluded_cluster) {
PriorityContextPtr priority_context = linearizePrioritySet(excluded_cluster);
if (!priority_context->priority_set_.hostSetsPerPriority().empty()) {
load_balancer_ = std::make_unique<LoadBalancerImpl>(
*priority_context, parent_info_->stats(), runtime_, random_, parent_info_->lbConfig());
} else {
load_balancer_ = nullptr;
}
refresh();
handle_ = cluster_manager_.addThreadLocalClusterUpdateCallbacks(*this);

onPreInitComplete();
}

void Cluster::refresh(const std::function<bool(const std::string&)>& skip_predicate) {
// Post the priority set to worker threads.
// TODO(mattklein123): Remove "this" capture.
tls_.runOnAllThreads([this, skip_predicate, cluster_name = this->info()->name()](
OptRef<PerThreadLoadBalancer> per_thread_load_balancer) {
PriorityContextPtr priority_context = linearizePrioritySet(skip_predicate);
per_thread_load_balancer->get().refresh(std::move(priority_context));
});
priority_context_ = std::move(priority_context);
}

void Cluster::onClusterAddOrUpdate(Upstream::ThreadLocalCluster& cluster) {
if (std::find(clusters_.begin(), clusters_.end(), cluster.info()->name()) != clusters_.end()) {
void AggregateClusterLoadBalancer::onClusterAddOrUpdate(Upstream::ThreadLocalCluster& cluster) {
if (std::find(clusters_->begin(), clusters_->end(), cluster.info()->name()) != clusters_->end()) {
ENVOY_LOG(debug, "adding or updating cluster '{}' for aggregate cluster '{}'",
cluster.info()->name(), info()->name());
cluster.info()->name(), parent_info_->name());
refresh();
cluster.prioritySet().addMemberUpdateCb(
[this](const Upstream::HostVector&, const Upstream::HostVector&) { refresh(); });
addMemberUpdateCallbackForCluster(cluster);
}
}

void Cluster::onClusterRemoval(const std::string& cluster_name) {
void AggregateClusterLoadBalancer::onClusterRemoval(const std::string& cluster_name) {
// The onClusterRemoval callback is called before the thread local cluster is removed. There
// will be a dangling pointer to the thread local cluster if the deleted cluster is not skipped
// when we refresh the load balancer.
if (std::find(clusters_.begin(), clusters_.end(), cluster_name) != clusters_.end()) {
ENVOY_LOG(debug, "removing cluster '{}' from aggreagte cluster '{}'", cluster_name,
info()->name());
refresh([cluster_name](const std::string& c) { return cluster_name == c; });
if (std::find(clusters_->begin(), clusters_->end(), cluster_name) != clusters_->end()) {
ENVOY_LOG(debug, "removing cluster '{}' from aggregate cluster '{}'", cluster_name,
parent_info_->name());
refresh(cluster_name);
}
}

Expand Down Expand Up @@ -182,7 +189,7 @@ ClusterFactory::createClusterWithConfig(
auto new_cluster =
std::make_shared<Cluster>(cluster, proto_config, context.clusterManager(), context.runtime(),
context.api().randomGenerator(), socket_factory_context,
std::move(stats_scope), context.tls(), context.addedViaApi());
std::move(stats_scope), context.addedViaApi());
auto lb = std::make_unique<AggregateThreadAwareLoadBalancer>(*new_cluster);
return std::make_pair(new_cluster, std::move(lb));
}
Expand Down
96 changes: 33 additions & 63 deletions source/extensions/clusters/aggregate/cluster.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
#include "envoy/config/cluster/v3/cluster.pb.h"
#include "envoy/extensions/clusters/aggregate/v3/cluster.pb.h"
#include "envoy/extensions/clusters/aggregate/v3/cluster.pb.validate.h"
#include "envoy/stream_info/stream_info.h"
#include "envoy/thread_local/thread_local_object.h"
#include "envoy/upstream/thread_local_cluster.h"

#include "common/common/logger.h"
#include "common/upstream/cluster_factory_impl.h"
#include "common/upstream/upstream_impl.h"

Expand All @@ -29,72 +32,49 @@ struct PriorityContext {

using PriorityContextPtr = std::unique_ptr<PriorityContext>;

class AggregateClusterLoadBalancer;
// Order matters so a vector must be used for rebuilds. If the vector size becomes larger we can
// maintain a parallel set for lookups during cluster update callbacks.
using ClusterSet = std::vector<std::string>;
using ClusterSetConstSharedPtr = std::shared_ptr<const ClusterSet>;

class Cluster : public Upstream::ClusterImplBase, Upstream::ClusterUpdateCallbacks {
class Cluster : public Upstream::ClusterImplBase {
public:
Cluster(const envoy::config::cluster::v3::Cluster& cluster,
const envoy::extensions::clusters::aggregate::v3::ClusterConfig& config,
Upstream::ClusterManager& cluster_manager, Runtime::Loader& runtime,
Random::RandomGenerator& random,
Server::Configuration::TransportSocketFactoryContextImpl& factory_context,
Stats::ScopePtr&& stats_scope, ThreadLocal::SlotAllocator& tls, bool added_via_api);

struct PerThreadLoadBalancer : public ThreadLocal::ThreadLocalObject {
AggregateClusterLoadBalancer& get() {
// We can refresh before the per-worker LB is created. One of these variants should hold
// a non-null value.
if (absl::holds_alternative<std::unique_ptr<AggregateClusterLoadBalancer>>(lb_)) {
ASSERT(absl::get<std::unique_ptr<AggregateClusterLoadBalancer>>(lb_) != nullptr);
return *absl::get<std::unique_ptr<AggregateClusterLoadBalancer>>(lb_);
} else {
ASSERT(absl::get<AggregateClusterLoadBalancer*>(lb_) != nullptr);
return *absl::get<AggregateClusterLoadBalancer*>(lb_);
}
}

// For aggregate cluster the per-thread LB is only created once. We need to own it so we
// can pre-populate it before the LB is created and handed to the cluster.
absl::variant<std::unique_ptr<AggregateClusterLoadBalancer>, AggregateClusterLoadBalancer*> lb_;
};
Stats::ScopePtr&& stats_scope, bool added_via_api);

// Upstream::Cluster
Upstream::Cluster::InitializePhase initializePhase() const override {
return Upstream::Cluster::InitializePhase::Secondary;
}

// Upstream::ClusterUpdateCallbacks
void onClusterAddOrUpdate(Upstream::ThreadLocalCluster& cluster) override;
void onClusterRemoval(const std::string& cluster_name) override;

void refresh() {
refresh([](const std::string&) { return false; });
}

Upstream::ClusterUpdateCallbacksHandlePtr handle_;
Upstream::ClusterManager& cluster_manager_;
Runtime::Loader& runtime_;
Random::RandomGenerator& random_;
ThreadLocal::TypedSlot<PerThreadLoadBalancer> tls_;
const std::vector<std::string> clusters_;
const ClusterSetConstSharedPtr clusters_;

private:
// Upstream::ClusterImplBase
void startPreInit() override;

void refresh(const std::function<bool(const std::string&)>& skip_predicate);
PriorityContextPtr
linearizePrioritySet(const std::function<bool(const std::string&)>& skip_predicate);
void startPreInit() override { onPreInitComplete(); }
};

// Load balancer used by each worker thread. It will be refreshed when clusters, hosts or priorities
// are updated.
class AggregateClusterLoadBalancer : public Upstream::LoadBalancer {
class AggregateClusterLoadBalancer : public Upstream::LoadBalancer,
Upstream::ClusterUpdateCallbacks,
Logger::Loggable<Logger::Id::upstream> {
public:
AggregateClusterLoadBalancer(
Upstream::ClusterStats& stats, Runtime::Loader& runtime, Random::RandomGenerator& random,
const envoy::config::cluster::v3::Cluster::CommonLbConfig& common_config)
: stats_(stats), runtime_(runtime), random_(random), common_config_(common_config) {}
AggregateClusterLoadBalancer(const Upstream::ClusterInfoConstSharedPtr& parent_info,
Upstream::ClusterManager& cluster_manager, Runtime::Loader& runtime,
Random::RandomGenerator& random,
const ClusterSetConstSharedPtr& clusters);

// Upstream::ClusterUpdateCallbacks
void onClusterAddOrUpdate(Upstream::ThreadLocalCluster& cluster) override;
void onClusterRemoval(const std::string& cluster_name) override;

// Upstream::LoadBalancer
Upstream::HostConstSharedPtr chooseHost(Upstream::LoadBalancerContext* context) override;
Expand Down Expand Up @@ -135,23 +115,18 @@ class AggregateClusterLoadBalancer : public Upstream::LoadBalancer {

using LoadBalancerImplPtr = std::unique_ptr<LoadBalancerImpl>;

void addMemberUpdateCallbackForCluster(Upstream::ThreadLocalCluster& thread_local_cluster);
PriorityContextPtr linearizePrioritySet(OptRef<const std::string> excluded_cluster);
void refresh(OptRef<const std::string> excluded_cluster = OptRef<const std::string>());

LoadBalancerImplPtr load_balancer_;
Upstream::ClusterStats& stats_;
Upstream::ClusterInfoConstSharedPtr parent_info_;
Upstream::ClusterManager& cluster_manager_;
Runtime::Loader& runtime_;
Random::RandomGenerator& random_;
const envoy::config::cluster::v3::Cluster::CommonLbConfig& common_config_;
PriorityContextPtr priority_context_;

public:
void refresh(PriorityContextPtr priority_context) {
if (!priority_context->priority_set_.hostSetsPerPriority().empty()) {
load_balancer_ = std::make_unique<LoadBalancerImpl>(*priority_context, stats_, runtime_,
random_, common_config_);
} else {
load_balancer_ = nullptr;
}
priority_context_ = std::move(priority_context);
}
const ClusterSetConstSharedPtr clusters_;
Upstream::ClusterUpdateCallbacksHandlePtr handle_;
};

// Load balancer factory created by the main thread and will be called in each worker thread to
Expand All @@ -160,14 +135,9 @@ struct AggregateLoadBalancerFactory : public Upstream::LoadBalancerFactory {
AggregateLoadBalancerFactory(const Cluster& cluster) : cluster_(cluster) {}
// Upstream::LoadBalancerFactory
Upstream::LoadBalancerPtr create() override {
// See comments in PerThreadLoadBalancer above for why the follow is done.
auto per_thread_local_balancer = cluster_.tls_.get();
ASSERT(absl::get<std::unique_ptr<AggregateClusterLoadBalancer>>(
per_thread_local_balancer->lb_) != nullptr);
auto to_return = std::move(
absl::get<std::unique_ptr<AggregateClusterLoadBalancer>>(per_thread_local_balancer->lb_));
per_thread_local_balancer->lb_ = to_return.get();
return to_return;
return std::make_unique<AggregateClusterLoadBalancer>(
cluster_.info(), cluster_.cluster_manager_, cluster_.runtime_, cluster_.random_,
cluster_.clusters_);
}

const Cluster& cluster_;
Expand Down
19 changes: 7 additions & 12 deletions test/extensions/clusters/aggregate/cluster_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ class AggregateClusterTest : public Event::TestUsingSimulatedTime, public testin
Upstream::HostSetImpl::partitionHosts(std::make_shared<Upstream::HostVector>(hosts),
Upstream::HostsPerLocalityImpl::empty()),
nullptr, hosts, {}, 100);
cluster_->refresh();
}

void setupSecondary(int priority, int healthy_hosts, int degraded_hosts, int unhealthy_hosts) {
Expand All @@ -83,7 +82,6 @@ class AggregateClusterTest : public Event::TestUsingSimulatedTime, public testin
Upstream::HostSetImpl::partitionHosts(std::make_shared<Upstream::HostVector>(hosts),
Upstream::HostsPerLocalityImpl::empty()),
nullptr, hosts, {}, 100);
cluster_->refresh();
}

void setupPrioritySet() {
Expand All @@ -107,25 +105,22 @@ class AggregateClusterTest : public Event::TestUsingSimulatedTime, public testin

cluster_ =
std::make_shared<Cluster>(cluster_config, config, cm_, runtime_, api_->randomGenerator(),
factory_context, std::move(scope), tls_, false);
factory_context, std::move(scope), false);

thread_aware_lb_ = std::make_unique<AggregateThreadAwareLoadBalancer>(*cluster_);
lb_factory_ = thread_aware_lb_->factory();
lb_ = lb_factory_->create();

EXPECT_CALL(cm_, getThreadLocalCluster(Eq("aggregate_cluster")))
.WillRepeatedly(Return(&aggregate_cluster_));
cm_.initializeThreadLocalClusters({"primary", "secondary"});
EXPECT_CALL(cm_, getThreadLocalCluster(Eq("primary"))).WillRepeatedly(Return(&primary_));
EXPECT_CALL(cm_, getThreadLocalCluster(Eq("secondary"))).WillRepeatedly(Return(&secondary_));
EXPECT_CALL(cm_, getThreadLocalCluster(Eq("tertiary"))).WillRepeatedly(Return(nullptr));
ON_CALL(primary_, prioritySet()).WillByDefault(ReturnRef(primary_ps_));
ON_CALL(secondary_, prioritySet()).WillByDefault(ReturnRef(secondary_ps_));
ON_CALL(aggregate_cluster_, loadBalancer()).WillByDefault(ReturnRef(*lb_));

setupPrioritySet();

ON_CALL(primary_, loadBalancer()).WillByDefault(ReturnRef(primary_load_balancer_));
ON_CALL(secondary_, loadBalancer()).WillByDefault(ReturnRef(secondary_load_balancer_));

thread_aware_lb_ = std::make_unique<AggregateThreadAwareLoadBalancer>(*cluster_);
lb_factory_ = thread_aware_lb_->factory();
lb_ = lb_factory_->create();
}

Stats::TestUtil::TestStore stats_store_;
Expand All @@ -150,7 +145,7 @@ class AggregateClusterTest : public Event::TestUsingSimulatedTime, public testin
new NiceMock<Upstream::MockClusterInfo>()};
std::shared_ptr<Upstream::MockClusterInfo> secondary_info_{
new NiceMock<Upstream::MockClusterInfo>()};
NiceMock<Upstream::MockThreadLocalCluster> aggregate_cluster_, primary_, secondary_;
NiceMock<Upstream::MockThreadLocalCluster> primary_, secondary_;
Upstream::PrioritySetImpl primary_ps_, secondary_ps_;
NiceMock<Upstream::MockLoadBalancer> primary_load_balancer_, secondary_load_balancer_;

Expand Down

0 comments on commit 0cb98ff

Please sign in to comment.