From 87179ef206d89260e2880438d79a5fb1c52538bb Mon Sep 17 00:00:00 2001 From: Matt Klein Date: Sat, 13 Aug 2016 09:00:54 -0700 Subject: [PATCH] Refactor async client This change refactors async client so that it is owned by the cluster manager. Now users can just send requests and fire and forget if they don't need to cancel. This makes a lot of the consumers much simpler. The fire and forget functionality will be used for router request shadowing (similar to the HTTP tracer). --- include/envoy/http/async_client.h | 13 +- include/envoy/upstream/cluster_manager.h | 17 +- include/envoy/upstream/upstream.h | 3 +- source/common/common/linked_object.h | 5 + source/common/filter/auth/client_ssl.cc | 16 +- source/common/filter/auth/client_ssl.h | 8 - source/common/filter/tcp_proxy.cc | 2 +- source/common/grpc/rpc_channel_impl.cc | 11 +- source/common/grpc/rpc_channel_impl.h | 3 +- source/common/http/async_client_impl.cc | 81 ++++--- source/common/http/async_client_impl.h | 38 +++- source/common/ratelimit/ratelimit_impl.cc | 2 +- source/common/router/config_impl.cc | 2 +- source/common/stats/statsd.cc | 2 +- source/common/tracing/http_tracer_impl.cc | 83 ++----- source/common/tracing/http_tracer_impl.h | 68 ++---- .../common/upstream/cluster_manager_impl.cc | 68 +++--- source/common/upstream/cluster_manager_impl.h | 19 +- source/common/upstream/sds.cc | 14 +- source/common/upstream/sds.h | 3 +- source/server/configuration_impl.cc | 6 +- test/common/filter/auth/client_ssl_test.cc | 25 ++- test/common/filter/tcp_proxy_test.cc | 3 +- test/common/grpc/rpc_channel_impl_test.cc | 53 ++--- test/common/http/async_client_impl_test.cc | 87 ++++++-- test/common/ratelimit/ratelimit_impl_test.cc | 4 +- test/common/router/config_impl_test.cc | 6 - test/common/stats/statsd_test.cc | 2 +- test/common/tracing/http_tracer_impl_test.cc | 71 +++--- .../upstream/cluster_manager_impl_test.cc | 207 +++++++++++++++--- test/common/upstream/sds_test.cc | 26 +-- test/mocks/http/mocks.h | 6 +- test/mocks/upstream/mocks.h | 8 +- 33 files changed, 532 insertions(+), 430 deletions(-) diff --git a/include/envoy/http/async_client.h b/include/envoy/http/async_client.h index b2d559596c12..59d6fef6c301 100644 --- a/include/envoy/http/async_client.h +++ b/include/envoy/http/async_client.h @@ -52,19 +52,18 @@ class AsyncClient { virtual void cancel() PURE; }; - typedef std::unique_ptr RequestPtr; - virtual ~AsyncClient() {} /** * Send an HTTP request asynchronously - * @param request the request to send - * @param callbacks the callbacks to be notified of request status + * @param request the request to send. + * @param callbacks the callbacks to be notified of request status. * @return a request handle or nullptr if no request could be created. NOTE: In this case - * onFailure() has already been called inline. + * onFailure() has already been called inline. The client owns the request and the + * handle should just be used to cancel. */ - virtual RequestPtr send(MessagePtr&& request, Callbacks& callbacks, - const Optional& timeout) PURE; + virtual Request* send(MessagePtr&& request, Callbacks& callbacks, + const Optional& timeout) PURE; }; typedef std::unique_ptr AsyncClientPtr; diff --git a/include/envoy/upstream/cluster_manager.h b/include/envoy/upstream/cluster_manager.h index ae03f2a30f13..002c025fa0cb 100644 --- a/include/envoy/upstream/cluster_manager.h +++ b/include/envoy/upstream/cluster_manager.h @@ -31,18 +31,12 @@ class ClusterManager { */ virtual const Cluster* get(const std::string& cluster) PURE; - /** - * @return whether the cluster manager knows about a particular cluster by name. - */ - virtual bool has(const std::string& cluster) PURE; - /** * Allocate a load balanced HTTP connection pool for a cluster. This is *per-thread* so that * callers do not need to worry about per thread synchronization. The load balancing policy that * is used is the one defined on the cluster when it was created. * - * Can return nullptr if there is no host available in the cluster or the cluster name is not - * valid. + * Can return nullptr if there is no host available in the cluster. */ virtual Http::ConnectionPool::Instance* httpConnPoolForCluster(const std::string& cluster) PURE; @@ -52,15 +46,16 @@ class ClusterManager { * load balancing policy that is used is the one defined on the cluster when it was created. * * Returns both a connection and the host that backs the connection. Both can be nullptr if there - * is no host available in the cluster or the cluster name is not valid. + * is no host available in the cluster. */ virtual Host::CreateConnectionData tcpConnForCluster(const std::string& cluster) PURE; /** - * Returns a client that can be used to make async HTTP calls against the given cluster. The - * client may be backed by a connection pool or by a multiplexed connection. + * Returns a client that can be used to make async HTTP calls against the given cluster. The + * client may be backed by a connection pool or by a multiplexed connection. The cluster manager + * owns the client. */ - virtual Http::AsyncClientPtr httpAsyncClientForCluster(const std::string& cluster) PURE; + virtual Http::AsyncClient& httpAsyncClientForCluster(const std::string& cluster) PURE; /** * Shutdown the cluster prior to destroying connection pools and other thread local data. diff --git a/include/envoy/upstream/upstream.h b/include/envoy/upstream/upstream.h index 7a89bc26863f..75d5cc3d2b37 100644 --- a/include/envoy/upstream/upstream.h +++ b/include/envoy/upstream/upstream.h @@ -238,8 +238,7 @@ class Cluster : public virtual HostSet { virtual ResourceManager& resourceManager() const PURE; /** - * Shutdown the cluster manager prior to destroying connection pools and other thread local - * data. + * Shutdown the cluster prior to destroying connection pools and other thread local data. */ virtual void shutdown() PURE; diff --git a/source/common/common/linked_object.h b/source/common/common/linked_object.h index 127ffe431c04..c99ed7eadb99 100644 --- a/source/common/common/linked_object.h +++ b/source/common/common/linked_object.h @@ -18,6 +18,11 @@ template class LinkedObject { return entry_; } + /** + * @return whether the object is currently inserted into a list. + */ + bool inserted() { return inserted_; } + /** * Move a linked item between 2 lists. * @param list1 supplies the first list. diff --git a/source/common/filter/auth/client_ssl.cc b/source/common/filter/auth/client_ssl.cc index 5bce5fa01e8a..edce72ad5fcf 100644 --- a/source/common/filter/auth/client_ssl.cc +++ b/source/common/filter/auth/client_ssl.cc @@ -21,7 +21,7 @@ Config::Config(const Json::Object& config, ThreadLocal::Instance& tls, Upstream: ip_white_list_(config), stats_(generateStats(stats_store, config.getString("stat_prefix"))), runtime_(runtime), local_address_(local_address) { - if (!cm_.has(auth_api_cluster_)) { + if (!cm_.get(auth_api_cluster_)) { throw EnvoyException( fmt::format("unknown cluster '{}' in client ssl auth config", auth_api_cluster_)); } @@ -83,29 +83,19 @@ void Config::onFailure(Http::AsyncClient::FailureReason) { } void Config::refreshPrincipals() { - ASSERT(!active_request_); - active_request_.reset(new ActiveRequest()); - active_request_->client_ = cm_.httpAsyncClientForCluster(auth_api_cluster_); - if (!active_request_->client_) { - onFailure(Http::AsyncClient::FailureReason::Reset); - return; - } - Http::MessagePtr message(new Http::RequestMessageImpl()); message->headers().addViaMoveValue(Http::Headers::get().Scheme, "http"); message->headers().addViaMoveValue(Http::Headers::get().Method, "GET"); message->headers().addViaMoveValue(Http::Headers::get().Path, "/v1/certs/list/approved"); message->headers().addViaCopy(Http::Headers::get().Host, auth_api_cluster_); message->headers().addViaCopy(Http::Headers::get().ForwardedFor, local_address_); - active_request_->request_ = active_request_->client_->send(std::move(message), *this, - Optional()); + cm_.httpAsyncClientForCluster(auth_api_cluster_) + .send(std::move(message), *this, Optional()); } void Config::requestComplete() { std::chrono::milliseconds interval( runtime_.snapshot().getInteger("auth.clientssl.refresh_interval_ms", 60000)); - - active_request_.reset(); interval_timer_->enableTimer(interval); } diff --git a/source/common/filter/auth/client_ssl.h b/source/common/filter/auth/client_ssl.h index 9ad2966034f2..5d83558ab85c 100644 --- a/source/common/filter/auth/client_ssl.h +++ b/source/common/filter/auth/client_ssl.h @@ -78,13 +78,6 @@ class Config : public Http::AsyncClient::Callbacks { void onFailure(Http::AsyncClient::FailureReason reason) override; private: - struct ActiveRequest { - Http::AsyncClientPtr client_; - Http::AsyncClient::RequestPtr request_; - }; - - typedef std::unique_ptr ActiveRequestPtr; - static GlobalStats generateStats(Stats::Store& store, const std::string& prefix); AllowedPrincipalsPtr parseAuthResponse(Http::Message& message); void refreshPrincipals(); @@ -94,7 +87,6 @@ class Config : public Http::AsyncClient::Callbacks { uint32_t tls_slot_; Upstream::ClusterManager& cm_; const std::string auth_api_cluster_; - ActiveRequestPtr active_request_; Event::TimerPtr interval_timer_; Network::IpWhiteList ip_white_list_; GlobalStats stats_; diff --git a/source/common/filter/tcp_proxy.cc b/source/common/filter/tcp_proxy.cc index 7bd914df4c58..0d83d437db43 100644 --- a/source/common/filter/tcp_proxy.cc +++ b/source/common/filter/tcp_proxy.cc @@ -16,7 +16,7 @@ TcpProxyConfig::TcpProxyConfig(const Json::Object& config, Upstream::ClusterManager& cluster_manager, Stats::Store& stats_store) : cluster_name_(config.getString("cluster")), stats_(generateStats(config.getString("stat_prefix"), stats_store)) { - if (!cluster_manager.has(cluster_name_)) { + if (!cluster_manager.get(cluster_name_)) { throw EnvoyException(fmt::format("tcp proxy: unknown cluster '{}'", cluster_name_)); } } diff --git a/source/common/grpc/rpc_channel_impl.cc b/source/common/grpc/rpc_channel_impl.cc index e30f91855e28..99f1bb418bf3 100644 --- a/source/common/grpc/rpc_channel_impl.cc +++ b/source/common/grpc/rpc_channel_impl.cc @@ -30,12 +30,6 @@ void RpcChannelImpl::CallMethod(const proto::MethodDescriptor* method, proto::Rp // here for clarity. ASSERT(cm_.get(cluster_)->features() & Upstream::Cluster::Features::HTTP2); - client_ = cm_.httpAsyncClientForCluster(cluster_); - if (!client_) { - onFailureWorker(Optional(), "http request failure"); - return; - } - Http::MessagePtr message(new Http::RequestMessageImpl()); message->headers().addViaMoveValue(Http::Headers::get().Scheme, "http"); message->headers().addViaMoveValue(Http::Headers::get().Method, "POST"); @@ -46,10 +40,7 @@ void RpcChannelImpl::CallMethod(const proto::MethodDescriptor* method, proto::Rp message->headers().addViaCopy(Http::Headers::get().ContentType, Common::GRPC_CONTENT_TYPE); message->body(serializeBody(*grpc_request)); - http_request_ = client_->send(std::move(message), *this, timeout_); - if (!http_request_) { - onFailureWorker(Optional(), "http request failure"); - } + http_request_ = cm_.httpAsyncClientForCluster(cluster_).send(std::move(message), *this, timeout_); } void RpcChannelImpl::incStat(bool success) { diff --git a/source/common/grpc/rpc_channel_impl.h b/source/common/grpc/rpc_channel_impl.h index 194ca46745f0..72b74e0660a6 100644 --- a/source/common/grpc/rpc_channel_impl.h +++ b/source/common/grpc/rpc_channel_impl.h @@ -62,8 +62,7 @@ class RpcChannelImpl : public RpcChannel, public Http::AsyncClient::Callbacks { Upstream::ClusterManager& cm_; const std::string cluster_; - Http::AsyncClientPtr client_; - Http::AsyncClient::RequestPtr http_request_; + Http::AsyncClient::Request* http_request_{}; const proto::MethodDescriptor* grpc_method_{}; proto::Message* grpc_response_{}; RpcChannelCallbacks& callbacks_; diff --git a/source/common/http/async_client_impl.cc b/source/common/http/async_client_impl.cc index 5d2911b4f6c5..b1e51a2a3221 100644 --- a/source/common/http/async_client_impl.cc +++ b/source/common/http/async_client_impl.cc @@ -7,26 +7,35 @@ namespace Http { -const Http::HeaderMapImpl AsyncRequestImpl::SERVICE_UNAVAILABLE_HEADER{ - {Http::Headers::get().Status, std::to_string(enumToInt(Http::Code::ServiceUnavailable))}}; +const HeaderMapImpl AsyncRequestImpl::SERVICE_UNAVAILABLE_HEADER{ + {Headers::get().Status, std::to_string(enumToInt(Code::ServiceUnavailable))}}; -const Http::HeaderMapImpl AsyncRequestImpl::REQUEST_TIMEOUT_HEADER{ - {Http::Headers::get().Status, std::to_string(enumToInt(Http::Code::GatewayTimeout))}}; +const HeaderMapImpl AsyncRequestImpl::REQUEST_TIMEOUT_HEADER{ + {Headers::get().Status, std::to_string(enumToInt(Code::GatewayTimeout))}}; -AsyncClientImpl::AsyncClientImpl(ConnectionPool::Instance& conn_pool, const std::string& cluster, - Stats::Store& stats_store, Event::Dispatcher& dispatcher) - : conn_pool_(conn_pool), stat_prefix_(fmt::format("cluster.{}.", cluster)), - stats_store_(stats_store), dispatcher_(dispatcher) {} +AsyncClientImpl::AsyncClientImpl(const Upstream::Cluster& cluster, + AsyncClientConnPoolFactory& factory, Stats::Store& stats_store, + Event::Dispatcher& dispatcher) + : cluster_(cluster), factory_(factory), stats_store_(stats_store), dispatcher_(dispatcher), + stat_prefix_(fmt::format("cluster.{}.", cluster.name())) {} + +AsyncClientImpl::~AsyncClientImpl() { ASSERT(active_requests_.empty()); } + +AsyncClient::Request* AsyncClientImpl::send(MessagePtr&& request, AsyncClient::Callbacks& callbacks, + const Optional& timeout) { + ConnectionPool::Instance* conn_pool = factory_.connPool(); + if (!conn_pool) { + callbacks.onFailure(AsyncClient::FailureReason::Reset); + return nullptr; + } -AsyncClient::RequestPtr AsyncClientImpl::send(MessagePtr&& request, - AsyncClient::Callbacks& callbacks, - const Optional& timeout) { std::unique_ptr new_request{ - new AsyncRequestImpl(std::move(request), *this, callbacks, dispatcher_, timeout)}; + new AsyncRequestImpl(std::move(request), *this, callbacks, dispatcher_, *conn_pool, timeout)}; // The request may get immediately failed. If so, we will return nullptr. if (new_request->stream_encoder_) { - return std::move(new_request); + new_request->moveIntoList(std::move(new_request), active_requests_); + return active_requests_.front().get(); } else { return nullptr; } @@ -34,10 +43,11 @@ AsyncClient::RequestPtr AsyncClientImpl::send(MessagePtr&& request, AsyncRequestImpl::AsyncRequestImpl(MessagePtr&& request, AsyncClientImpl& parent, AsyncClient::Callbacks& callbacks, Event::Dispatcher& dispatcher, + ConnectionPool::Instance& conn_pool, const Optional& timeout) : request_(std::move(request)), parent_(parent), callbacks_(callbacks) { - stream_encoder_.reset(new PooledStreamEncoder(parent_.conn_pool_, *this, *this, 0, 0, *this)); + stream_encoder_.reset(new PooledStreamEncoder(conn_pool, *this, *this, 0, 0, *this)); stream_encoder_->encodeHeaders(request_->headers(), !request_->body()); // We might have been immediately failed. @@ -66,9 +76,9 @@ void AsyncRequestImpl::decodeHeaders(HeaderMapPtr&& headers, bool end_stream) { -> void { log_debug(" '{}':'{}'", key.get(), value); }); #endif - Http::CodeUtility::ResponseStatInfo info{parent_.stats_store_, parent_.stat_prefix_, - response_->headers(), true, EMPTY_STRING, EMPTY_STRING}; - Http::CodeUtility::chargeResponseStat(info); + CodeUtility::ResponseStatInfo info{parent_.stats_store_, parent_.stat_prefix_, + response_->headers(), true, EMPTY_STRING, EMPTY_STRING}; + CodeUtility::chargeResponseStat(info); if (end_stream) { onComplete(); @@ -102,34 +112,32 @@ void AsyncRequestImpl::decodeTrailers(HeaderMapPtr&& trailers) { void AsyncRequestImpl::onComplete() { // TODO: Check host's canary status in addition to canary header. - Http::CodeUtility::ResponseTimingInfo info{ + CodeUtility::ResponseTimingInfo info{ parent_.stats_store_, parent_.stat_prefix_, stream_encoder_->requestCompleteTime(), - response_->headers().get(Http::Headers::get().EnvoyUpstreamCanary) == "true", true, - EMPTY_STRING, EMPTY_STRING}; - Http::CodeUtility::chargeResponseTiming(info); + response_->headers().get(Headers::get().EnvoyUpstreamCanary) == "true", true, EMPTY_STRING, + EMPTY_STRING}; + CodeUtility::chargeResponseTiming(info); - cleanup(); callbacks_.onSuccess(std::move(response_)); + cleanup(); } void AsyncRequestImpl::onResetStream(StreamResetReason) { - Http::CodeUtility::ResponseStatInfo info{parent_.stats_store_, parent_.stat_prefix_, - SERVICE_UNAVAILABLE_HEADER, true, EMPTY_STRING, - EMPTY_STRING}; - Http::CodeUtility::chargeResponseStat(info); + CodeUtility::ResponseStatInfo info{parent_.stats_store_, parent_.stat_prefix_, + SERVICE_UNAVAILABLE_HEADER, true, EMPTY_STRING, EMPTY_STRING}; + CodeUtility::chargeResponseStat(info); + callbacks_.onFailure(AsyncClient::FailureReason::Reset); cleanup(); - callbacks_.onFailure(Http::AsyncClient::FailureReason::Reset); } void AsyncRequestImpl::onRequestTimeout() { - Http::CodeUtility::ResponseStatInfo info{parent_.stats_store_, parent_.stat_prefix_, - REQUEST_TIMEOUT_HEADER, true, EMPTY_STRING, - EMPTY_STRING}; - Http::CodeUtility::chargeResponseStat(info); - parent_.stats_store_.counter(fmt::format("{}upstream_rq_timeout", parent_.stat_prefix_)).inc(); + CodeUtility::ResponseStatInfo info{parent_.stats_store_, parent_.stat_prefix_, + REQUEST_TIMEOUT_HEADER, true, EMPTY_STRING, EMPTY_STRING}; + CodeUtility::chargeResponseStat(info); + parent_.cluster_.stats().upstream_rq_timeout_.inc(); stream_encoder_->resetStream(); + callbacks_.onFailure(AsyncClient::FailureReason::RequestTimemout); cleanup(); - callbacks_.onFailure(Http::AsyncClient::FailureReason::RequestTimemout); } void AsyncRequestImpl::cleanup() { @@ -137,5 +145,12 @@ void AsyncRequestImpl::cleanup() { if (request_timeout_) { request_timeout_->disableTimer(); } + + // This will destroy us, but only do so if we are actually in a list. This does not happen in + // the immediate failure case. + if (inserted()) { + removeFromList(parent_.active_requests_); + } } + } // Http diff --git a/source/common/http/async_client_impl.h b/source/common/http/async_client_impl.h index ad42a8d63d30..afc1b693afda 100644 --- a/source/common/http/async_client_impl.h +++ b/source/common/http/async_client_impl.h @@ -10,24 +10,43 @@ #include "envoy/http/message.h" #include "common/common/assert.h" +#include "common/common/linked_object.h" #include "common/http/header_map_impl.h" namespace Http { +/** + * Factory for obtaining a connection pool. + */ +class AsyncClientConnPoolFactory { +public: + virtual ~AsyncClientConnPoolFactory() {} + + /** + * Return a connection pool or nullptr if there is no healthy upstream host. + */ + virtual ConnectionPool::Instance* connPool() PURE; +}; + +class AsyncRequestImpl; + class AsyncClientImpl final : public AsyncClient { public: - AsyncClientImpl(ConnectionPool::Instance& conn_pool, const std::string& cluster, + AsyncClientImpl(const Upstream::Cluster& cluster, AsyncClientConnPoolFactory& factory, Stats::Store& stats_store, Event::Dispatcher& dispatcher); + ~AsyncClientImpl(); // Http::AsyncClient - RequestPtr send(MessagePtr&& request, Callbacks& callbacks, - const Optional& timeout) override; + Request* send(MessagePtr&& request, Callbacks& callbacks, + const Optional& timeout) override; private: - ConnectionPool::Instance& conn_pool_; - const std::string stat_prefix_; + const Upstream::Cluster& cluster_; + AsyncClientConnPoolFactory& factory_; Stats::Store& stats_store_; Event::Dispatcher& dispatcher_; + const std::string stat_prefix_; + std::list> active_requests_; friend class AsyncRequestImpl; }; @@ -40,10 +59,11 @@ class AsyncRequestImpl final : public AsyncClient::Request, StreamDecoder, StreamCallbacks, PooledStreamEncoderCallbacks, - Logger::Loggable { + Logger::Loggable, + LinkedObject { public: AsyncRequestImpl(MessagePtr&& request, AsyncClientImpl& parent, AsyncClient::Callbacks& callbacks, - Event::Dispatcher& dispatcher, + Event::Dispatcher& dispatcher, ConnectionPool::Instance& conn_pool, const Optional& timeout); ~AsyncRequestImpl(); @@ -74,8 +94,8 @@ class AsyncRequestImpl final : public AsyncClient::Request, std::unique_ptr response_; PooledStreamEncoderPtr stream_encoder_; - static const Http::HeaderMapImpl SERVICE_UNAVAILABLE_HEADER; - static const Http::HeaderMapImpl REQUEST_TIMEOUT_HEADER; + static const HeaderMapImpl SERVICE_UNAVAILABLE_HEADER; + static const HeaderMapImpl REQUEST_TIMEOUT_HEADER; friend class AsyncClientImpl; }; diff --git a/source/common/ratelimit/ratelimit_impl.cc b/source/common/ratelimit/ratelimit_impl.cc index f20aa9416a42..48733103c3d7 100644 --- a/source/common/ratelimit/ratelimit_impl.cc +++ b/source/common/ratelimit/ratelimit_impl.cc @@ -60,7 +60,7 @@ void GrpcClientImpl::onFailure(const Optional&, const std::string&) { GrpcFactoryImpl::GrpcFactoryImpl(const Json::Object& config, Upstream::ClusterManager& cm, Stats::Store& stats_store) : cluster_name_(config.getString("cluster_name")), cm_(cm), stats_store_(stats_store) { - if (!cm_.has(cluster_name_)) { + if (!cm_.get(cluster_name_)) { throw EnvoyException(fmt::format("unknown rate limit service cluster '{}'", cluster_name_)); } } diff --git a/source/common/router/config_impl.cc b/source/common/router/config_impl.cc index c3443ff64c61..721eb720040e 100644 --- a/source/common/router/config_impl.cc +++ b/source/common/router/config_impl.cc @@ -194,7 +194,7 @@ VirtualHost::VirtualHost(const Json::Object& virtual_host, Runtime::Loader& runt } if (!routes_.back()->isRedirect()) { - if (!cm.has(routes_.back()->clusterName())) { + if (!cm.get(routes_.back()->clusterName())) { throw EnvoyException( fmt::format("route: unknown cluster '{}'", routes_.back()->clusterName())); } diff --git a/source/common/stats/statsd.cc b/source/common/stats/statsd.cc index e6c054eafeb8..45536f1f5325 100644 --- a/source/common/stats/statsd.cc +++ b/source/common/stats/statsd.cc @@ -60,7 +60,7 @@ TcpStatsdSink::TcpStatsdSink(const std::string& stat_cluster, const std::string& : stat_cluster_(stat_cluster), stat_host_(stat_host), cluster_name_(cluster_name), tls_(tls), tls_slot_(tls.allocateSlot()), cluster_manager_(cluster_manager) { - if (!cluster_manager.has(cluster_name)) { + if (!cluster_manager.get(cluster_name)) { throw EnvoyException(fmt::format("unknown TCP statsd upstream cluster: {}", cluster_name)); } diff --git a/source/common/tracing/http_tracer_impl.cc b/source/common/tracing/http_tracer_impl.cc index ba86a684fadb..a95f1eab2841 100644 --- a/source/common/tracing/http_tracer_impl.cc +++ b/source/common/tracing/http_tracer_impl.cc @@ -258,90 +258,39 @@ std::string LightStepUtility::buildJsonBody(const Http::HeaderMap& request_heade } LightStepSink::LightStepSink(const Json::Object& config, Upstream::ClusterManager& cluster_manager, - ThreadLocal::Instance& tls, const std::string& stat_prefix, - Stats::Store& stats, Runtime::RandomGenerator& random, + const std::string& stat_prefix, Stats::Store& stats, + Runtime::RandomGenerator& random, const std::string& local_service_cluster, const std::string& service_node, const std::string& access_token) - : cm_(cluster_manager), tls_(tls), tls_slot_(tls.allocateSlot()) { - collector_cluster_ = config.getString("collector_cluster"); - if (!cm_.has(collector_cluster_)) { - throw EnvoyException(fmt::format("{} collector cluster is not defined on cluster manager level", - collector_cluster_)); - } - - tls.set(tls_slot_, - [this, stat_prefix, &stats, &random, local_service_cluster, service_node, access_token]( - Event::Dispatcher&) -> ThreadLocal::ThreadLocalObjectPtr { - return ThreadLocal::ThreadLocalObjectPtr{new TlsSink(*this, stat_prefix, stats, random, - local_service_cluster, - service_node, access_token)}; - }); -} - -LightStepSink::TlsSink::TlsSink(LightStepSink& parent, const std::string& stat_prefix, - Stats::Store& stats, Runtime::RandomGenerator& random, - const std::string& service_cluster, const std::string& service_node, - const std::string& access_token) - : parent_(parent), + : collector_cluster_(config.getString("collector_cluster")), cm_(cluster_manager), stats_{LIGHTSTEP_STATS(POOL_COUNTER_PREFIX(stats, stat_prefix + "tracing.lightstep."))}, - random_(random), local_service_cluster_(service_cluster), service_node_(service_node), + random_(random), local_service_cluster_(local_service_cluster), service_node_(service_node), access_token_(access_token) { - shutdown_ = false; -} - -void LightStepSink::TlsSink::shutdown() { - shutdown_ = true; - - for (auto& active_request : active_requests_) { - active_request->request_->cancel(); + if (!cm_.get(collector_cluster_)) { + throw EnvoyException(fmt::format("{} collector cluster is not defined on cluster manager level", + collector_cluster_)); } } -void LightStepSink::TlsSink::flushTrace(const Http::HeaderMap& request_headers, - const Http::HeaderMap& response_headers, - const Http::AccessLog::RequestInfo& request_info) { - if (shutdown_) { - return; - } - - Http::AsyncClientPtr client = parent_.cm_.httpAsyncClientForCluster(parent_.collector_cluster_); - - if (!client) { - stats_.client_failed_.inc(); - return; - } - +void LightStepSink::flushTrace(const Http::HeaderMap& request_headers, + const Http::HeaderMap& response_headers, + const Http::AccessLog::RequestInfo& request_info) { Http::MessagePtr msg = LightStepUtility::buildHeaders(access_token_); Buffer::InstancePtr buffer(new Buffer::OwnedImpl( LightStepUtility::buildJsonBody(request_headers, response_headers, request_info, random_, local_service_cluster_, service_node_))); msg->body(std::move(buffer)); - - executeRequest(std::move(client), std::move(msg)); + executeRequest(std::move(msg)); } -void LightStepSink::TlsSink::executeRequest(Http::AsyncClientPtr&& client, Http::MessagePtr&& msg) { - ActiveRequestPtr active_request(new LightStepSink::ActiveRequest(*this)); - Http::AsyncClient::RequestPtr request = - client->send(std::move(msg), *active_request, std::chrono::milliseconds(5000)); - if (request) { - active_request->request_ = std::move(request); - active_request->client_ = std::move(client); - active_request->moveIntoListBack(std::move(active_request), active_requests_); - } +void LightStepSink::executeRequest(Http::MessagePtr&& msg) { + cm_.httpAsyncClientForCluster(collector_cluster_) + .send(std::move(msg), *this, std::chrono::milliseconds(5000)); } -void LightStepSink::ActiveRequest::onFailure(Http::AsyncClient::FailureReason) { - parent_.stats_.collector_failed_.inc(); - clean(); -} - -void LightStepSink::ActiveRequest::onSuccess(Http::MessagePtr&&) { - parent_.stats_.collector_success_.inc(); - clean(); -} +void LightStepSink::onFailure(Http::AsyncClient::FailureReason) { stats_.collector_failed_.inc(); } -void LightStepSink::ActiveRequest::clean() { removeFromList(parent_.active_requests_); } +void LightStepSink::onSuccess(Http::MessagePtr&&) { stats_.collector_success_.inc(); } } // Tracing \ No newline at end of file diff --git a/source/common/tracing/http_tracer_impl.h b/source/common/tracing/http_tracer_impl.h index 9a3135032594..ca281e09e825 100644 --- a/source/common/tracing/http_tracer_impl.h +++ b/source/common/tracing/http_tracer_impl.h @@ -1,11 +1,9 @@ #pragma once #include "envoy/runtime/runtime.h" -#include "envoy/thread_local/thread_local.h" #include "envoy/tracing/http_tracer.h" #include "envoy/upstream/cluster_manager.h" -#include "common/common/linked_object.h" #include "common/http/header_map_impl.h" #include "common/json/json_loader.h" @@ -13,8 +11,7 @@ namespace Tracing { #define LIGHTSTEP_STATS(COUNTER) \ COUNTER(collector_failed) \ - COUNTER(collector_success) \ - COUNTER(client_failed) + COUNTER(collector_success) struct LightStepStats { LIGHTSTEP_STATS(GENERATE_COUNTER_STRUCT) @@ -136,64 +133,31 @@ class LightStepUtility { * * LightStepSink is for flushing data to LightStep collectors. */ -class LightStepSink : public HttpSink { +class LightStepSink : public HttpSink, public Http::AsyncClient::Callbacks { public: LightStepSink(const Json::Object& config, Upstream::ClusterManager& cluster_manager, - ThreadLocal::Instance& tls, const std::string& stat_prefix, Stats::Store& stats, + const std::string& stat_prefix, Stats::Store& stats, Runtime::RandomGenerator& random, const std::string& local_service_cluster, const std::string& service_node, const std::string& access_token); // Tracer::HttpSink void flushTrace(const Http::HeaderMap& request_headers, const Http::HeaderMap& response_headers, - const Http::AccessLog::RequestInfo& request_info) override { - tls_.getTyped(tls_slot_).flushTrace(request_headers, response_headers, request_info); - } + const Http::AccessLog::RequestInfo& request_info) override; + + // Http::AsyncClient::Callbacks + void onSuccess(Http::MessagePtr&&) override; + void onFailure(Http::AsyncClient::FailureReason reason) override; private: - struct ActiveRequest; - typedef std::unique_ptr ActiveRequestPtr; - - struct TlsSink : public ThreadLocal::ThreadLocalObject { - TlsSink(LightStepSink& parent, const std::string& stat_prefix, Stats::Store& stats, - Runtime::RandomGenerator& random, const std::string& service_cluster, - const std::string& service_node, const std::string& access_token); - ~TlsSink() {} - - void flushTrace(const Http::HeaderMap& request_headers, const Http::HeaderMap& response_headers, - const Http::AccessLog::RequestInfo& request_info); - void executeRequest(Http::AsyncClientPtr&& client, Http::MessagePtr&& msg); - - // ThreadLocal::ThreadLocalObject - void shutdown() override; - - LightStepSink& parent_; - bool shutdown_{}; - LightStepStats stats_; - Runtime::RandomGenerator& random_; - std::string local_service_cluster_; - std::string service_node_; - std::string access_token_; - std::list active_requests_; - }; - - struct ActiveRequest : public Http::AsyncClient::Callbacks, LinkedObject { - ActiveRequest(TlsSink& parent) : parent_(parent) {} - - // Http::AsyncClient::Callbacks - void onSuccess(Http::MessagePtr&&) override; - void onFailure(Http::AsyncClient::FailureReason reason) override; - - void clean(); - - TlsSink& parent_; - Http::AsyncClientPtr client_; - Http::AsyncClient::RequestPtr request_; - }; - - std::string collector_cluster_; + void executeRequest(Http::MessagePtr&& msg); + + const std::string collector_cluster_; Upstream::ClusterManager& cm_; - ThreadLocal::Instance& tls_; - const uint32_t tls_slot_; + LightStepStats stats_; + Runtime::RandomGenerator& random_; + const std::string local_service_cluster_; + const std::string service_node_; + const std::string access_token_; }; } // Tracing \ No newline at end of file diff --git a/source/common/upstream/cluster_manager_impl.cc b/source/common/upstream/cluster_manager_impl.cc index 2f2badb83a0c..8b78bd22ed59 100644 --- a/source/common/upstream/cluster_manager_impl.cc +++ b/source/common/upstream/cluster_manager_impl.cc @@ -145,19 +145,11 @@ ClusterManagerImpl::httpConnPoolForCluster(const std::string& cluster) { // Select a host and create a connection pool for it if it does not already exist. auto entry = cluster_manager.thread_local_clusters_.find(cluster); - ConstHostPtr host = entry->second->lb_->chooseHost(); - if (!host) { - entry->second->primary_cluster_.stats().upstream_cx_none_healthy_.inc(); - return nullptr; - } - - if (cluster_manager.host_http_conn_pool_map_.find(host) == - cluster_manager.host_http_conn_pool_map_.end()) { - cluster_manager.host_http_conn_pool_map_[host] = - allocateConnPool(cluster_manager.dispatcher_, host, stats_); + if (entry == cluster_manager.thread_local_clusters_.end()) { + throw EnvoyException(fmt::format("unknown cluster '{}'", cluster)); } - return cluster_manager.host_http_conn_pool_map_[host].get(); + return entry->second->connPool(); } void ClusterManagerImpl::postThreadLocalClusterUpdate(const ClusterImplBase& primary_cluster, @@ -184,6 +176,10 @@ Host::CreateConnectionData ClusterManagerImpl::tcpConnForCluster(const std::stri tls_.getTyped(thread_local_slot_); auto entry = cluster_manager.thread_local_clusters_.find(cluster); + if (entry == cluster_manager.thread_local_clusters_.end()) { + throw EnvoyException(fmt::format("unknown cluster '{}'", cluster)); + } + ConstHostPtr logical_host = entry->second->lb_->chooseHost(); if (logical_host) { return logical_host->createConnection(cluster_manager.dispatcher_); @@ -193,28 +189,28 @@ Host::CreateConnectionData ClusterManagerImpl::tcpConnForCluster(const std::stri } } -Http::AsyncClientPtr ClusterManagerImpl::httpAsyncClientForCluster(const std::string& cluster) { - Http::ConnectionPool::Instance* conn_pool = httpConnPoolForCluster(cluster); +Http::AsyncClient& ClusterManagerImpl::httpAsyncClientForCluster(const std::string& cluster) { ThreadLocalClusterManagerImpl& cluster_manager = tls_.getTyped(thread_local_slot_); - if (conn_pool) { - return Http::AsyncClientPtr{ - new Http::AsyncClientImpl(*conn_pool, cluster, stats_, cluster_manager.dispatcher_)}; + auto entry = cluster_manager.thread_local_clusters_.find(cluster); + if (entry != cluster_manager.thread_local_clusters_.end()) { + return entry->second->http_async_client_; } else { - return nullptr; + throw EnvoyException(fmt::format("unknown cluster '{}'", cluster)); } } ClusterManagerImpl::ThreadLocalClusterManagerImpl::ThreadLocalClusterManagerImpl( ClusterManagerImpl& parent, Event::Dispatcher& dispatcher, Runtime::Loader& runtime, Runtime::RandomGenerator& random) - : dispatcher_(dispatcher) { + : parent_(parent), dispatcher_(dispatcher) { for (auto& cluster : parent.primary_clusters_) { - thread_local_clusters_[cluster.first].reset(new ClusterEntry(*cluster.second, runtime, random)); + thread_local_clusters_[cluster.first].reset( + new ClusterEntry(*this, *cluster.second, runtime, random, parent.stats_, dispatcher)); } for (auto& cluster : thread_local_clusters_) { - cluster.second->host_set_->addMemberUpdateCb( + cluster.second->host_set_.addMemberUpdateCb( [this](const std::vector&, const std::vector& hosts_removed) -> void { // We need to go through and purge any connection pools for hosts that got deleted. // Right now hosts are specific to clusters, so even if two hosts actually point @@ -245,7 +241,7 @@ void ClusterManagerImpl::ThreadLocalClusterManagerImpl::updateClusterMembership( tls.getTyped(thead_local_slot); ASSERT(config.thread_local_clusters_.find(name) != config.thread_local_clusters_.end()); - config.thread_local_clusters_[name]->host_set_->updateHosts( + config.thread_local_clusters_[name]->host_set_.updateHosts( hosts, healthy_hosts, local_zone_hosts, local_zone_healthy_hosts, hosts_added, hosts_removed); } @@ -254,25 +250,43 @@ void ClusterManagerImpl::ThreadLocalClusterManagerImpl::shutdown() { } ClusterManagerImpl::ThreadLocalClusterManagerImpl::ClusterEntry::ClusterEntry( - const Cluster& parent, Runtime::Loader& runtime, Runtime::RandomGenerator& random) - : host_set_(new HostSetImpl()), primary_cluster_(parent) { + ThreadLocalClusterManagerImpl& parent, const Cluster& cluster, Runtime::Loader& runtime, + Runtime::RandomGenerator& random, Stats::Store& stats_store, Event::Dispatcher& dispatcher) + : parent_(parent), primary_cluster_(cluster), + http_async_client_(cluster, *this, stats_store, dispatcher) { - switch (parent.lbType()) { + switch (cluster.lbType()) { case LoadBalancerType::LeastRequest: { - lb_.reset(new LeastRequestLoadBalancer(*host_set_, parent.stats(), runtime, random)); + lb_.reset(new LeastRequestLoadBalancer(host_set_, cluster.stats(), runtime, random)); break; } case LoadBalancerType::Random: { - lb_.reset(new RandomLoadBalancer(*host_set_, parent.stats(), runtime, random)); + lb_.reset(new RandomLoadBalancer(host_set_, cluster.stats(), runtime, random)); break; } case LoadBalancerType::RoundRobin: { - lb_.reset(new RoundRobinLoadBalancer(*host_set_, parent.stats(), runtime)); + lb_.reset(new RoundRobinLoadBalancer(host_set_, cluster.stats(), runtime)); break; } } } +Http::ConnectionPool::Instance* +ClusterManagerImpl::ThreadLocalClusterManagerImpl::ClusterEntry::connPool() { + ConstHostPtr host = lb_->chooseHost(); + if (!host) { + primary_cluster_.stats().upstream_cx_none_healthy_.inc(); + return nullptr; + } + + if (parent_.host_http_conn_pool_map_.find(host) == parent_.host_http_conn_pool_map_.end()) { + parent_.host_http_conn_pool_map_[host] = + parent_.parent_.allocateConnPool(parent_.dispatcher_, host, parent_.parent_.stats_); + } + + return parent_.host_http_conn_pool_map_[host].get(); +} + Http::ConnectionPool::InstancePtr ProdClusterManagerImpl::allocateConnPool(Event::Dispatcher& dispatcher, ConstHostPtr host, Stats::Store& store) { diff --git a/source/common/upstream/cluster_manager_impl.h b/source/common/upstream/cluster_manager_impl.h index 8fcbbf554ece..4b00de9c0012 100644 --- a/source/common/upstream/cluster_manager_impl.h +++ b/source/common/upstream/cluster_manager_impl.h @@ -7,6 +7,7 @@ #include "envoy/thread_local/thread_local.h" #include "envoy/upstream/cluster_manager.h" +#include "common/http/async_client_impl.h" #include "common/json/json_loader.h" namespace Upstream { @@ -41,10 +42,9 @@ class ClusterManagerImpl : public ClusterManager { } const Cluster* get(const std::string& cluster) override; - bool has(const std::string& cluster) override { return primary_clusters_.count(cluster); } Http::ConnectionPool::Instance* httpConnPoolForCluster(const std::string& cluster) override; Host::CreateConnectionData tcpConnForCluster(const std::string& cluster) override; - Http::AsyncClientPtr httpAsyncClientForCluster(const std::string& cluster) override; + Http::AsyncClient& httpAsyncClientForCluster(const std::string& cluster) override; void shutdown() override { for (auto& cluster : primary_clusters_) { @@ -62,13 +62,19 @@ class ClusterManagerImpl : public ClusterManager { * connection pools. */ struct ThreadLocalClusterManagerImpl : public ThreadLocal::ThreadLocalObject { - struct ClusterEntry { - ClusterEntry(const Cluster& parent, Runtime::Loader& runtime, - Runtime::RandomGenerator& random); + struct ClusterEntry : public Http::AsyncClientConnPoolFactory { + ClusterEntry(ThreadLocalClusterManagerImpl& parent, const Cluster& cluster, + Runtime::Loader& runtime, Runtime::RandomGenerator& random, + Stats::Store& stats_store, Event::Dispatcher& dispatcher); - HostSetImplPtr host_set_; + // Http::AsyncClientConnPoolFactory + Http::ConnectionPool::Instance* connPool() override; + + ThreadLocalClusterManagerImpl& parent_; + HostSetImpl host_set_; LoadBalancerPtr lb_; const Cluster& primary_cluster_; + Http::AsyncClientImpl http_async_client_; }; typedef std::unique_ptr ClusterEntryPtr; @@ -87,6 +93,7 @@ class ClusterManagerImpl : public ClusterManager { // ThreadLocal::ThreadLocalObject void shutdown() override; + ClusterManagerImpl& parent_; Event::Dispatcher& dispatcher_; std::unordered_map thread_local_clusters_; std::unordered_map host_http_conn_pool_map_; diff --git a/source/common/upstream/sds.cc b/source/common/upstream/sds.cc index 82e9fc43bc61..48bf721ebf82 100644 --- a/source/common/upstream/sds.cc +++ b/source/common/upstream/sds.cc @@ -101,11 +101,6 @@ void SdsClusterImpl::parseSdsResponse(Http::Message& response) { void SdsClusterImpl::refreshHosts() { log_debug("starting sds refresh for cluster: {}", name_); stats_.update_attempt_.inc(); - client_ = cm_.httpAsyncClientForCluster(sds_config_.sds_cluster_name_); - if (!client_) { - onFailure(Http::AsyncClient::FailureReason::Reset); - return; - } Http::MessagePtr message(new Http::RequestMessageImpl()); message->headers().addViaMoveValue(Http::Headers::get().Scheme, "http"); @@ -113,7 +108,8 @@ void SdsClusterImpl::refreshHosts() { message->headers().addViaMoveValue(Http::Headers::get().Path, "/v1/registration/" + service_name_); message->headers().addViaMoveValue(Http::Headers::get().Host, "sds"); - active_request_ = client_->send(std::move(message), *this, Optional()); + active_request_ = cm_.httpAsyncClientForCluster(sds_config_.sds_cluster_name_) + .send(std::move(message), *this, Optional()); } void SdsClusterImpl::requestComplete() { @@ -125,8 +121,7 @@ void SdsClusterImpl::requestComplete() { initialize_callback_ = nullptr; } - active_request_.reset(); - client_.reset(); + active_request_ = nullptr; // Add refresh jitter based on the configured interval. std::chrono::milliseconds final_delay = @@ -139,8 +134,7 @@ void SdsClusterImpl::requestComplete() { void SdsClusterImpl::shutdown() { if (active_request_) { active_request_->cancel(); - active_request_.reset(); - client_.reset(); + active_request_ = nullptr; } refresh_timer_.reset(); diff --git a/source/common/upstream/sds.h b/source/common/upstream/sds.h index 94e09261d81c..e465dec280e4 100644 --- a/source/common/upstream/sds.h +++ b/source/common/upstream/sds.h @@ -54,8 +54,7 @@ class SdsClusterImpl : public BaseDynamicClusterImpl, public Http::AsyncClient:: const std::string service_name_; Runtime::RandomGenerator& random_; Event::TimerPtr refresh_timer_; - Http::AsyncClientPtr client_; - Http::AsyncClient::RequestPtr active_request_; + Http::AsyncClient::Request* active_request_{}; uint64_t pending_health_checks_{}; }; diff --git a/source/server/configuration_impl.cc b/source/server/configuration_impl.cc index 8c7a4ddf5eba..76bda817e681 100644 --- a/source/server/configuration_impl.cc +++ b/source/server/configuration_impl.cc @@ -82,9 +82,9 @@ void MainImpl::initializeTracers(const Json::Object& tracing_configuration_) { StringUtil::rtrim(access_token); http_tracer_->addSink(Tracing::HttpSinkPtr{new Tracing::LightStepSink( - sink.getObject("config"), *cluster_manager_, server_.threadLocal(), "", - server_.stats(), server_.random(), server_.options().serviceClusterName(), - server_.options().serviceNodeName(), access_token)}); + sink.getObject("config"), *cluster_manager_, "", server_.stats(), server_.random(), + server_.options().serviceClusterName(), server_.options().serviceNodeName(), + access_token)}); } else { throw EnvoyException(fmt::format("Unsupported sink type: '{}'", type)); } diff --git a/test/common/filter/auth/client_ssl_test.cc b/test/common/filter/auth/client_ssl_test.cc index d124b95b4cdb..401ac1281ee8 100644 --- a/test/common/filter/auth/client_ssl_test.cc +++ b/test/common/filter/auth/client_ssl_test.cc @@ -26,7 +26,8 @@ TEST(ClientSslAuthAllowedPrincipalsTest, EmptyString) { class ClientSslAuthFilterTest : public testing::Test { public: - ClientSslAuthFilterTest() : interval_timer_(new Event::MockTimer(&dispatcher_)) {} + ClientSslAuthFilterTest() + : interval_timer_(new Event::MockTimer(&dispatcher_)), request_(&cm_.async_client_) {} ~ClientSslAuthFilterTest() { tls_.shutdownThread(); } void setup() { @@ -39,7 +40,7 @@ class ClientSslAuthFilterTest : public testing::Test { )EOF"; Json::StringLoader loader(json); - EXPECT_CALL(cm_, has("vpn")).WillOnce(Return(true)); + EXPECT_CALL(cm_, get("vpn")); setupRequest(); config_.reset(new Config(loader, tls_, cm_, dispatcher_, stats_store_, runtime_, "127.0.0.1")); @@ -52,15 +53,14 @@ class ClientSslAuthFilterTest : public testing::Test { } void setupRequest() { - client_ = new NiceMock(); - EXPECT_CALL(cm_, httpAsyncClientForCluster_("vpn")).WillOnce(Return(client_)); - EXPECT_CALL(*client_, send_(_, _, _)) + EXPECT_CALL(cm_, httpAsyncClientForCluster("vpn")).WillOnce(ReturnRef(cm_.async_client_)); + EXPECT_CALL(cm_.async_client_, send_(_, _, _)) .WillOnce( Invoke([this](Http::MessagePtr& request, Http::AsyncClient::Callbacks& callbacks, Optional) -> Http::AsyncClient::Request* { EXPECT_EQ("127.0.0.1", request->headers().get("x-forwarded-for")); callbacks_ = &callbacks; - return new Http::MockAsyncClientRequest(client_); + return &request_; })); } @@ -71,11 +71,11 @@ class ClientSslAuthFilterTest : public testing::Test { NiceMock filter_callbacks_; std::unique_ptr instance_; Event::MockTimer* interval_timer_; - Http::MockAsyncClient* client_; Http::AsyncClient::Callbacks* callbacks_; Ssl::MockConnection ssl_; Stats::IsolatedStoreImpl stats_store_; NiceMock runtime_; + Http::MockAsyncClientRequest request_; }; TEST_F(ClientSslAuthFilterTest, NoCluster) { @@ -87,7 +87,7 @@ TEST_F(ClientSslAuthFilterTest, NoCluster) { )EOF"; Json::StringLoader loader(json); - EXPECT_CALL(cm_, has("bad_cluster")).WillOnce(Return(false)); + EXPECT_CALL(cm_, get("bad_cluster")).WillOnce(Return(nullptr)); EXPECT_THROW(new Config(loader, tls_, cm_, dispatcher_, stats_store_, runtime_, "127.0.0.1"), EnvoyException); } @@ -174,7 +174,14 @@ TEST_F(ClientSslAuthFilterTest, Basic) { // Interval timer fires, cannot obtain async client. EXPECT_CALL(*interval_timer_, enableTimer(_)); - EXPECT_CALL(cm_, httpAsyncClientForCluster_("vpn")).WillOnce(Return(nullptr)); + EXPECT_CALL(cm_, httpAsyncClientForCluster("vpn")).WillOnce(ReturnRef(cm_.async_client_)); + EXPECT_CALL(cm_.async_client_, send_(_, _, _)) + .WillOnce( + Invoke([&](Http::MessagePtr&, Http::AsyncClient::Callbacks& callbacks, + const Optional&) -> Http::AsyncClient::Request* { + callbacks.onFailure(Http::AsyncClient::FailureReason::Reset); + return nullptr; + })); interval_timer_->callback_(); EXPECT_EQ(4U, stats_store_.counter("auth.clientssl.vpn.update_failure").value()); diff --git a/test/common/filter/tcp_proxy_test.cc b/test/common/filter/tcp_proxy_test.cc index d6765af3f087..623473c395ab 100644 --- a/test/common/filter/tcp_proxy_test.cc +++ b/test/common/filter/tcp_proxy_test.cc @@ -25,7 +25,7 @@ TEST(TcpProxyConfigTest, NoCluster) { Json::StringLoader config(json); NiceMock cluster_manager; - EXPECT_CALL(cluster_manager, has("fake_cluster")).WillOnce(Return(false)); + EXPECT_CALL(cluster_manager, get("fake_cluster")).WillOnce(Return(nullptr)); EXPECT_THROW(TcpProxyConfig(config, cluster_manager, cluster_manager.cluster_.stats_store_), EnvoyException); } @@ -41,7 +41,6 @@ class TcpProxyTest : public testing::Test { )EOF"; Json::StringLoader config(json); - EXPECT_CALL(cluster_manager_, has("fake_cluster")).WillOnce(Return(true)); config_.reset( new TcpProxyConfig(config, cluster_manager_, cluster_manager_.cluster_.stats_store_)); } diff --git a/test/common/grpc/rpc_channel_impl_test.cc b/test/common/grpc/rpc_channel_impl_test.cc index 4eb92ed09974..15e9ea6bbbff 100644 --- a/test/common/grpc/rpc_channel_impl_test.cc +++ b/test/common/grpc/rpc_channel_impl_test.cc @@ -12,20 +12,19 @@ namespace Grpc { class GrpcRequestImplTest : public testing::Test { public: - GrpcRequestImplTest() { + GrpcRequestImplTest() : http_async_client_request_(&cm_.async_client_) { ON_CALL(cm_.cluster_, features()).WillByDefault(Return(Upstream::Cluster::Features::HTTP2)); } - void expectNormalRequest() { - http_async_client_ = new NiceMock(); - EXPECT_CALL(cm_, httpAsyncClientForCluster_("cluster")).WillOnce(Return(http_async_client_)); - EXPECT_CALL(*http_async_client_, send_(_, _, _)) + void expectNormalRequest( + const Optional timeout = Optional()) { + EXPECT_CALL(cm_, httpAsyncClientForCluster("cluster")).WillOnce(ReturnRef(cm_.async_client_)); + EXPECT_CALL(cm_.async_client_, send_(_, _, timeout)) .WillOnce(Invoke([&](Http::MessagePtr& request, Http::AsyncClient::Callbacks& callbacks, Optional) -> Http::AsyncClient::Request* { http_request_ = std::move(request); http_callbacks_ = &callbacks; - http_async_client_request_ = new Http::MockAsyncClientRequest(http_async_client_); - return http_async_client_request_; + return &http_async_client_request_; })); } @@ -34,8 +33,7 @@ class GrpcRequestImplTest : public testing::Test { RpcChannelImpl grpc_request_{cm_, "cluster", grpc_callbacks_, cm_.cluster_.stats_store_, Optional()}; helloworld::Greeter::Stub service_{&grpc_request_}; - Http::MockAsyncClient* http_async_client_{}; - Http::MockAsyncClientRequest* http_async_client_request_{}; + Http::MockAsyncClientRequest http_async_client_request_; Http::MessagePtr http_request_; Http::AsyncClient::Callbacks* http_callbacks_{}; }; @@ -231,21 +229,16 @@ TEST_F(GrpcRequestImplTest, HttpAsyncRequestTimeout) { http_callbacks_->onFailure(Http::AsyncClient::FailureReason::RequestTimemout); } -TEST_F(GrpcRequestImplTest, NoHttpAsyncClient) { - EXPECT_CALL(cm_, httpAsyncClientForCluster_("cluster")).WillOnce(Return(nullptr)); - EXPECT_CALL(grpc_callbacks_, onFailure(Optional(), "http request failure")); - - helloworld::HelloRequest request; - request.set_name("a name"); - helloworld::HelloReply response; - service_.SayHello(nullptr, &request, &response, nullptr); -} - TEST_F(GrpcRequestImplTest, NoHttpAsyncRequest) { - Http::MockAsyncClient* http_async_client = new NiceMock(); - EXPECT_CALL(cm_, httpAsyncClientForCluster_("cluster")).WillOnce(Return(http_async_client)); - EXPECT_CALL(*http_async_client, send_(_, _, _)).WillOnce(Return(nullptr)); - EXPECT_CALL(grpc_callbacks_, onFailure(Optional(), "http request failure")); + EXPECT_CALL(cm_, httpAsyncClientForCluster("cluster")).WillOnce(ReturnRef(cm_.async_client_)); + EXPECT_CALL(cm_.async_client_, send_(_, _, _)) + .WillOnce( + Invoke([&](Http::MessagePtr&, Http::AsyncClient::Callbacks& callbacks, + const Optional&) -> Http::AsyncClient::Request* { + callbacks.onFailure(Http::AsyncClient::FailureReason::Reset); + return nullptr; + })); + EXPECT_CALL(grpc_callbacks_, onFailure(Optional(), "stream reset")); helloworld::HelloRequest request; request.set_name("a name"); @@ -261,7 +254,7 @@ TEST_F(GrpcRequestImplTest, Cancel) { helloworld::HelloReply response; service_.SayHello(nullptr, &request, &response, nullptr); - EXPECT_CALL(*http_async_client_request_, cancel()); + EXPECT_CALL(http_async_client_request_, cancel()); grpc_request_.cancel(); } @@ -270,17 +263,7 @@ TEST_F(GrpcRequestImplTest, RequestTimeoutSet) { RpcChannelImpl grpc_request_timeout{cm_, "cluster", grpc_callbacks_, cm_.cluster_.stats_store_, timeout}; helloworld::Greeter::Stub service_timeout{&grpc_request_timeout}; - http_async_client_ = new NiceMock(); - EXPECT_CALL(cm_, httpAsyncClientForCluster_("cluster")).WillOnce(Return(http_async_client_)); - EXPECT_CALL(*http_async_client_, send_(_, _, timeout)) - .WillOnce( - Invoke([&](Http::MessagePtr& request, Http::AsyncClient::Callbacks& callbacks, - const Optional&) -> Http::AsyncClient::Request* { - http_request_ = std::move(request); - http_callbacks_ = &callbacks; - http_async_client_request_ = new Http::MockAsyncClientRequest(http_async_client_); - return http_async_client_request_; - })); + expectNormalRequest(timeout); helloworld::HelloRequest request; request.set_name("a name"); helloworld::HelloReply response; diff --git a/test/common/http/async_client_impl_test.cc b/test/common/http/async_client_impl_test.cc index f306c34c9b4a..ff4c20ad8b0a 100644 --- a/test/common/http/async_client_impl_test.cc +++ b/test/common/http/async_client_impl_test.cc @@ -6,6 +6,7 @@ #include "test/mocks/common.h" #include "test/mocks/http/mocks.h" #include "test/mocks/stats/mocks.h" +#include "test/mocks/upstream/mocks.h" using testing::_; using testing::ByRef; @@ -15,10 +16,13 @@ using testing::Ref; namespace Http { -class AsyncClientImplTest : public testing::Test { +class AsyncClientImplTest : public testing::Test, public AsyncClientConnPoolFactory { public: AsyncClientImplTest() { HttpTestUtility::addDefaultHeaders(message_->headers()); } + // Http::AsyncClientConnPoolFactory + Http::ConnectionPool::Instance* connPool() override { return &conn_pool_; } + MessagePtr message_{new RequestMessageImpl()}; MockAsyncClientCallbacks callbacks_; ConnectionPool::MockInstance conn_pool_; @@ -27,6 +31,7 @@ class AsyncClientImplTest : public testing::Test { NiceMock stats_store_; NiceMock* timer_; NiceMock dispatcher_; + NiceMock cluster_; }; TEST_F(AsyncClientImplTest, Basic) { @@ -45,9 +50,8 @@ TEST_F(AsyncClientImplTest, Basic) { EXPECT_CALL(stream_encoder_, encodeData(BufferEqual(&data), true)); EXPECT_CALL(callbacks_, onSuccess_(_)); - AsyncClientImpl client(conn_pool_, "fake_cluster", stats_store_, dispatcher_); - AsyncClient::RequestPtr request = - client.send(std::move(message_), callbacks_, Optional()); + AsyncClientImpl client(cluster_, *this, stats_store_, dispatcher_); + client.send(std::move(message_), callbacks_, Optional()); EXPECT_CALL(stats_store_, counter("cluster.fake_cluster.upstream_rq_2xx")); EXPECT_CALL(stats_store_, counter("cluster.fake_cluster.upstream_rq_200")); @@ -62,6 +66,53 @@ TEST_F(AsyncClientImplTest, Basic) { response_decoder_->decodeData(data, true); } +TEST_F(AsyncClientImplTest, MultipleRequests) { + // Send request 1 + message_->body(Buffer::InstancePtr{new Buffer::OwnedImpl("test body")}); + Buffer::Instance& data = *message_->body(); + + EXPECT_CALL(conn_pool_, newStream(_, _)) + .WillOnce(Invoke([&](StreamDecoder& decoder, ConnectionPool::Callbacks& callbacks) + -> ConnectionPool::Cancellable* { + callbacks.onPoolReady(stream_encoder_, conn_pool_.host_); + response_decoder_ = &decoder; + return nullptr; + })); + + EXPECT_CALL(stream_encoder_, encodeHeaders(HeaderMapEqualRef(ByRef(message_->headers())), false)); + EXPECT_CALL(stream_encoder_, encodeData(BufferEqual(&data), true)); + + AsyncClientImpl client(cluster_, *this, stats_store_, dispatcher_); + client.send(std::move(message_), callbacks_, Optional()); + + // Send request 2. + MessagePtr message2{new RequestMessageImpl()}; + HttpTestUtility::addDefaultHeaders(message2->headers()); + NiceMock stream_encoder2; + StreamDecoder* response_decoder2{}; + MockAsyncClientCallbacks callbacks2; + EXPECT_CALL(conn_pool_, newStream(_, _)) + .WillOnce(Invoke([&](StreamDecoder& decoder, ConnectionPool::Callbacks& callbacks) + -> ConnectionPool::Cancellable* { + callbacks.onPoolReady(stream_encoder2, conn_pool_.host_); + response_decoder2 = &decoder; + return nullptr; + })); + EXPECT_CALL(stream_encoder2, encodeHeaders(HeaderMapEqualRef(ByRef(message2->headers())), true)); + client.send(std::move(message2), callbacks2, Optional()); + + // Finish request 2. + HeaderMapPtr response_headers2(new HeaderMapImpl{{":status", "503"}}); + EXPECT_CALL(callbacks2, onSuccess_(_)); + response_decoder2->decodeHeaders(std::move(response_headers2), true); + + // Finish request 1. + HeaderMapPtr response_headers(new HeaderMapImpl{{":status", "200"}}); + response_decoder_->decodeHeaders(std::move(response_headers), false); + EXPECT_CALL(callbacks_, onSuccess_(_)); + response_decoder_->decodeData(data, true); +} + TEST_F(AsyncClientImplTest, Trailers) { message_->body(Buffer::InstancePtr{new Buffer::OwnedImpl("test body")}); Buffer::Instance& data = *message_->body(); @@ -78,9 +129,8 @@ TEST_F(AsyncClientImplTest, Trailers) { EXPECT_CALL(stream_encoder_, encodeData(BufferEqual(&data), true)); EXPECT_CALL(callbacks_, onSuccess_(_)); - AsyncClientImpl client(conn_pool_, "fake_cluster", stats_store_, dispatcher_); - AsyncClient::RequestPtr request = - client.send(std::move(message_), callbacks_, Optional()); + AsyncClientImpl client(cluster_, *this, stats_store_, dispatcher_); + client.send(std::move(message_), callbacks_, Optional()); HeaderMapPtr response_headers(new HeaderMapImpl{{":status", "200"}}); response_decoder_->decodeHeaders(std::move(response_headers), false); response_decoder_->decodeData(data, false); @@ -103,9 +153,8 @@ TEST_F(AsyncClientImplTest, FailRequest) { EXPECT_CALL(stream_encoder_, encodeHeaders(HeaderMapEqualRef(ByRef(message_->headers())), true)); EXPECT_CALL(callbacks_, onFailure(Http::AsyncClient::FailureReason::Reset)); - AsyncClientImpl client(conn_pool_, "fake_cluster", stats_store_, dispatcher_); - AsyncClient::RequestPtr request = - client.send(std::move(message_), callbacks_, Optional()); + AsyncClientImpl client(cluster_, *this, stats_store_, dispatcher_); + client.send(std::move(message_), callbacks_, Optional()); stream_encoder_.getStream().resetStream(StreamResetReason::RemoteReset); } @@ -120,8 +169,8 @@ TEST_F(AsyncClientImplTest, CancelRequest) { EXPECT_CALL(stream_encoder_, encodeHeaders(HeaderMapEqualRef(ByRef(message_->headers())), true)); EXPECT_CALL(stream_encoder_.stream_, resetStream(_)); - AsyncClientImpl client(conn_pool_, "fake_cluster", stats_store_, dispatcher_); - AsyncClient::RequestPtr request = + AsyncClientImpl client(cluster_, *this, stats_store_, dispatcher_); + AsyncClient::Request* request = client.send(std::move(message_), callbacks_, Optional()); request->cancel(); } @@ -141,7 +190,7 @@ TEST_F(AsyncClientImplTest, PoolFailure) { })); EXPECT_CALL(callbacks_, onFailure(Http::AsyncClient::FailureReason::Reset)); - AsyncClientImpl client(conn_pool_, "fake_cluster", stats_store_, dispatcher_); + AsyncClientImpl client(cluster_, *this, stats_store_, dispatcher_); EXPECT_EQ(nullptr, client.send(std::move(message_), callbacks_, Optional())); } @@ -151,7 +200,6 @@ TEST_F(AsyncClientImplTest, RequestTimeout) { EXPECT_CALL(stats_store_, counter("cluster.fake_cluster.upstream_rq_504")); EXPECT_CALL(stats_store_, counter("cluster.fake_cluster.internal.upstream_rq_5xx")); EXPECT_CALL(stats_store_, counter("cluster.fake_cluster.internal.upstream_rq_504")); - EXPECT_CALL(stats_store_, counter("cluster.fake_cluster.upstream_rq_timeout")); EXPECT_CALL(conn_pool_, newStream(_, _)) .WillOnce(Invoke([&](StreamDecoder&, ConnectionPool::Callbacks& callbacks) -> ConnectionPool::Cancellable* { @@ -164,10 +212,11 @@ TEST_F(AsyncClientImplTest, RequestTimeout) { timer_ = new NiceMock(&dispatcher_); EXPECT_CALL(*timer_, enableTimer(std::chrono::milliseconds(40))); EXPECT_CALL(stream_encoder_.stream_, resetStream(_)); - AsyncClientImpl client(conn_pool_, "fake_cluster", stats_store_, dispatcher_); - AsyncClient::RequestPtr request = - client.send(std::move(message_), callbacks_, std::chrono::milliseconds(40)); + AsyncClientImpl client(cluster_, *this, stats_store_, dispatcher_); + client.send(std::move(message_), callbacks_, std::chrono::milliseconds(40)); timer_->callback_(); + + EXPECT_EQ(1UL, cluster_.stats_store_.counter("cluster.fake_cluster.upstream_rq_timeout").value()); } TEST_F(AsyncClientImplTest, DisableTimer) { @@ -183,8 +232,8 @@ TEST_F(AsyncClientImplTest, DisableTimer) { EXPECT_CALL(*timer_, enableTimer(std::chrono::milliseconds(200))); EXPECT_CALL(*timer_, disableTimer()); EXPECT_CALL(stream_encoder_.stream_, resetStream(_)); - AsyncClientImpl client(conn_pool_, "fake_cluster", stats_store_, dispatcher_); - AsyncClient::RequestPtr request = + AsyncClientImpl client(cluster_, *this, stats_store_, dispatcher_); + AsyncClient::Request* request = client.send(std::move(message_), callbacks_, std::chrono::milliseconds(200)); request->cancel(); } diff --git a/test/common/ratelimit/ratelimit_impl_test.cc b/test/common/ratelimit/ratelimit_impl_test.cc index a5773a2e93f6..92862b97b2a6 100644 --- a/test/common/ratelimit/ratelimit_impl_test.cc +++ b/test/common/ratelimit/ratelimit_impl_test.cc @@ -112,7 +112,7 @@ TEST(RateLimitGrpcFactoryTest, NoCluster) { Upstream::MockClusterManager cm; Stats::IsolatedStoreImpl stats_store; - EXPECT_CALL(cm, has("foo")).WillOnce(Return(false)); + EXPECT_CALL(cm, get("foo")).WillOnce(Return(nullptr)); EXPECT_THROW(GrpcFactoryImpl(config, cm, stats_store), EnvoyException); } @@ -127,7 +127,7 @@ TEST(RateLimitGrpcFactoryTest, Create) { Upstream::MockClusterManager cm; Stats::IsolatedStoreImpl stats_store; - EXPECT_CALL(cm, has("foo")).WillOnce(Return(true)); + EXPECT_CALL(cm, get("foo")); GrpcFactoryImpl factory(config, cm, stats_store); factory.create(Optional()); } diff --git a/test/common/router/config_impl_test.cc b/test/common/router/config_impl_test.cc index b68d0b080cc1..212cd6dfd75c 100644 --- a/test/common/router/config_impl_test.cc +++ b/test/common/router/config_impl_test.cc @@ -146,7 +146,6 @@ TEST(RouteMatcherTest, TestRoutes) { Json::StringLoader loader(json); NiceMock runtime; NiceMock cm; - ON_CALL(cm, has(_)).WillByDefault(Return(true)); ConfigImpl config(loader, runtime, cm); // Base routing testing. @@ -351,7 +350,6 @@ TEST(RouteMatcherTest, ContentType) { Json::StringLoader loader(json); NiceMock runtime; NiceMock cm; - ON_CALL(cm, has(_)).WillByDefault(Return(true)); ConfigImpl config(loader, runtime, cm); { @@ -403,7 +401,6 @@ TEST(RouteMatcherTest, Runtime) { NiceMock cm; Runtime::MockSnapshot snapshot; - ON_CALL(cm, has(_)).WillByDefault(Return(true)); ON_CALL(runtime, snapshot()).WillByDefault(ReturnRef(snapshot)); ConfigImpl config(loader, runtime, cm); @@ -445,7 +442,6 @@ TEST(RouteMatcherTest, RateLimit) { Json::StringLoader loader(json); NiceMock runtime; NiceMock cm; - ON_CALL(cm, has(_)).WillByDefault(Return(true)); ConfigImpl config(loader, runtime, cm); EXPECT_TRUE(config.routeForRequest(genHeaders("www.lyft.com", "/foo", "GET"), 0) @@ -492,7 +488,6 @@ TEST(RouteMatcherTest, Retry) { Json::StringLoader loader(json); NiceMock runtime; NiceMock cm; - ON_CALL(cm, has(_)).WillByDefault(Return(true)); ConfigImpl config(loader, runtime, cm); EXPECT_EQ(1U, config.routeForRequest(genHeaders("www.lyft.com", "/foo", "GET"), 0) @@ -619,7 +614,6 @@ TEST(RouteMatcherTest, Redirect) { Json::StringLoader loader(json); NiceMock runtime; NiceMock cm; - ON_CALL(cm, has(StrNe(""))).WillByDefault(Return(true)); ConfigImpl config(loader, runtime, cm); EXPECT_EQ(nullptr, diff --git a/test/common/stats/statsd_test.cc b/test/common/stats/statsd_test.cc index e4dca82ae755..11f791c32003 100644 --- a/test/common/stats/statsd_test.cc +++ b/test/common/stats/statsd_test.cc @@ -16,7 +16,7 @@ namespace Statsd { class TcpStatsdSinkTest : public testing::Test { public: TcpStatsdSinkTest() { - EXPECT_CALL(cluster_manager_, has(_)).WillOnce(Return(true)); + EXPECT_CALL(cluster_manager_, get("statsd")); sink_.reset(new TcpStatsdSink("cluster", "host", "statsd", tls_, cluster_manager_)); } diff --git a/test/common/tracing/http_tracer_impl_test.cc b/test/common/tracing/http_tracer_impl_test.cc index e4556aed7fb4..7eb17b6f2c4e 100644 --- a/test/common/tracing/http_tracer_impl_test.cc +++ b/test/common/tracing/http_tracer_impl_test.cc @@ -246,12 +246,12 @@ class LightStepSinkTest : public Test { : stats_{LIGHTSTEP_STATS(POOL_COUNTER_PREFIX(fake_stats_, "prefix.tracing.lightstep."))} {} void setup(Json::Object& config) { - sink_.reset(new LightStepSink(config, cm_, tls_, "prefix.", fake_stats_, random_, - "service_cluster", "service_node", "token")); + sink_.reset(new LightStepSink(config, cm_, "prefix.", fake_stats_, random_, "service_cluster", + "service_node", "token")); } void setupValidSink() { - EXPECT_CALL(cm_, has("lightstep_saas")).WillOnce(Return(true)); + EXPECT_CALL(cm_, get("lightstep_saas")); std::string valid_config = R"EOF( {"collector_cluster": "lightstep_saas"} @@ -266,7 +266,6 @@ class LightStepSinkTest : public Test { Stats::IsolatedStoreImpl fake_stats_; LightStepStats stats_; NiceMock cm_; - NiceMock tls_; NiceMock random_; std::unique_ptr sink_; }; @@ -290,7 +289,7 @@ TEST_F(LightStepSinkTest, InitializeSink) { { // Valid config but not valid cluster - EXPECT_CALL(cm_, has("lightstep_saas")).WillOnce(Return(false)); + EXPECT_CALL(cm_, get("lightstep_saas")).WillOnce(Return(nullptr)); std::string valid_config = R"EOF( {"collector_cluster": "lightstep_saas"} @@ -301,7 +300,7 @@ TEST_F(LightStepSinkTest, InitializeSink) { } { - EXPECT_CALL(cm_, has("lightstep_saas")).WillOnce(Return(true)); + EXPECT_CALL(cm_, get("lightstep_saas")); std::string valid_config = R"EOF( {"collector_cluster": "lightstep_saas"} @@ -317,19 +316,19 @@ TEST_F(LightStepSinkTest, CallbacksCalled) { NiceMock request_info; - NiceMock* client_1 = new NiceMock(); - EXPECT_CALL(cm_, httpAsyncClientForCluster_("lightstep_saas")).WillOnce(Return(client_1)); + EXPECT_CALL(cm_, httpAsyncClientForCluster("lightstep_saas")) + .WillOnce(ReturnRef(cm_.async_client_)); - Http::MockAsyncClientRequest* request_1 = new Http::MockAsyncClientRequest(client_1); + Http::MockAsyncClientRequest request_1(&cm_.async_client_); Http::AsyncClient::Callbacks* callback_1; const Optional timeout(std::chrono::seconds(5)); - EXPECT_CALL(*client_1, send_(_, _, timeout)) + EXPECT_CALL(cm_.async_client_, send_(_, _, timeout)) .WillOnce( Invoke([&](Http::MessagePtr&, Http::AsyncClient::Callbacks& callbacks, const Optional&) -> Http::AsyncClient::Request* { callback_1 = &callbacks; - return request_1; + return &request_1; })); EXPECT_CALL(random_, uuid()).WillOnce(Return("1")).WillOnce(Return("2")); SystemTime start_time_1; @@ -343,16 +342,16 @@ TEST_F(LightStepSinkTest, CallbacksCalled) { sink_->flushTrace(empty_header_, empty_header_, request_info); - NiceMock* client_2 = new NiceMock(); - Http::MockAsyncClientRequest* request_2 = new Http::MockAsyncClientRequest(client_2); - EXPECT_CALL(cm_, httpAsyncClientForCluster_("lightstep_saas")).WillOnce(Return(client_2)); + Http::MockAsyncClientRequest request_2(&cm_.async_client_); + EXPECT_CALL(cm_, httpAsyncClientForCluster("lightstep_saas")) + .WillOnce(ReturnRef(cm_.async_client_)); Http::AsyncClient::Callbacks* callback_2; - EXPECT_CALL(*client_2, send_(_, _, _)) + EXPECT_CALL(cm_.async_client_, send_(_, _, _)) .WillOnce(Invoke([&](Http::MessagePtr&, Http::AsyncClient::Callbacks& callbacks, Optional) -> Http::AsyncClient::Request* { callback_2 = &callbacks; - return request_2; + return &request_2; })); EXPECT_CALL(random_, uuid()).WillOnce(Return("3")).WillOnce(Return("4")); SystemTime start_time_2; @@ -372,12 +371,6 @@ TEST_F(LightStepSinkTest, CallbacksCalled) { callback_1->onSuccess(std::move(msg)); EXPECT_EQ(1UL, stats_.collector_failed_.value()); EXPECT_EQ(1UL, stats_.collector_success_.value()); - - // Shutdown sink and try to make trace - tls_.shutdownThread_(); - - EXPECT_CALL(cm_, httpAsyncClientForCluster_("lightstep_saas")).Times(0); - sink_->flushTrace(empty_header_, empty_header_, request_info); } TEST_F(LightStepSinkTest, ClientNotAvailable) { @@ -385,21 +378,36 @@ TEST_F(LightStepSinkTest, ClientNotAvailable) { NiceMock request_info; - EXPECT_CALL(cm_, httpAsyncClientForCluster_("lightstep_saas")).WillOnce(Return(nullptr)); + EXPECT_CALL(cm_, httpAsyncClientForCluster("lightstep_saas")) + .WillOnce(ReturnRef(cm_.async_client_)); + EXPECT_CALL(cm_.async_client_, send_(_, _, _)) + .WillOnce( + Invoke([&](Http::MessagePtr&, Http::AsyncClient::Callbacks& callbacks, + const Optional&) -> Http::AsyncClient::Request* { + callbacks.onFailure(Http::AsyncClient::FailureReason::Reset); + return nullptr; + })); + SystemTime start_time_1; + EXPECT_CALL(request_info, startTime()).WillOnce(Return(start_time_1)); + std::chrono::seconds duration_1(1); + EXPECT_CALL(request_info, duration()).WillOnce(Return(duration_1)); + const std::string protocol = "http/1"; + EXPECT_CALL(request_info, protocol()).WillRepeatedly(ReturnRef(protocol)); + Optional code_1(200); + EXPECT_CALL(request_info, responseCode()).WillRepeatedly(ReturnRef(code_1)); sink_->flushTrace(empty_header_, empty_header_, request_info); - EXPECT_EQ(1UL, stats_.client_failed_.value()); - EXPECT_EQ(0UL, stats_.collector_failed_.value()); + EXPECT_EQ(1UL, stats_.collector_failed_.value()); EXPECT_EQ(0UL, stats_.collector_success_.value()); } TEST_F(LightStepSinkTest, ShutdownWhenActiveRequests) { setupValidSink(); - NiceMock* client = new NiceMock(); - EXPECT_CALL(cm_, httpAsyncClientForCluster_("lightstep_saas")).WillOnce(Return(client)); + EXPECT_CALL(cm_, httpAsyncClientForCluster("lightstep_saas")) + .WillOnce(ReturnRef(cm_.async_client_)); - Http::MockAsyncClientRequest* request = new Http::MockAsyncClientRequest(client); + Http::MockAsyncClientRequest request(&cm_.async_client_); NiceMock request_info; const std::string protocol = "http/1"; @@ -467,19 +475,16 @@ TEST_F(LightStepSinkTest, ShutdownWhenActiveRequests) { )EOF"; Http::AsyncClient::Callbacks* callback; - EXPECT_CALL(*client, send_(_, _, _)) + EXPECT_CALL(cm_.async_client_, send_(_, _, _)) .WillOnce(Invoke([&](Http::MessagePtr& msg, Http::AsyncClient::Callbacks& callbacks, Optional) -> Http::AsyncClient::Request* { callback = &callbacks; EXPECT_EQ(expected_json, msg->bodyAsString()); EXPECT_EQ("token", msg->headers().get("LightStep-Access-Token")); - return request; + return &request; })); sink_->flushTrace(request_header, empty_header_, request_info); - - EXPECT_CALL(*request, cancel()); - tls_.shutdownThread_(); } TEST(LightStepUtilityTest, HeadersNotSet) { diff --git a/test/common/upstream/cluster_manager_impl_test.cc b/test/common/upstream/cluster_manager_impl_test.cc index 1ac2f68afc47..0bc5694e7ac1 100644 --- a/test/common/upstream/cluster_manager_impl_test.cc +++ b/test/common/upstream/cluster_manager_impl_test.cc @@ -11,6 +11,7 @@ using testing::_; using testing::NiceMock; +using testing::Return; using testing::ReturnNew; using testing::SaveArg; @@ -28,46 +29,188 @@ class ClusterManagerImplForTest : public ClusterManagerImpl { MOCK_METHOD1(allocateConnPool_, Http::ConnectionPool::Instance*(ConstHostPtr host)); }; -TEST(ClusterManagerImplTest, DynamicHostRemove) { +class ClusterManagerImplTest : public testing::Test { +public: + void create(const Json::Object& config) { + cluster_manager_.reset(new ClusterManagerImplForTest(config, stats_, tls_, dns_resolver_, + ssl_context_manager_, runtime_, random_, + "us-east-1d")); + } + + Stats::IsolatedStoreImpl stats_; + NiceMock tls_; + NiceMock dns_resolver_; + NiceMock runtime_; + NiceMock random_; + Ssl::ContextManagerImpl ssl_context_manager_{runtime_}; + std::unique_ptr cluster_manager_; +}; + +TEST_F(ClusterManagerImplTest, NoSdsConfig) { std::string json = R"EOF( { - "cluster_manager": { - "clusters": [ - { - "name": "cluster_1", - "connect_timeout_ms": 250, - "type": "strict_dns", - "lb_type": "round_robin", - "hosts": [{"url": "tcp://localhost:11001"}] - }] - } + "clusters": [ + { + "name": "cluster_1", + "connect_timeout_ms": 250, + "type": "sds", + "lb_type": "round_robin" + }] } )EOF"; Json::StringLoader loader(json); + EXPECT_THROW(create(loader), EnvoyException); +} - Stats::IsolatedStoreImpl stats; - NiceMock tls; - NiceMock dns_resolver; - Network::DnsResolver::ResolveCb dns_callback; - Event::MockTimer* dns_timer_ = new NiceMock(&dns_resolver.dispatcher_); - NiceMock runtime; - NiceMock random; - Ssl::ContextManagerImpl ssl_context_manager(runtime); - EXPECT_CALL(dns_resolver, resolve(_, _)).WillRepeatedly(SaveArg<1>(&dns_callback)); +TEST_F(ClusterManagerImplTest, UnknownClusterType) { + std::string json = R"EOF( + { + "clusters": [ + { + "name": "cluster_1", + "connect_timeout_ms": 250, + "type": "foo", + "lb_type": "round_robin" + }] + } + )EOF"; + + Json::StringLoader loader(json); + EXPECT_THROW(create(loader), EnvoyException); +} + +TEST_F(ClusterManagerImplTest, DuplicateCluster) { + std::string json = R"EOF( + { + "clusters": [ + { + "name": "cluster_1", + "connect_timeout_ms": 250, + "type": "static", + "lb_type": "round_robin", + "hosts": [{"url": "tcp://127.0.0.1:11001"}] + }, + { + "name": "cluster_1", + "connect_timeout_ms": 250, + "type": "static", + "lb_type": "round_robin", + "hosts": [{"url": "tcp://127.0.0.1:11001"}] + }] + } + )EOF"; + + Json::StringLoader loader(json); + EXPECT_THROW(create(loader), EnvoyException); +} + +TEST_F(ClusterManagerImplTest, UnknownHcType) { + std::string json = R"EOF( + { + "clusters": [ + { + "name": "cluster_1", + "connect_timeout_ms": 250, + "type": "static", + "lb_type": "round_robin", + "hosts": [{"url": "tcp://127.0.0.1:11001"}], + "health_check": { + "type": "foo" + } + }] + } + )EOF"; - ClusterManagerImplForTest cluster_manager(loader.getObject("cluster_manager"), stats, tls, - dns_resolver, ssl_context_manager, runtime, random, - "us-east-1d"); + Json::StringLoader loader(json); + EXPECT_THROW(create(loader), EnvoyException); +} + +TEST_F(ClusterManagerImplTest, TcpHealthChecker) { + std::string json = R"EOF( + { + "clusters": [ + { + "name": "cluster_1", + "connect_timeout_ms": 250, + "type": "static", + "lb_type": "round_robin", + "hosts": [{"url": "tcp://127.0.0.1:11001"}], + "health_check": { + "type": "tcp", + "timeout_ms": 1000, + "interval_ms": 1000, + "unhealthy_threshold": 2, + "healthy_threshold": 2, + "send": [ + {"binary": "01"} + ], + "receive": [ + {"binary": "02"} + ] + } + }] + } + )EOF"; + + Json::StringLoader loader(json); + Network::MockClientConnection* connection = new NiceMock(); + EXPECT_CALL(dns_resolver_.dispatcher_, createClientConnection_("tcp://127.0.0.1:11001")) + .WillOnce(Return(connection)); + create(loader); +} + +TEST_F(ClusterManagerImplTest, UnknownCluster) { + std::string json = R"EOF( + { + "clusters": [ + { + "name": "cluster_1", + "connect_timeout_ms": 250, + "type": "static", + "lb_type": "round_robin", + "hosts": [{"url": "tcp://127.0.0.1:11001"}] + }] + } + )EOF"; + + Json::StringLoader loader(json); + create(loader); + EXPECT_EQ(nullptr, cluster_manager_->get("hello")); + EXPECT_THROW(cluster_manager_->httpConnPoolForCluster("hello"), EnvoyException); + EXPECT_THROW(cluster_manager_->tcpConnForCluster("hello"), EnvoyException); + EXPECT_THROW(cluster_manager_->httpAsyncClientForCluster("hello"), EnvoyException); +} + +TEST_F(ClusterManagerImplTest, DynamicHostRemove) { + std::string json = R"EOF( + { + "clusters": [ + { + "name": "cluster_1", + "connect_timeout_ms": 250, + "type": "strict_dns", + "lb_type": "round_robin", + "hosts": [{"url": "tcp://localhost:11001"}] + }] + } + )EOF"; + + Json::StringLoader loader(json); + + Network::DnsResolver::ResolveCb dns_callback; + Event::MockTimer* dns_timer_ = new NiceMock(&dns_resolver_.dispatcher_); + EXPECT_CALL(dns_resolver_, resolve(_, _)).WillRepeatedly(SaveArg<1>(&dns_callback)); + create(loader); // Test for no hosts returning the correct values before we have hosts. - EXPECT_EQ(nullptr, cluster_manager.httpConnPoolForCluster("cluster_1")); - EXPECT_EQ(nullptr, cluster_manager.tcpConnForCluster("cluster_1").connection_); - EXPECT_EQ(2UL, stats.counter("cluster.cluster_1.upstream_cx_none_healthy").value()); + EXPECT_EQ(nullptr, cluster_manager_->httpConnPoolForCluster("cluster_1")); + EXPECT_EQ(nullptr, cluster_manager_->tcpConnForCluster("cluster_1").connection_); + EXPECT_EQ(2UL, stats_.counter("cluster.cluster_1.upstream_cx_none_healthy").value()); // Set up for an initialize callback. ReadyWatcher initialized; - cluster_manager.setInitializedCb([&]() -> void { initialized.ready(); }); + cluster_manager_->setInitializedCb([&]() -> void { initialized.ready(); }); EXPECT_CALL(initialized, ready()); dns_callback({"127.0.0.1", "127.0.0.2"}); @@ -75,17 +218,17 @@ TEST(ClusterManagerImplTest, DynamicHostRemove) { // After we are initialized, we should immediately get called back if someone asks for an // initialize callback. EXPECT_CALL(initialized, ready()); - cluster_manager.setInitializedCb([&]() -> void { initialized.ready(); }); + cluster_manager_->setInitializedCb([&]() -> void { initialized.ready(); }); - EXPECT_CALL(cluster_manager, allocateConnPool_(_)) + EXPECT_CALL(*cluster_manager_, allocateConnPool_(_)) .Times(2) .WillRepeatedly(ReturnNew()); // This should provide us a CP for each of the above hosts. Http::ConnectionPool::MockInstance* cp1 = dynamic_cast( - cluster_manager.httpConnPoolForCluster("cluster_1")); + cluster_manager_->httpConnPoolForCluster("cluster_1")); Http::ConnectionPool::MockInstance* cp2 = dynamic_cast( - cluster_manager.httpConnPoolForCluster("cluster_1")); + cluster_manager_->httpConnPoolForCluster("cluster_1")); EXPECT_NE(cp1, cp2); @@ -100,7 +243,7 @@ TEST(ClusterManagerImplTest, DynamicHostRemove) { // Make sure we get back the same connection pool for the 2nd host as we did before the change. Http::ConnectionPool::MockInstance* cp3 = dynamic_cast( - cluster_manager.httpConnPoolForCluster("cluster_1")); + cluster_manager_->httpConnPoolForCluster("cluster_1")); EXPECT_EQ(cp2, cp3); // Now add and remove a host that we never have a conn pool to. This should not lead to any diff --git a/test/common/upstream/sds_test.cc b/test/common/upstream/sds_test.cc index 3979e3d3f475..8e16e3fb42d1 100644 --- a/test/common/upstream/sds_test.cc +++ b/test/common/upstream/sds_test.cc @@ -12,7 +12,6 @@ using testing::DoAll; using testing::Invoke; using testing::NiceMock; using testing::Return; -using testing::ReturnNew; using testing::SaveArg; using testing::WithArg; @@ -20,7 +19,9 @@ namespace Upstream { class SdsTest : public testing::Test { protected: - SdsTest() : sds_config_{"us-east-1a", "sds", std::chrono::milliseconds(30000)} { + SdsTest() + : sds_config_{"us-east-1a", "sds", std::chrono::milliseconds(30000)}, + request_(&cm_.async_client_) { std::string raw_config = R"EOF( { "name": "name", @@ -60,9 +61,8 @@ class SdsTest : public testing::Test { } void setupPoolFailure() { - client_ = new NiceMock(); - EXPECT_CALL(cm_, httpAsyncClientForCluster_("sds")).WillOnce(Return(client_)); - EXPECT_CALL(*client_, send_(_, _, _)) + EXPECT_CALL(cm_, httpAsyncClientForCluster("sds")).WillOnce(ReturnRef(cm_.async_client_)); + EXPECT_CALL(cm_.async_client_, send_(_, _, _)) .WillOnce(Invoke([](Http::MessagePtr&, Http::AsyncClient::Callbacks& callbacks, Optional) -> Http::AsyncClient::Request* { callbacks.onFailure(Http::AsyncClient::FailureReason::Reset); @@ -71,11 +71,9 @@ class SdsTest : public testing::Test { } void setupRequest() { - client_ = new NiceMock(); - EXPECT_CALL(cm_, httpAsyncClientForCluster_("sds")).WillOnce(Return(client_)); - EXPECT_CALL(*client_, send_(_, _, _)) - .WillOnce(DoAll(WithArg<1>(SaveArgAddress(&callbacks_)), - ReturnNew>(client_))); + EXPECT_CALL(cm_, httpAsyncClientForCluster("sds")).WillOnce(ReturnRef(cm_.async_client_)); + EXPECT_CALL(cm_.async_client_, send_(_, _, _)) + .WillOnce(DoAll(WithArg<1>(SaveArgAddress(&callbacks_)), Return(&request_))); } Stats::IsolatedStoreImpl stats_; @@ -85,15 +83,16 @@ class SdsTest : public testing::Test { Event::MockDispatcher dispatcher_; std::unique_ptr cluster_; Event::MockTimer* timer_; - Http::MockAsyncClient* client_; Http::AsyncClient::Callbacks* callbacks_; ReadyWatcher membership_updated_; NiceMock random_; + Http::MockAsyncClientRequest request_; }; TEST_F(SdsTest, Shutdown) { setupRequest(); cluster_->initialize(); + EXPECT_CALL(request_, cancel()); cluster_->shutdown(); } @@ -167,11 +166,6 @@ TEST_F(SdsTest, NoHealthChecker) { EXPECT_EQ(13UL, cluster_->hosts().size()); EXPECT_EQ(50U, canary_host->weight()); EXPECT_EQ(50UL, cluster_->stats().max_host_weight_.value()); - - // No healthy SDS hosts. - EXPECT_CALL(cm_, httpAsyncClientForCluster_("sds")).WillOnce(Return(nullptr)); - EXPECT_CALL(*timer_, enableTimer(_)); - timer_->callback_(); } TEST_F(SdsTest, HealthChecker) { diff --git a/test/mocks/http/mocks.h b/test/mocks/http/mocks.h index 85735f06a56f..6de18aa3d0de 100644 --- a/test/mocks/http/mocks.h +++ b/test/mocks/http/mocks.h @@ -281,9 +281,9 @@ class MockAsyncClient : public AsyncClient { MOCK_METHOD0(onRequestDestroy, void()); // Http::AsyncClient - RequestPtr send(MessagePtr&& request, Callbacks& callbacks, - const Optional& timeout) override { - return RequestPtr{send_(request, callbacks, timeout)}; + Request* send(MessagePtr&& request, Callbacks& callbacks, + const Optional& timeout) override { + return send_(request, callbacks, timeout); } MOCK_METHOD3(send_, Request*(MessagePtr& request, Callbacks& callbacks, diff --git a/test/mocks/upstream/mocks.h b/test/mocks/upstream/mocks.h index 31b91d8e8f38..a016feae258a 100644 --- a/test/mocks/upstream/mocks.h +++ b/test/mocks/upstream/mocks.h @@ -66,22 +66,18 @@ class MockClusterManager : public ClusterManager { return {Network::ClientConnectionPtr{data.connection_}, data.host_}; } - Http::AsyncClientPtr httpAsyncClientForCluster(const std::string& cluster) override { - return Http::AsyncClientPtr{httpAsyncClientForCluster_(cluster)}; - } - // Upstream::ClusterManager MOCK_METHOD1(setInitializedCb, void(std::function)); MOCK_METHOD0(clusters, std::unordered_map()); MOCK_METHOD1(get, const Cluster*(const std::string& cluster)); - MOCK_METHOD1(has, bool(const std::string& cluster)); MOCK_METHOD1(httpConnPoolForCluster, Http::ConnectionPool::Instance*(const std::string& cluster)); MOCK_METHOD1(tcpConnForCluster_, MockHost::MockCreateConnectionData(const std::string& cluster)); - MOCK_METHOD1(httpAsyncClientForCluster_, Http::AsyncClient*(const std::string& cluster)); + MOCK_METHOD1(httpAsyncClientForCluster, Http::AsyncClient&(const std::string& cluster)); MOCK_METHOD0(shutdown, void()); NiceMock conn_pool_; NiceMock cluster_; + NiceMock async_client_; }; class MockHealthChecker : public HealthChecker {