diff --git a/source/extensions/clusters/aggregate/cluster.cc b/source/extensions/clusters/aggregate/cluster.cc index c8b406daadfb..c48ba103c9d7 100644 --- a/source/extensions/clusters/aggregate/cluster.cc +++ b/source/extensions/clusters/aggregate/cluster.cc @@ -17,21 +17,46 @@ Cluster::Cluster(const envoy::config::cluster::v3::Cluster& cluster, Upstream::ClusterManager& cluster_manager, Runtime::Loader& runtime, Random::RandomGenerator& random, Server::Configuration::TransportSocketFactoryContextImpl& factory_context, - Stats::ScopePtr&& stats_scope, ThreadLocal::SlotAllocator& tls, bool added_via_api) + Stats::ScopePtr&& stats_scope, bool added_via_api) : Upstream::ClusterImplBase(cluster, runtime, factory_context, std::move(stats_scope), added_via_api, factory_context.dispatcher().timeSource()), - cluster_manager_(cluster_manager), runtime_(runtime), random_(random), tls_(tls), - clusters_(config.clusters().begin(), config.clusters().end()) { - tls_.set([info = info(), &runtime, &random](Event::Dispatcher&) { - auto per_thread_load_balancer = std::make_unique<PerThreadLoadBalancer>(); - per_thread_load_balancer->lb_ = std::make_unique<AggregateClusterLoadBalancer>( - info->stats(), runtime, random, info->lbConfig()); - return per_thread_load_balancer; - }); + cluster_manager_(cluster_manager), runtime_(runtime), random_(random), + clusters_(std::make_shared<ClusterSet>(config.clusters().begin(), config.clusters().end())) {} + +AggregateClusterLoadBalancer::AggregateClusterLoadBalancer( + const Upstream::ClusterInfoConstSharedPtr& parent_info, + Upstream::ClusterManager& cluster_manager, Runtime::Loader& runtime, + Random::RandomGenerator& random, const ClusterSetConstSharedPtr& clusters) + : parent_info_(parent_info), cluster_manager_(cluster_manager), runtime_(runtime), + random_(random), clusters_(clusters) { + for (const auto& cluster : *clusters_) { + auto tlc = cluster_manager_.getThreadLocalCluster(cluster); + // It is possible when initializing the cluster, the included cluster doesn't exist. e.g., the + // cluster could be added dynamically by xDS. + if (tlc == nullptr) { + continue; + } + + // Add callback for clusters initialized before aggregate cluster. + addMemberUpdateCallbackForCluster(*tlc); + } + refresh(); + handle_ = cluster_manager_.addThreadLocalClusterUpdateCallbacks(*this); +} + +void AggregateClusterLoadBalancer::addMemberUpdateCallbackForCluster( + Upstream::ThreadLocalCluster& thread_local_cluster) { + thread_local_cluster.prioritySet().addMemberUpdateCb( + [this, target_cluster_info = thread_local_cluster.info()](const Upstream::HostVector&, + const Upstream::HostVector&) { + ENVOY_LOG(debug, "member update for cluster '{}' in aggregate cluster '{}'", + target_cluster_info->name(), parent_info_->name()); + refresh(); + }); } PriorityContextPtr -Cluster::linearizePrioritySet(const std::function<bool(const std::string&)>& skip_predicate) { +AggregateClusterLoadBalancer::linearizePrioritySet(OptRef<const std::string> excluded_cluster) { PriorityContextPtr priority_context = std::make_unique<PriorityContext>(); uint32_t next_priority_after_linearizing = 0; @@ -42,15 +67,20 @@ Cluster::linearizePrioritySet(const std::function<bool(const std::string&)>& ski // The linearization result is: // [C_0.P_0, C_0.P_1, C_0.P_2, C_1.P_0, C_1.P_1, C_2.P_0, C_2.P_1, C_2.P_2, C_2.P_3] // and the traffic will be distributed among these priorities. - for (const auto& cluster : clusters_) { - if (skip_predicate(cluster)) { + for (const auto& cluster : *clusters_) { + if (excluded_cluster.has_value() && excluded_cluster.value().get() == cluster) { continue; } auto tlc = cluster_manager_.getThreadLocalCluster(cluster); // It is possible that the cluster doesn't exist, e.g., the cluster could be deleted or the // cluster hasn't been added by xDS. if (tlc == nullptr) { + ENVOY_LOG(debug, "refresh: cluster '{}' absent in aggregate cluster '{}'", cluster, + parent_info_->name()); continue; + } else { + ENVOY_LOG(debug, "refresh: cluster '{}' found in aggregate cluster '{}'", cluster, + parent_info_->name()); } uint32_t priority_in_current_cluster = 0; @@ -73,57 +103,34 @@ Cluster::linearizePrioritySet(const std::function<bool(const std::string&)>& ski return priority_context; } -void Cluster::startPreInit() { - for (const auto& cluster : clusters_) { - auto tlc = cluster_manager_.getThreadLocalCluster(cluster); - // It is possible when initializing the cluster, the included cluster doesn't exist. e.g., the - // cluster could be added dynamically by xDS. - if (tlc == nullptr) { - continue; - } - - // Add callback for clusters initialized before aggregate cluster. - tlc->prioritySet().addMemberUpdateCb( - [this, cluster](const Upstream::HostVector&, const Upstream::HostVector&) { - ENVOY_LOG(debug, "member update for cluster '{}' in aggregate cluster '{}'", cluster, - this->info()->name()); - refresh(); - }); +void AggregateClusterLoadBalancer::refresh(OptRef<const std::string> excluded_cluster) { + PriorityContextPtr priority_context = linearizePrioritySet(excluded_cluster); + if (!priority_context->priority_set_.hostSetsPerPriority().empty()) { + load_balancer_ = std::make_unique<LoadBalancerImpl>( + *priority_context, parent_info_->stats(), runtime_, random_, parent_info_->lbConfig()); + } else { + load_balancer_ = nullptr; } - refresh(); - handle_ = cluster_manager_.addThreadLocalClusterUpdateCallbacks(*this); - - onPreInitComplete(); -} - -void Cluster::refresh(const std::function<bool(const std::string&)>& skip_predicate) { - // Post the priority set to worker threads. - // TODO(mattklein123): Remove "this" capture. - tls_.runOnAllThreads([this, skip_predicate, cluster_name = this->info()->name()]( - OptRef<PerThreadLoadBalancer> per_thread_load_balancer) { - PriorityContextPtr priority_context = linearizePrioritySet(skip_predicate); - per_thread_load_balancer->get().refresh(std::move(priority_context)); - }); + priority_context_ = std::move(priority_context); } -void Cluster::onClusterAddOrUpdate(Upstream::ThreadLocalCluster& cluster) { - if (std::find(clusters_.begin(), clusters_.end(), cluster.info()->name()) != clusters_.end()) { +void AggregateClusterLoadBalancer::onClusterAddOrUpdate(Upstream::ThreadLocalCluster& cluster) { + if (std::find(clusters_->begin(), clusters_->end(), cluster.info()->name()) != clusters_->end()) { ENVOY_LOG(debug, "adding or updating cluster '{}' for aggregate cluster '{}'", - cluster.info()->name(), info()->name()); + cluster.info()->name(), parent_info_->name()); refresh(); - cluster.prioritySet().addMemberUpdateCb( - [this](const Upstream::HostVector&, const Upstream::HostVector&) { refresh(); }); + addMemberUpdateCallbackForCluster(cluster); } } -void Cluster::onClusterRemoval(const std::string& cluster_name) { +void AggregateClusterLoadBalancer::onClusterRemoval(const std::string& cluster_name) { // The onClusterRemoval callback is called before the thread local cluster is removed. There // will be a dangling pointer to the thread local cluster if the deleted cluster is not skipped // when we refresh the load balancer. - if (std::find(clusters_.begin(), clusters_.end(), cluster_name) != clusters_.end()) { - ENVOY_LOG(debug, "removing cluster '{}' from aggreagte cluster '{}'", cluster_name, - info()->name()); - refresh([cluster_name](const std::string& c) { return cluster_name == c; }); + if (std::find(clusters_->begin(), clusters_->end(), cluster_name) != clusters_->end()) { + ENVOY_LOG(debug, "removing cluster '{}' from aggregate cluster '{}'", cluster_name, + parent_info_->name()); + refresh(cluster_name); } } @@ -182,7 +189,7 @@ ClusterFactory::createClusterWithConfig( auto new_cluster = std::make_shared<Cluster>(cluster, proto_config, context.clusterManager(), context.runtime(), context.api().randomGenerator(), socket_factory_context, - std::move(stats_scope), context.tls(), context.addedViaApi()); + std::move(stats_scope), context.addedViaApi()); auto lb = std::make_unique<AggregateThreadAwareLoadBalancer>(*new_cluster); return std::make_pair(new_cluster, std::move(lb)); } diff --git a/source/extensions/clusters/aggregate/cluster.h b/source/extensions/clusters/aggregate/cluster.h index 59f74b08d876..72df0d685824 100644 --- a/source/extensions/clusters/aggregate/cluster.h +++ b/source/extensions/clusters/aggregate/cluster.h @@ -3,8 +3,11 @@ #include "envoy/config/cluster/v3/cluster.pb.h" #include "envoy/extensions/clusters/aggregate/v3/cluster.pb.h" #include "envoy/extensions/clusters/aggregate/v3/cluster.pb.validate.h" +#include "envoy/stream_info/stream_info.h" #include "envoy/thread_local/thread_local_object.h" +#include "envoy/upstream/thread_local_cluster.h" +#include "common/common/logger.h" #include "common/upstream/cluster_factory_impl.h" #include "common/upstream/upstream_impl.h" @@ -29,72 +32,49 @@ struct PriorityContext { using PriorityContextPtr = std::unique_ptr<PriorityContext>; -class AggregateClusterLoadBalancer; +// Order matters so a vector must be used for rebuilds. If the vector size becomes larger we can +// maintain a parallel set for lookups during cluster update callbacks. +using ClusterSet = std::vector<std::string>; +using ClusterSetConstSharedPtr = std::shared_ptr<const ClusterSet>; -class Cluster : public Upstream::ClusterImplBase, Upstream::ClusterUpdateCallbacks { +class Cluster : public Upstream::ClusterImplBase { public: Cluster(const envoy::config::cluster::v3::Cluster& cluster, const envoy::extensions::clusters::aggregate::v3::ClusterConfig& config, Upstream::ClusterManager& cluster_manager, Runtime::Loader& runtime, Random::RandomGenerator& random, Server::Configuration::TransportSocketFactoryContextImpl& factory_context, - Stats::ScopePtr&& stats_scope, ThreadLocal::SlotAllocator& tls, bool added_via_api); - - struct PerThreadLoadBalancer : public ThreadLocal::ThreadLocalObject { - AggregateClusterLoadBalancer& get() { - // We can refresh before the per-worker LB is created. One of these variants should hold - // a non-null value. - if (absl::holds_alternative<std::unique_ptr<AggregateClusterLoadBalancer>>(lb_)) { - ASSERT(absl::get<std::unique_ptr<AggregateClusterLoadBalancer>>(lb_) != nullptr); - return *absl::get<std::unique_ptr<AggregateClusterLoadBalancer>>(lb_); - } else { - ASSERT(absl::get<AggregateClusterLoadBalancer*>(lb_) != nullptr); - return *absl::get<AggregateClusterLoadBalancer*>(lb_); - } - } - - // For aggregate cluster the per-thread LB is only created once. We need to own it so we - // can pre-populate it before the LB is created and handed to the cluster. - absl::variant<std::unique_ptr<AggregateClusterLoadBalancer>, AggregateClusterLoadBalancer*> lb_; - }; + Stats::ScopePtr&& stats_scope, bool added_via_api); // Upstream::Cluster Upstream::Cluster::InitializePhase initializePhase() const override { return Upstream::Cluster::InitializePhase::Secondary; } - // Upstream::ClusterUpdateCallbacks - void onClusterAddOrUpdate(Upstream::ThreadLocalCluster& cluster) override; - void onClusterRemoval(const std::string& cluster_name) override; - - void refresh() { - refresh([](const std::string&) { return false; }); - } - - Upstream::ClusterUpdateCallbacksHandlePtr handle_; Upstream::ClusterManager& cluster_manager_; Runtime::Loader& runtime_; Random::RandomGenerator& random_; - ThreadLocal::TypedSlot<PerThreadLoadBalancer> tls_; - const std::vector<std::string> clusters_; + const ClusterSetConstSharedPtr clusters_; private: // Upstream::ClusterImplBase - void startPreInit() override; - - void refresh(const std::function<bool(const std::string&)>& skip_predicate); - PriorityContextPtr - linearizePrioritySet(const std::function<bool(const std::string&)>& skip_predicate); + void startPreInit() override { onPreInitComplete(); } }; // Load balancer used by each worker thread. It will be refreshed when clusters, hosts or priorities // are updated. -class AggregateClusterLoadBalancer : public Upstream::LoadBalancer { +class AggregateClusterLoadBalancer : public Upstream::LoadBalancer, + Upstream::ClusterUpdateCallbacks, + Logger::Loggable<Logger::Id::upstream> { public: - AggregateClusterLoadBalancer( - Upstream::ClusterStats& stats, Runtime::Loader& runtime, Random::RandomGenerator& random, - const envoy::config::cluster::v3::Cluster::CommonLbConfig& common_config) - : stats_(stats), runtime_(runtime), random_(random), common_config_(common_config) {} + AggregateClusterLoadBalancer(const Upstream::ClusterInfoConstSharedPtr& parent_info, + Upstream::ClusterManager& cluster_manager, Runtime::Loader& runtime, + Random::RandomGenerator& random, + const ClusterSetConstSharedPtr& clusters); + + // Upstream::ClusterUpdateCallbacks + void onClusterAddOrUpdate(Upstream::ThreadLocalCluster& cluster) override; + void onClusterRemoval(const std::string& cluster_name) override; // Upstream::LoadBalancer Upstream::HostConstSharedPtr chooseHost(Upstream::LoadBalancerContext* context) override; @@ -135,23 +115,18 @@ class AggregateClusterLoadBalancer : public Upstream::LoadBalancer { using LoadBalancerImplPtr = std::unique_ptr<LoadBalancerImpl>; + void addMemberUpdateCallbackForCluster(Upstream::ThreadLocalCluster& thread_local_cluster); + PriorityContextPtr linearizePrioritySet(OptRef<const std::string> excluded_cluster); + void refresh(OptRef<const std::string> excluded_cluster = OptRef<const std::string>()); + LoadBalancerImplPtr load_balancer_; - Upstream::ClusterStats& stats_; + Upstream::ClusterInfoConstSharedPtr parent_info_; + Upstream::ClusterManager& cluster_manager_; Runtime::Loader& runtime_; Random::RandomGenerator& random_; - const envoy::config::cluster::v3::Cluster::CommonLbConfig& common_config_; PriorityContextPtr priority_context_; - -public: - void refresh(PriorityContextPtr priority_context) { - if (!priority_context->priority_set_.hostSetsPerPriority().empty()) { - load_balancer_ = std::make_unique<LoadBalancerImpl>(*priority_context, stats_, runtime_, - random_, common_config_); - } else { - load_balancer_ = nullptr; - } - priority_context_ = std::move(priority_context); - } + const ClusterSetConstSharedPtr clusters_; + Upstream::ClusterUpdateCallbacksHandlePtr handle_; }; // Load balancer factory created by the main thread and will be called in each worker thread to @@ -160,14 +135,9 @@ struct AggregateLoadBalancerFactory : public Upstream::LoadBalancerFactory { AggregateLoadBalancerFactory(const Cluster& cluster) : cluster_(cluster) {} // Upstream::LoadBalancerFactory Upstream::LoadBalancerPtr create() override { - // See comments in PerThreadLoadBalancer above for why the follow is done. - auto per_thread_local_balancer = cluster_.tls_.get(); - ASSERT(absl::get<std::unique_ptr<AggregateClusterLoadBalancer>>( - per_thread_local_balancer->lb_) != nullptr); - auto to_return = std::move( - absl::get<std::unique_ptr<AggregateClusterLoadBalancer>>(per_thread_local_balancer->lb_)); - per_thread_local_balancer->lb_ = to_return.get(); - return to_return; + return std::make_unique<AggregateClusterLoadBalancer>( + cluster_.info(), cluster_.cluster_manager_, cluster_.runtime_, cluster_.random_, + cluster_.clusters_); } const Cluster& cluster_; diff --git a/test/extensions/clusters/aggregate/cluster_test.cc b/test/extensions/clusters/aggregate/cluster_test.cc index 1d78f5e8779e..303f9d841190 100644 --- a/test/extensions/clusters/aggregate/cluster_test.cc +++ b/test/extensions/clusters/aggregate/cluster_test.cc @@ -72,7 +72,6 @@ class AggregateClusterTest : public Event::TestUsingSimulatedTime, public testin Upstream::HostSetImpl::partitionHosts(std::make_shared<Upstream::HostVector>(hosts), Upstream::HostsPerLocalityImpl::empty()), nullptr, hosts, {}, 100); - cluster_->refresh(); } void setupSecondary(int priority, int healthy_hosts, int degraded_hosts, int unhealthy_hosts) { @@ -83,7 +82,6 @@ class AggregateClusterTest : public Event::TestUsingSimulatedTime, public testin Upstream::HostSetImpl::partitionHosts(std::make_shared<Upstream::HostVector>(hosts), Upstream::HostsPerLocalityImpl::empty()), nullptr, hosts, {}, 100); - cluster_->refresh(); } void setupPrioritySet() { @@ -107,25 +105,22 @@ class AggregateClusterTest : public Event::TestUsingSimulatedTime, public testin cluster_ = std::make_shared<Cluster>(cluster_config, config, cm_, runtime_, api_->randomGenerator(), - factory_context, std::move(scope), tls_, false); + factory_context, std::move(scope), false); - thread_aware_lb_ = std::make_unique<AggregateThreadAwareLoadBalancer>(*cluster_); - lb_factory_ = thread_aware_lb_->factory(); - lb_ = lb_factory_->create(); - - EXPECT_CALL(cm_, getThreadLocalCluster(Eq("aggregate_cluster"))) - .WillRepeatedly(Return(&aggregate_cluster_)); + cm_.initializeThreadLocalClusters({"primary", "secondary"}); EXPECT_CALL(cm_, getThreadLocalCluster(Eq("primary"))).WillRepeatedly(Return(&primary_)); EXPECT_CALL(cm_, getThreadLocalCluster(Eq("secondary"))).WillRepeatedly(Return(&secondary_)); - EXPECT_CALL(cm_, getThreadLocalCluster(Eq("tertiary"))).WillRepeatedly(Return(nullptr)); ON_CALL(primary_, prioritySet()).WillByDefault(ReturnRef(primary_ps_)); ON_CALL(secondary_, prioritySet()).WillByDefault(ReturnRef(secondary_ps_)); - ON_CALL(aggregate_cluster_, loadBalancer()).WillByDefault(ReturnRef(*lb_)); setupPrioritySet(); ON_CALL(primary_, loadBalancer()).WillByDefault(ReturnRef(primary_load_balancer_)); ON_CALL(secondary_, loadBalancer()).WillByDefault(ReturnRef(secondary_load_balancer_)); + + thread_aware_lb_ = std::make_unique<AggregateThreadAwareLoadBalancer>(*cluster_); + lb_factory_ = thread_aware_lb_->factory(); + lb_ = lb_factory_->create(); } Stats::TestUtil::TestStore stats_store_; @@ -150,7 +145,7 @@ class AggregateClusterTest : public Event::TestUsingSimulatedTime, public testin new NiceMock<Upstream::MockClusterInfo>()}; std::shared_ptr<Upstream::MockClusterInfo> secondary_info_{ new NiceMock<Upstream::MockClusterInfo>()}; - NiceMock<Upstream::MockThreadLocalCluster> aggregate_cluster_, primary_, secondary_; + NiceMock<Upstream::MockThreadLocalCluster> primary_, secondary_; Upstream::PrioritySetImpl primary_ps_, secondary_ps_; NiceMock<Upstream::MockLoadBalancer> primary_load_balancer_, secondary_load_balancer_;