diff --git a/client/clientfactory.go b/client/clientfactory.go index db809aae66b..92e0caf551f 100644 --- a/client/clientfactory.go +++ b/client/clientfactory.go @@ -42,7 +42,6 @@ import ( "github.com/uber/cadence/client/wrappers/metered" "github.com/uber/cadence/client/wrappers/thrift" timeoutwrapper "github.com/uber/cadence/client/wrappers/timeout" - "github.com/uber/cadence/common" "github.com/uber/cadence/common/dynamicconfig" "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/membership" @@ -68,7 +67,7 @@ type ( DomainIDToNameFunc func(string) (string, error) rpcClientFactory struct { - rpcFactory common.RPCFactory + rpcFactory rpc.Factory resolver membership.Resolver metricsClient metrics.Client dynConfig *dynamicconfig.Collection @@ -79,7 +78,7 @@ type ( // NewRPCClientFactory creates an instance of client factory that knows how to dispatch RPC calls. func NewRPCClientFactory( - rpcFactory common.RPCFactory, + rpcFactory rpc.Factory, resolver membership.Resolver, metricsClient metrics.Client, dc *dynamicconfig.Collection, diff --git a/cmd/server/cadence/server.go b/cmd/server/cadence/server.go index 15d3163cc76..1e087f7b04f 100644 --- a/cmd/server/cadence/server.go +++ b/cmd/server/cadence/server.go @@ -156,7 +156,7 @@ func (s *server) startService() common.Daemon { params.MetricScope = svcCfg.Metrics.NewScope(params.Logger, params.Name) - rpcParams, err := rpc.NewParams(params.Name, s.cfg, dc) + rpcParams, err := rpc.NewParams(params.Name, s.cfg, dc, params.Logger) if err != nil { log.Fatalf("error creating rpc factory params: %v", err) } @@ -170,7 +170,7 @@ func (s *server) startService() common.Daemon { peerProvider, err := ringpopprovider.New( params.Name, &s.cfg.Ringpop, - rpcFactory.GetChannel(), + rpcFactory.GetTChannel(), membership.PortMap{ membership.PortGRPC: svcCfg.RPC.GRPCPort, membership.PortTchannel: svcCfg.RPC.Port, diff --git a/common/dynamicconfig/constants.go b/common/dynamicconfig/constants.go index c232ad6c0fb..1a903a59b41 100644 --- a/common/dynamicconfig/constants.go +++ b/common/dynamicconfig/constants.go @@ -1588,6 +1588,12 @@ const ( // Default value: false // Allowed filters: N/A EnableSQLAsyncTransaction + // EnableConnectionRetainingDirectChooser is the key for enabling connection retaining direct yarpc chooser + // KeyName: system.enableConnectionRetainingDirectChooser + // Value type: Bool + // Default value: false + // Allowed filters: N/A + EnableConnectionRetainingDirectChooser // key for frontend @@ -3950,6 +3956,11 @@ var BoolKeys = map[BoolKey]DynamicBool{ Description: "EnableSQLAsyncTransaction is the key for enabling async transaction", DefaultValue: false, }, + EnableConnectionRetainingDirectChooser: { + KeyName: "system.enableConnectionRetainingDirectChooser", + Description: "EnableConnectionRetainingDirectChooser is the key for enabling connection retaining direct chooser", + DefaultValue: false, + }, EnableClientVersionCheck: { KeyName: "frontend.enableClientVersionCheck", Description: "EnableClientVersionCheck is enables client version check for frontend", diff --git a/common/rpc.go b/common/headers.go similarity index 89% rename from common/rpc.go rename to common/headers.go index 3ac42837049..30b7d25eb94 100644 --- a/common/rpc.go +++ b/common/headers.go @@ -18,14 +18,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -//go:generate mockgen -package $GOPACKAGE -source $GOFILE -destination rpc_mock.go -self_package github.com/uber/cadence/common - package common -import ( - "go.uber.org/yarpc" -) - const ( // LibraryVersionHeaderName refers to the name of the // tchannel / http header that contains the client @@ -61,11 +55,3 @@ const ( // ClientIsolationGroupHeaderName refers to the name of the header that contains the isolation group which the client request is from ClientIsolationGroupHeaderName = "cadence-client-isolation-group" ) - -type ( - // RPCFactory Creates a dispatcher that knows how to transport requests. - RPCFactory interface { - GetDispatcher() *yarpc.Dispatcher - GetMaxMessageSize() int - } -) diff --git a/common/log/tag/values.go b/common/log/tag/values.go index f212c435604..b71738b8da6 100644 --- a/common/log/tag/values.go +++ b/common/log/tag/values.go @@ -138,6 +138,7 @@ var ( ComponentMapQ = component("mapq") ComponentMapQTree = component("mapq-tree") ComponentMapQTreeNode = component("mapq-tree-node") + ComponentRPCFactory = component("rpc-factory") ) // Pre-defined values for TagSysLifecycle diff --git a/common/membership/hashring.go b/common/membership/hashring.go index 72ecf53050d..b622822a30b 100644 --- a/common/membership/hashring.go +++ b/common/membership/hashring.go @@ -213,7 +213,7 @@ func (r *ring) notifySubscribers(msg ChangedEvent) { select { case ch <- &msg: default: - r.logger.Error("subscriber notification failed", tag.Name(name)) + r.logger.Warn("subscriber notification failed", tag.Name(name)) } } } diff --git a/common/membership/resolver.go b/common/membership/resolver.go index 1a6cf2ee6ab..2606bdd5680 100644 --- a/common/membership/resolver.go +++ b/common/membership/resolver.go @@ -48,6 +48,7 @@ type ( // Resolver provides membership information for all cadence services. Resolver interface { common.Daemon + // WhoAmI returns self host details. // To be consistent with peer provider, it is advised to use peer provider // to return this information diff --git a/common/resource/params.go b/common/resource/params.go index 0c16a4e5006..7b57a71c827 100644 --- a/common/resource/params.go +++ b/common/resource/params.go @@ -44,6 +44,7 @@ import ( "github.com/uber/cadence/common/metrics" "github.com/uber/cadence/common/partition" "github.com/uber/cadence/common/pinot" + "github.com/uber/cadence/common/rpc" ) type ( @@ -58,7 +59,7 @@ type ( MetricScope tally.Scope MembershipResolver membership.Resolver - RPCFactory common.RPCFactory + RPCFactory rpc.Factory PProfInitializer common.PProfInitializer PersistenceConfig config.Persistence ClusterMetadata cluster.Metadata diff --git a/common/resource/resource.go b/common/resource/resource.go index 796b0846313..4f5aa81c6e8 100644 --- a/common/resource/resource.go +++ b/common/resource/resource.go @@ -49,7 +49,7 @@ import ( "github.com/uber/cadence/common/partition" "github.com/uber/cadence/common/persistence" persistenceClient "github.com/uber/cadence/common/persistence/client" - "github.com/uber/cadence/common/quotas/global/rpc" + qrpc "github.com/uber/cadence/common/quotas/global/rpc" "github.com/uber/cadence/common/service" ) @@ -96,7 +96,7 @@ type ( GetMatchingClient() matching.Client GetHistoryRawClient() history.Client GetHistoryClient() history.Client - GetRatelimiterAggregatorsClient() rpc.Client + GetRatelimiterAggregatorsClient() qrpc.Client GetRemoteAdminClient(cluster string) admin.Client GetRemoteFrontendClient(cluster string) frontend.Client GetClientBean() client.Bean diff --git a/common/resource/resourceImpl.go b/common/resource/resourceImpl.go index 5a26c3bb6d8..0fa491b08ab 100644 --- a/common/resource/resourceImpl.go +++ b/common/resource/resourceImpl.go @@ -58,8 +58,9 @@ import ( "github.com/uber/cadence/common/partition" "github.com/uber/cadence/common/persistence" persistenceClient "github.com/uber/cadence/common/persistence/client" - "github.com/uber/cadence/common/quotas/global/rpc" + qrpc "github.com/uber/cadence/common/quotas/global/rpc" "github.com/uber/cadence/common/quotas/permember" + "github.com/uber/cadence/common/rpc" "github.com/uber/cadence/common/service" ) @@ -141,7 +142,7 @@ type ( pprofInitializer common.PProfInitializer runtimeMetricsReporter *metrics.RuntimeMetricsReporter - rpcFactory common.RPCFactory + rpcFactory rpc.Factory isolationGroups isolationgroup.State isolationGroupConfigStore configstore.Client @@ -149,7 +150,7 @@ type ( asyncWorkflowQueueProvider queue.Provider - ratelimiterAggregatorClient rpc.Client + ratelimiterAggregatorClient qrpc.Client } ) @@ -304,7 +305,7 @@ func New( } partitioner := ensurePartitionerOrDefault(params, isolationGroupState) - ratelimiterAggs := rpc.New( + ratelimiterAggs := qrpc.New( historyRawClient, // no retries, will retry internally if needed clientBean.GetHistoryPeers(), logger, @@ -384,7 +385,6 @@ func New( // Start all resources func (h *Impl) Start() { - if !atomic.CompareAndSwapInt32( &h.status, common.DaemonStatusInitialized, @@ -399,6 +399,9 @@ func (h *Impl) Start() { if err := h.pprofInitializer.Start(); err != nil { h.logger.WithTags(tag.Error(err)).Fatal("fail to start PProf") } + + h.rpcFactory.Start(h.membershipResolver) + if err := h.dispatcher.Start(); err != nil { h.logger.WithTags(tag.Error(err)).Fatal("fail to start dispatcher") } @@ -423,7 +426,6 @@ func (h *Impl) Start() { // Stop stops all resources func (h *Impl) Stop() { - if !atomic.CompareAndSwapInt32( &h.status, common.DaemonStatusStarted, @@ -435,9 +437,12 @@ func (h *Impl) Stop() { h.domainCache.Stop() h.domainMetricsScopeCache.Stop() h.membershipResolver.Stop() + if err := h.dispatcher.Stop(); err != nil { h.logger.WithTags(tag.Error(err)).Error("failed to stop dispatcher") } + h.rpcFactory.Stop() + h.runtimeMetricsReporter.Stop() h.persistenceBean.Close() if h.isolationGroupConfigStore != nil { @@ -555,7 +560,7 @@ func (h *Impl) GetHistoryClient() history.Client { return h.historyClient } -func (h *Impl) GetRatelimiterAggregatorsClient() rpc.Client { +func (h *Impl) GetRatelimiterAggregatorsClient() qrpc.Client { return h.ratelimiterAggregatorClient } diff --git a/common/rpc/direct_peer_chooser.go b/common/rpc/direct_peer_chooser.go new file mode 100644 index 00000000000..1e515e4ebc4 --- /dev/null +++ b/common/rpc/direct_peer_chooser.go @@ -0,0 +1,144 @@ +// Copyright (c) 2021 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package rpc + +import ( + "context" + "sync" + + "go.uber.org/yarpc/api/peer" + "go.uber.org/yarpc/api/transport" + "go.uber.org/yarpc/peer/direct" + "go.uber.org/yarpc/yarpcerrors" + + "github.com/uber/cadence/common/dynamicconfig" + "github.com/uber/cadence/common/log" + "github.com/uber/cadence/common/log/tag" + "github.com/uber/cadence/common/membership" +) + +// directPeerChooser is a peer.Chooser that chooses a peer based on the shard key. +// Peers are managed by the peerList and peers are reused across multiple requests. +type directPeerChooser struct { + serviceName string + logger log.Logger + t peer.Transport + enableConnRetainMode dynamicconfig.BoolPropertyFn + legacyChooser peer.Chooser + legacyChooserErr error + mu sync.RWMutex +} + +func newDirectChooser(serviceName string, t peer.Transport, logger log.Logger, enableConnRetainMode dynamicconfig.BoolPropertyFn) *directPeerChooser { + return &directPeerChooser{ + serviceName: serviceName, + logger: logger, + t: t, + enableConnRetainMode: enableConnRetainMode, + } +} + +// Start statisfies the peer.Chooser interface. +func (g *directPeerChooser) Start() error { + c, ok := g.getLegacyChooser() + if ok { + return c.Start() + } + + return nil // no-op +} + +// Stop statisfies the peer.Chooser interface. +func (g *directPeerChooser) Stop() error { + c, ok := g.getLegacyChooser() + if ok { + return c.Stop() + } + + return nil // no-op +} + +// IsRunning statisfies the peer.Chooser interface. +func (g *directPeerChooser) IsRunning() bool { + c, ok := g.getLegacyChooser() + if ok { + return c.IsRunning() + } + + return true // no-op +} + +// Choose returns an existing peer for the shard key. +func (g *directPeerChooser) Choose(ctx context.Context, req *transport.Request) (peer peer.Peer, onFinish func(error), err error) { + if g.enableConnRetainMode != nil && !g.enableConnRetainMode() { + return g.chooseFromLegacyDirectPeerChooser(ctx, req) + } + + if req.ShardKey == "" { + return nil, nil, yarpcerrors.InvalidArgumentErrorf("chooser requires ShardKey to be non-empty") + } + + // TODO: implement connection retain mode + return nil, nil, yarpcerrors.UnimplementedErrorf("direct peer chooser conn retain mode unimplemented") +} + +func (g *directPeerChooser) UpdatePeers(members []membership.HostInfo) { + // TODO: implement + g.logger.Debug("direct peer chooser got a membership update", tag.Counter(len(members))) +} + +func (g *directPeerChooser) chooseFromLegacyDirectPeerChooser(ctx context.Context, req *transport.Request) (peer.Peer, func(error), error) { + c, ok := g.getLegacyChooser() + if !ok { + return nil, nil, yarpcerrors.InternalErrorf("failed to get legacy direct peer chooser") + } + + return c.Choose(ctx, req) +} + +func (g *directPeerChooser) getLegacyChooser() (peer.Chooser, bool) { + g.mu.RLock() + + if g.legacyChooser != nil { + // Legacy chooser already created, return it + g.mu.RUnlock() + return g.legacyChooser, true + } + + if g.legacyChooserErr != nil { + // There was an error creating the legacy chooser, return false + g.mu.RUnlock() + return nil, false + } + + g.mu.RUnlock() + + g.mu.Lock() + g.legacyChooser, g.legacyChooserErr = direct.New(direct.Configuration{}, g.t) + g.mu.Unlock() + + if g.legacyChooserErr != nil { + g.logger.Error("failed to create legacy direct peer chooser", tag.Error(g.legacyChooserErr)) + return nil, false + } + + return g.legacyChooser, true +} diff --git a/common/rpc/direct_peer_chooser_test.go b/common/rpc/direct_peer_chooser_test.go new file mode 100644 index 00000000000..722f773872e --- /dev/null +++ b/common/rpc/direct_peer_chooser_test.go @@ -0,0 +1,96 @@ +// Copyright (c) 2021 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package rpc + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/goleak" + "go.uber.org/yarpc/api/transport" + "go.uber.org/yarpc/transport/grpc" + + "github.com/uber/cadence/common/dynamicconfig" + "github.com/uber/cadence/common/log/testlogger" +) + +func TestDirectChooser(t *testing.T) { + req := &transport.Request{ + Caller: "caller", + Service: "service", + ShardKey: "shard1", + } + + tests := []struct { + desc string + retainConn bool + req *transport.Request + wantChooseErr bool + }{ + { + desc: "don't retain connection", + retainConn: false, + req: req, + }, + { + desc: "retain connection", + retainConn: true, + req: req, + wantChooseErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.desc, func(t *testing.T) { + defer goleak.VerifyNone(t) + + logger := testlogger.New(t) + serviceName := "service" + directConnRetainFn := func(opts ...dynamicconfig.FilterOption) bool { return tc.retainConn } + grpcTransport := grpc.NewTransport() + + chooser := newDirectChooser(serviceName, grpcTransport, logger, directConnRetainFn) + if err := chooser.Start(); err != nil { + t.Fatalf("failed to start direct peer chooser: %v", err) + } + + assert.True(t, chooser.IsRunning()) + + peer, onFinish, err := chooser.Choose(context.Background(), tc.req) + if tc.wantChooseErr != (err != nil) { + t.Fatalf("Choose() err = %v, wantChooseErr = %v", err, tc.wantChooseErr) + } + + if err == nil { + assert.NotNil(t, peer) + assert.NotNil(t, onFinish) + + // call onFinish to release the peer + onFinish(nil) + } + + if err := chooser.Stop(); err != nil { + t.Fatalf("failed to stop direct peer chooser: %v", err) + } + }) + } +} diff --git a/common/rpc/factory.go b/common/rpc/factory.go index 4d96a217281..cd1a0b76ed0 100644 --- a/common/rpc/factory.go +++ b/common/rpc/factory.go @@ -21,10 +21,12 @@ package rpc import ( + "context" "crypto/tls" "fmt" "net" nethttp "net/http" + "sync" "go.uber.org/yarpc" "go.uber.org/yarpc/transport/grpc" @@ -34,19 +36,32 @@ import ( "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/log/tag" + "github.com/uber/cadence/common/membership" ) -const defaultGRPCSizeLimit = 4 * 1024 * 1024 +const ( + defaultGRPCSizeLimit = 4 * 1024 * 1024 + factoryComponentName = "rpc-factory" +) -// Factory is an implementation of common.RPCFactory interface -type Factory struct { +// Factory is an implementation of rpc.Factory interface +type FactoryImpl struct { maxMessageSize int channel tchannel.Channel dispatcher *yarpc.Dispatcher + outbounds *Outbounds + logger log.Logger + serviceName string + wg sync.WaitGroup + ctx context.Context + cancelFn context.CancelFunc + peerLister PeerLister } // NewFactory builds a new rpc.Factory -func NewFactory(logger log.Logger, p Params) *Factory { +func NewFactory(logger log.Logger, p Params) *FactoryImpl { + logger = logger.WithTags(tag.ComponentRPCFactory) + inbounds := yarpc.Inbounds{} // Create TChannel transport // This is here only because ringpop expects tchannel.ChannelTransport, @@ -96,7 +111,6 @@ func NewFactory(logger log.Logger, p Params) *Factory { return } nethttp.NotFound(w, r) - return }) } @@ -115,7 +129,7 @@ func NewFactory(logger log.Logger, p Params) *Factory { logger.Info("Listening for HTTP requests", tag.Address(p.HTTP.Address)) } // Create outbounds - outbounds := yarpc.Outbounds{} + outbounds := &Outbounds{} if p.OutboundsBuilder != nil { outbounds, err = p.OutboundsBuilder.Build(grpcTransport, tchannel) if err != nil { @@ -126,35 +140,88 @@ func NewFactory(logger log.Logger, p Params) *Factory { dispatcher := yarpc.NewDispatcher(yarpc.Config{ Name: p.ServiceName, Inbounds: inbounds, - Outbounds: outbounds, + Outbounds: outbounds.Outbounds, InboundMiddleware: p.InboundMiddleware, OutboundMiddleware: p.OutboundMiddleware, }) - return &Factory{ + ctx, cancel := context.WithCancel(context.Background()) + return &FactoryImpl{ maxMessageSize: p.GRPCMaxMsgSize, dispatcher: dispatcher, channel: ch.Channel(), + outbounds: outbounds, + serviceName: p.ServiceName, + logger: logger, + ctx: ctx, + cancelFn: cancel, } } // GetDispatcher return a cached dispatcher -func (d *Factory) GetDispatcher() *yarpc.Dispatcher { +func (d *FactoryImpl) GetDispatcher() *yarpc.Dispatcher { return d.dispatcher } // GetChannel returns Tchannel Channel used by Ringpop -func (d *Factory) GetChannel() tchannel.Channel { +func (d *FactoryImpl) GetTChannel() tchannel.Channel { return d.channel } -func (d *Factory) GetMaxMessageSize() int { +func (d *FactoryImpl) GetMaxMessageSize() int { if d.maxMessageSize == 0 { return defaultGRPCSizeLimit } return d.maxMessageSize } +func (d *FactoryImpl) Start(peerLister PeerLister) error { + // subscribe to membership changes and notify outbounds builder for peer updates + d.peerLister = peerLister + ch := make(chan *membership.ChangedEvent, 1) + if err := d.peerLister.Subscribe(d.serviceName, factoryComponentName, ch); err != nil { + return fmt.Errorf("rpc factory failed to subscribe to membership updates: %v", err) + } + d.wg.Add(1) + go d.listenMembershipChanges(ch) + + return nil +} + +func (d *FactoryImpl) Stop() error { + d.logger.Info("stopping rpc factory") + if err := d.peerLister.Unsubscribe(d.serviceName, factoryComponentName); err != nil { + d.logger.Error("rpc factory failed to unsubscribe from membership updates", tag.Error(err)) + } + + d.cancelFn() + d.wg.Wait() + + d.logger.Info("stopped rpc factory") + return nil +} + +func (d *FactoryImpl) listenMembershipChanges(ch chan *membership.ChangedEvent) { + defer d.wg.Done() + + for { + select { + case <-ch: + d.logger.Debug("rpc factory received membership changed event") + members, err := d.peerLister.Members(d.serviceName) + if err != nil { + d.logger.Error("rpc factory failed to get members from membership resolver", tag.Error(err)) + continue + } + + d.outbounds.UpdatePeers(members) + case <-d.ctx.Done(): + d.logger.Info("rpc factory stopped so listenMembershipChanges returning") + return + } + } +} + func createDialer(transport *grpc.Transport, tlsConfig *tls.Config) *grpc.Dialer { var dialOptions []grpc.DialOption if tlsConfig != nil { diff --git a/common/rpc/factory_mock.go b/common/rpc/factory_mock.go new file mode 100644 index 00000000000..4caa8a29ff2 --- /dev/null +++ b/common/rpc/factory_mock.go @@ -0,0 +1,196 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// Code generated by MockGen. DO NOT EDIT. +// Source: types.go + +// Package rpc is a generated GoMock package. +package rpc + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + yarpc "go.uber.org/yarpc" + tchannel "go.uber.org/yarpc/transport/tchannel" + + membership "github.com/uber/cadence/common/membership" +) + +// MockFactory is a mock of Factory interface. +type MockFactory struct { + ctrl *gomock.Controller + recorder *MockFactoryMockRecorder +} + +// MockFactoryMockRecorder is the mock recorder for MockFactory. +type MockFactoryMockRecorder struct { + mock *MockFactory +} + +// NewMockFactory creates a new mock instance. +func NewMockFactory(ctrl *gomock.Controller) *MockFactory { + mock := &MockFactory{ctrl: ctrl} + mock.recorder = &MockFactoryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockFactory) EXPECT() *MockFactoryMockRecorder { + return m.recorder +} + +// GetDispatcher mocks base method. +func (m *MockFactory) GetDispatcher() *yarpc.Dispatcher { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDispatcher") + ret0, _ := ret[0].(*yarpc.Dispatcher) + return ret0 +} + +// GetDispatcher indicates an expected call of GetDispatcher. +func (mr *MockFactoryMockRecorder) GetDispatcher() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDispatcher", reflect.TypeOf((*MockFactory)(nil).GetDispatcher)) +} + +// GetMaxMessageSize mocks base method. +func (m *MockFactory) GetMaxMessageSize() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMaxMessageSize") + ret0, _ := ret[0].(int) + return ret0 +} + +// GetMaxMessageSize indicates an expected call of GetMaxMessageSize. +func (mr *MockFactoryMockRecorder) GetMaxMessageSize() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMaxMessageSize", reflect.TypeOf((*MockFactory)(nil).GetMaxMessageSize)) +} + +// GetTChannel mocks base method. +func (m *MockFactory) GetTChannel() tchannel.Channel { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTChannel") + ret0, _ := ret[0].(tchannel.Channel) + return ret0 +} + +// GetTChannel indicates an expected call of GetTChannel. +func (mr *MockFactoryMockRecorder) GetTChannel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTChannel", reflect.TypeOf((*MockFactory)(nil).GetTChannel)) +} + +// Start mocks base method. +func (m *MockFactory) Start(arg0 PeerLister) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Start", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Start indicates an expected call of Start. +func (mr *MockFactoryMockRecorder) Start(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockFactory)(nil).Start), arg0) +} + +// Stop mocks base method. +func (m *MockFactory) Stop() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stop") + ret0, _ := ret[0].(error) + return ret0 +} + +// Stop indicates an expected call of Stop. +func (mr *MockFactoryMockRecorder) Stop() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockFactory)(nil).Stop)) +} + +// MockPeerLister is a mock of PeerLister interface. +type MockPeerLister struct { + ctrl *gomock.Controller + recorder *MockPeerListerMockRecorder +} + +// MockPeerListerMockRecorder is the mock recorder for MockPeerLister. +type MockPeerListerMockRecorder struct { + mock *MockPeerLister +} + +// NewMockPeerLister creates a new mock instance. +func NewMockPeerLister(ctrl *gomock.Controller) *MockPeerLister { + mock := &MockPeerLister{ctrl: ctrl} + mock.recorder = &MockPeerListerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPeerLister) EXPECT() *MockPeerListerMockRecorder { + return m.recorder +} + +// Members mocks base method. +func (m *MockPeerLister) Members(service string) ([]membership.HostInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Members", service) + ret0, _ := ret[0].([]membership.HostInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Members indicates an expected call of Members. +func (mr *MockPeerListerMockRecorder) Members(service interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Members", reflect.TypeOf((*MockPeerLister)(nil).Members), service) +} + +// Subscribe mocks base method. +func (m *MockPeerLister) Subscribe(service, name string, notifyChannel chan<- *membership.ChangedEvent) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Subscribe", service, name, notifyChannel) + ret0, _ := ret[0].(error) + return ret0 +} + +// Subscribe indicates an expected call of Subscribe. +func (mr *MockPeerListerMockRecorder) Subscribe(service, name, notifyChannel interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Subscribe", reflect.TypeOf((*MockPeerLister)(nil).Subscribe), service, name, notifyChannel) +} + +// Unsubscribe mocks base method. +func (m *MockPeerLister) Unsubscribe(service, name string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Unsubscribe", service, name) + ret0, _ := ret[0].(error) + return ret0 +} + +// Unsubscribe indicates an expected call of Unsubscribe. +func (mr *MockPeerListerMockRecorder) Unsubscribe(service, name interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unsubscribe", reflect.TypeOf((*MockPeerLister)(nil).Unsubscribe), service, name) +} diff --git a/common/rpc/factory_test.go b/common/rpc/factory_test.go new file mode 100644 index 00000000000..6b039cc8116 --- /dev/null +++ b/common/rpc/factory_test.go @@ -0,0 +1,119 @@ +// Copyright (c) 2017 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package rpc + +import ( + "sync" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "go.uber.org/goleak" + + "github.com/uber/cadence/common/log/testlogger" + "github.com/uber/cadence/common/membership" +) + +func TestNewFactory(t *testing.T) { + ctrl := gomock.NewController(t) + logger := testlogger.New(t) + serviceName := "service" + ob := NewMockOutboundsBuilder(ctrl) + ob.EXPECT().Build(gomock.Any(), gomock.Any()).Return(&Outbounds{}, nil).Times(1) + grpcMsgSize := 4 * 1024 * 1024 + f := NewFactory(logger, Params{ + ServiceName: serviceName, + TChannelAddress: "localhost:0", + GRPCMaxMsgSize: grpcMsgSize, + GRPCAddress: "localhost:0", + HTTP: &httpParams{ + Address: "localhost:0", + }, + OutboundsBuilder: ob, + }) + + if f == nil { + t.Fatal("NewFactory returned nil") + } + + assert.NotNil(t, f.GetDispatcher(), "GetDispatcher returned nil") + assert.NotNil(t, f.GetTChannel(), "GetTChannel returned nil") + assert.Equal(t, grpcMsgSize, f.GetMaxMessageSize(), "GetMaxMessageSize returned wrong value") +} + +func TestStartStop(t *testing.T) { + defer goleak.VerifyNone(t) + + ctrl := gomock.NewController(t) + logger := testlogger.New(t) + serviceName := "service" + ob := NewMockOutboundsBuilder(ctrl) + var mu sync.Mutex + var gotMembers []membership.HostInfo + outbounds := &Outbounds{ + onUpdatePeers: func(members []membership.HostInfo) { + mu.Lock() + defer mu.Unlock() + gotMembers = members + }, + } + ob.EXPECT().Build(gomock.Any(), gomock.Any()).Return(outbounds, nil).Times(1) + grpcMsgSize := 4 * 1024 * 1024 + f := NewFactory(logger, Params{ + ServiceName: serviceName, + TChannelAddress: "localhost:0", + GRPCMaxMsgSize: grpcMsgSize, + GRPCAddress: "localhost:0", + HTTP: &httpParams{ + Address: "localhost:0", + }, + OutboundsBuilder: ob, + }) + + members := []membership.HostInfo{ + membership.NewHostInfo("localhost:9191"), + membership.NewHostInfo("localhost:9192"), + } + peerLister := membership.NewMockResolver(ctrl) + peerLister.EXPECT().Subscribe(serviceName, factoryComponentName, gomock.Any()). + DoAndReturn(func(service, name string, notifyChannel chan<- *membership.ChangedEvent) error { + // Notify the channel once to validate listening logic is working + notifyChannel <- &membership.ChangedEvent{} + return nil + }).Times(1) + peerLister.EXPECT().Unsubscribe(serviceName, factoryComponentName).Return(nil).Times(1) + peerLister.EXPECT().Members(serviceName).Return(members, nil).Times(1) + + if err := f.Start(peerLister); err != nil { + t.Fatalf("Factory.Start() returned error: %v", err) + } + + // Wait for membership changes to be processed + time.Sleep(100 * time.Millisecond) + mu.Lock() + assert.Equal(t, members, gotMembers, "UpdatePeers not called with expected members") + mu.Unlock() + + if err := f.Stop(); err != nil { + t.Fatalf("Factory.Stop() returned error: %v", err) + } +} diff --git a/common/rpc/outbounds.go b/common/rpc/outbounds.go index dcb604ed964..8c17122744b 100644 --- a/common/rpc/outbounds.go +++ b/common/rpc/outbounds.go @@ -18,6 +18,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +//go:generate mockgen -package $GOPACKAGE -source $GOFILE -destination outbounds_mock.go -self_package github.com/uber/cadence/common/rpc + package rpc import ( @@ -28,12 +30,13 @@ import ( "go.uber.org/yarpc" "go.uber.org/yarpc/api/middleware" "go.uber.org/yarpc/api/transport" - "go.uber.org/yarpc/peer/direct" "go.uber.org/yarpc/transport/grpc" "go.uber.org/yarpc/transport/tchannel" "github.com/uber/cadence/common/authorization" "github.com/uber/cadence/common/config" + "github.com/uber/cadence/common/dynamicconfig" + "github.com/uber/cadence/common/membership" "github.com/uber/cadence/common/service" ) @@ -46,21 +49,34 @@ const ( // OutboundsBuilder allows defining outbounds for the dispatcher type OutboundsBuilder interface { - Build(*grpc.Transport, *tchannel.Transport) (yarpc.Outbounds, error) + // Build creates yarpc outbounds given transport instances for either gRPC and TChannel based on the configuration + Build(*grpc.Transport, *tchannel.Transport) (*Outbounds, error) +} + +type Outbounds struct { + yarpc.Outbounds + onUpdatePeers func([]membership.HostInfo) } -type multiOutbounds struct { +func (o *Outbounds) UpdatePeers(peers []membership.HostInfo) { + if o.onUpdatePeers != nil { + o.onUpdatePeers(peers) + } +} + +type multiOutboundsBuilder struct { builders []OutboundsBuilder } // CombineOutbounds takes multiple outbound builders and combines them func CombineOutbounds(builders ...OutboundsBuilder) OutboundsBuilder { - return multiOutbounds{builders} + return multiOutboundsBuilder{builders} } -func (b multiOutbounds) Build(grpc *grpc.Transport, tchannel *tchannel.Transport) (yarpc.Outbounds, error) { +func (b multiOutboundsBuilder) Build(grpc *grpc.Transport, tchannel *tchannel.Transport) (*Outbounds, error) { outbounds := yarpc.Outbounds{} var errs error + var callbacks []func([]membership.HostInfo) for _, builder := range b.builders { builderOutbounds, err := builder.Build(grpc, tchannel) if err != nil { @@ -68,7 +84,11 @@ func (b multiOutbounds) Build(grpc *grpc.Transport, tchannel *tchannel.Transport continue } - for name, outbound := range builderOutbounds { + if builderOutbounds.onUpdatePeers != nil { + callbacks = append(callbacks, builderOutbounds.onUpdatePeers) + } + + for name, outbound := range builderOutbounds.Outbounds { if _, exists := outbounds[name]; exists { errs = multierr.Append(errs, fmt.Errorf("outbound %q already configured", name)) break @@ -76,7 +96,15 @@ func (b multiOutbounds) Build(grpc *grpc.Transport, tchannel *tchannel.Transport outbounds[name] = outbound } } - return outbounds, errs + + return &Outbounds{ + Outbounds: outbounds, + onUpdatePeers: func(peers []membership.HostInfo) { + for _, callback := range callbacks { + callback(peers) + } + }, + }, errs } type publicClientOutbound struct { @@ -106,17 +134,19 @@ func newPublicClientOutbound(config *config.Config) (publicClientOutbound, error return publicClientOutbound{config.PublicClient.HostPort, isGrpc, authMiddleware}, nil } -func (b publicClientOutbound) Build(grpc *grpc.Transport, tchannel *tchannel.Transport) (yarpc.Outbounds, error) { +func (b publicClientOutbound) Build(grpc *grpc.Transport, tchannel *tchannel.Transport) (*Outbounds, error) { var outbound transport.UnaryOutbound if b.isGRPC { outbound = grpc.NewSingleOutbound(b.address) } else { outbound = tchannel.NewSingleOutbound(b.address) } - return yarpc.Outbounds{ - OutboundPublicClient: { - ServiceName: service.Frontend, - Unary: middleware.ApplyUnaryOutbound(outbound, b.authMiddleware), + return &Outbounds{ + Outbounds: yarpc.Outbounds{ + OutboundPublicClient: { + ServiceName: service.Frontend, + Unary: middleware.ApplyUnaryOutbound(outbound, b.authMiddleware), + }, }, }, nil } @@ -130,7 +160,7 @@ func NewCrossDCOutbounds(clusterGroup map[string]config.ClusterInformation, pcf return crossDCOutbounds{clusterGroup, pcf} } -func (b crossDCOutbounds) Build(grpcTransport *grpc.Transport, tchannelTransport *tchannel.Transport) (yarpc.Outbounds, error) { +func (b crossDCOutbounds) Build(grpcTransport *grpc.Transport, tchannelTransport *tchannel.Transport) (*Outbounds, error) { outbounds := yarpc.Outbounds{} for clusterName, clusterInfo := range b.clusterGroup { if !clusterInfo.Enabled { @@ -140,7 +170,7 @@ func (b crossDCOutbounds) Build(grpcTransport *grpc.Transport, tchannelTransport var outbound transport.UnaryOutbound switch clusterInfo.RPCTransport { case tchannel.TransportName: - peerChooser, err := b.pcf.CreatePeerChooser(tchannelTransport, clusterInfo.RPCAddress) + peerChooser, err := b.pcf.CreatePeerChooser(tchannelTransport, PeerChooserOptions{Address: clusterInfo.RPCAddress}) if err != nil { return nil, err } @@ -150,7 +180,7 @@ func (b crossDCOutbounds) Build(grpcTransport *grpc.Transport, tchannelTransport if err != nil { return nil, err } - peerChooser, err := b.pcf.CreatePeerChooser(createDialer(grpcTransport, tlsConfig), clusterInfo.RPCAddress) + peerChooser, err := b.pcf.CreatePeerChooser(createDialer(grpcTransport, tlsConfig), PeerChooserOptions{Address: clusterInfo.RPCAddress}) if err != nil { return nil, err } @@ -176,40 +206,52 @@ func (b crossDCOutbounds) Build(grpcTransport *grpc.Transport, tchannelTransport )), } } - return outbounds, nil + return &Outbounds{Outbounds: outbounds}, nil } type directOutbound struct { - serviceName string - grpcEnabled bool - tlsConfig *tls.Config + serviceName string + grpcEnabled bool + tlsConfig *tls.Config + pcf PeerChooserFactory + enableConnRetainMode dynamicconfig.BoolPropertyFn } -func NewDirectOutbound(serviceName string, grpcEnabled bool, tlsConfig *tls.Config) OutboundsBuilder { - return directOutbound{serviceName, grpcEnabled, tlsConfig} +func NewDirectOutboundBuilder(serviceName string, grpcEnabled bool, tlsConfig *tls.Config, pcf PeerChooserFactory, enableConnRetainMode dynamicconfig.BoolPropertyFn) OutboundsBuilder { + return directOutbound{serviceName, grpcEnabled, tlsConfig, pcf, enableConnRetainMode} } -func (o directOutbound) Build(grpc *grpc.Transport, tchannel *tchannel.Transport) (yarpc.Outbounds, error) { +func (o directOutbound) Build(grpc *grpc.Transport, tchannel *tchannel.Transport) (*Outbounds, error) { var outbound transport.UnaryOutbound + opts := PeerChooserOptions{ + EnableConnectionRetainingDirectChooser: o.enableConnRetainMode, + ServiceName: o.serviceName, + } + + var err error + var directChooser PeerChooser if o.grpcEnabled { - directChooser, err := direct.New(direct.Configuration{}, createDialer(grpc, o.tlsConfig)) + directChooser, err = o.pcf.CreatePeerChooser(createDialer(grpc, o.tlsConfig), opts) if err != nil { return nil, err } outbound = grpc.NewOutbound(directChooser) } else { - directChooser, err := direct.New(direct.Configuration{}, tchannel) + directChooser, err = o.pcf.CreatePeerChooser(tchannel, opts) if err != nil { return nil, err } outbound = tchannel.NewOutbound(directChooser) } - return yarpc.Outbounds{ - o.serviceName: { - ServiceName: o.serviceName, - Unary: middleware.ApplyUnaryOutbound(outbound, &ResponseInfoMiddleware{}), + return &Outbounds{ + Outbounds: yarpc.Outbounds{ + o.serviceName: { + ServiceName: o.serviceName, + Unary: middleware.ApplyUnaryOutbound(outbound, &ResponseInfoMiddleware{}), + }, }, + onUpdatePeers: directChooser.UpdatePeers, }, nil } diff --git a/common/rpc/outbounds_mock.go b/common/rpc/outbounds_mock.go new file mode 100644 index 00000000000..ce1dd17cc6d --- /dev/null +++ b/common/rpc/outbounds_mock.go @@ -0,0 +1,73 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// Code generated by MockGen. DO NOT EDIT. +// Source: outbounds.go + +// Package rpc is a generated GoMock package. +package rpc + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + grpc "go.uber.org/yarpc/transport/grpc" + tchannel "go.uber.org/yarpc/transport/tchannel" +) + +// MockOutboundsBuilder is a mock of OutboundsBuilder interface. +type MockOutboundsBuilder struct { + ctrl *gomock.Controller + recorder *MockOutboundsBuilderMockRecorder +} + +// MockOutboundsBuilderMockRecorder is the mock recorder for MockOutboundsBuilder. +type MockOutboundsBuilderMockRecorder struct { + mock *MockOutboundsBuilder +} + +// NewMockOutboundsBuilder creates a new mock instance. +func NewMockOutboundsBuilder(ctrl *gomock.Controller) *MockOutboundsBuilder { + mock := &MockOutboundsBuilder{ctrl: ctrl} + mock.recorder = &MockOutboundsBuilderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockOutboundsBuilder) EXPECT() *MockOutboundsBuilderMockRecorder { + return m.recorder +} + +// Build mocks base method. +func (m *MockOutboundsBuilder) Build(arg0 *grpc.Transport, arg1 *tchannel.Transport) (*Outbounds, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Build", arg0, arg1) + ret0, _ := ret[0].(*Outbounds) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Build indicates an expected call of Build. +func (mr *MockOutboundsBuilderMockRecorder) Build(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Build", reflect.TypeOf((*MockOutboundsBuilder)(nil).Build), arg0, arg1) +} diff --git a/common/rpc/outbounds_test.go b/common/rpc/outbounds_test.go index fd1e088d0c2..fb83c4c5fcd 100644 --- a/common/rpc/outbounds_test.go +++ b/common/rpc/outbounds_test.go @@ -35,6 +35,8 @@ import ( "go.uber.org/yarpc/transport/tchannel" "github.com/uber/cadence/common/config" + "github.com/uber/cadence/common/dynamicconfig" + "github.com/uber/cadence/common/log/testlogger" "github.com/uber/cadence/common/service" ) @@ -45,7 +47,7 @@ func TestCombineOutbounds(t *testing.T) { combined := CombineOutbounds() outbounds, err := combined.Build(grpc, tchannel) assert.NoError(t, err) - assert.Empty(t, outbounds) + assert.Empty(t, outbounds.Outbounds) combined = CombineOutbounds(fakeOutboundBuilder{err: errors.New("err-A")}) _, err = combined.Build(grpc, tchannel) @@ -75,7 +77,7 @@ func TestCombineOutbounds(t *testing.T) { "A": {}, "B": {}, "C": {}, - }, outbounds) + }, outbounds.Outbounds) } func TestPublicClientOutbound(t *testing.T) { @@ -119,8 +121,9 @@ func TestPublicClientOutbound(t *testing.T) { grpc := &grpc.Transport{} tchannel := &tchannel.Transport{} - outbounds, err := builder.Build(grpc, tchannel) + o, err := builder.Build(grpc, tchannel) require.NoError(t, err) + outbounds := o.Outbounds assert.Equal(t, outbounds[OutboundPublicClient].ServiceName, service.Frontend) assert.NotNil(t, outbounds[OutboundPublicClient].Unary) } @@ -146,8 +149,9 @@ func TestCrossDCOutbounds(t *testing.T) { "cluster-B": {Enabled: true, RPCName: "cadence-frontend", RPCAddress: "address-B", RPCTransport: "tchannel"}, "cluster-C": {Enabled: false}, } - outbounds, err := NewCrossDCOutbounds(clusterGroup, &fakePeerChooserFactory{}).Build(grpc, tchannel) + o, err := NewCrossDCOutbounds(clusterGroup, &fakePeerChooserFactory{}).Build(grpc, tchannel) assert.NoError(t, err) + outbounds := o.Outbounds assert.Equal(t, 2, len(outbounds)) assert.Equal(t, "cadence-frontend", outbounds["cluster-A"].ServiceName) assert.Equal(t, "cadence-frontend", outbounds["cluster-B"].ServiceName) @@ -158,14 +162,18 @@ func TestCrossDCOutbounds(t *testing.T) { func TestDirectOutbound(t *testing.T) { grpc := &grpc.Transport{} tchannel := &tchannel.Transport{} + logger := testlogger.New(t) + falseFn := func(opts ...dynamicconfig.FilterOption) bool { return false } - outbounds, err := NewDirectOutbound("cadence-history", false, nil).Build(grpc, tchannel) + o, err := NewDirectOutboundBuilder("cadence-history", false, nil, NewDirectPeerChooserFactory("cadence-history", logger), falseFn).Build(grpc, tchannel) assert.NoError(t, err) + outbounds := o.Outbounds assert.Equal(t, "cadence-history", outbounds["cadence-history"].ServiceName) assert.NotNil(t, outbounds["cadence-history"].Unary) - outbounds, err = NewDirectOutbound("cadence-history", true, nil).Build(grpc, tchannel) + o, err = NewDirectOutboundBuilder("cadence-history", true, nil, NewDirectPeerChooserFactory("cadence-history", logger), falseFn).Build(grpc, tchannel) assert.NoError(t, err) + outbounds = o.Outbounds assert.Equal(t, "cadence-history", outbounds["cadence-history"].ServiceName) assert.NotNil(t, outbounds["cadence-history"].Unary) } @@ -196,12 +204,16 @@ type fakeOutboundBuilder struct { err error } -func (b fakeOutboundBuilder) Build(*grpc.Transport, *tchannel.Transport) (yarpc.Outbounds, error) { - return b.outbounds, b.err +func (b fakeOutboundBuilder) Build(*grpc.Transport, *tchannel.Transport) (*Outbounds, error) { + return &Outbounds{Outbounds: b.outbounds}, b.err } type fakePeerChooserFactory struct{} -func (f fakePeerChooserFactory) CreatePeerChooser(transport peer.Transport, address string) (peer.Chooser, error) { - return direct.New(direct.Configuration{}, transport) +func (f fakePeerChooserFactory) CreatePeerChooser(transport peer.Transport, opts PeerChooserOptions) (PeerChooser, error) { + chooser, err := direct.New(direct.Configuration{}, transport) + if err != nil { + return nil, err + } + return &defaultPeerChooser{Chooser: chooser}, nil } diff --git a/common/rpc/params.go b/common/rpc/params.go index ff5dc6582ee..10b58faad6b 100644 --- a/common/rpc/params.go +++ b/common/rpc/params.go @@ -33,6 +33,7 @@ import ( "github.com/uber/cadence/common/config" "github.com/uber/cadence/common/dynamicconfig" + "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/service" ) @@ -61,7 +62,7 @@ type httpParams struct { } // NewParams creates parameters for rpc.Factory from the given config -func NewParams(serviceName string, config *config.Config, dc *dynamicconfig.Collection) (Params, error) { +func NewParams(serviceName string, config *config.Config, dc *dynamicconfig.Collection, logger log.Logger) (Params, error) { serviceConfig, err := config.GetServiceConfig(serviceName) if err != nil { return Params{}, err @@ -138,8 +139,20 @@ func NewParams(serviceName string, config *config.Config, dc *dynamicconfig.Coll GRPCAddress: net.JoinHostPort(listenIP.String(), strconv.Itoa(int(serviceConfig.RPC.GRPCPort))), GRPCMaxMsgSize: serviceConfig.RPC.GRPCMaxMsgSize, OutboundsBuilder: CombineOutbounds( - NewDirectOutbound(service.History, enableGRPCOutbound, outboundTLS[service.History]), - NewDirectOutbound(service.Matching, enableGRPCOutbound, outboundTLS[service.Matching]), + NewDirectOutboundBuilder( + service.History, + enableGRPCOutbound, + outboundTLS[service.History], + NewDirectPeerChooserFactory(service.History, logger), + dc.GetBoolProperty(dynamicconfig.EnableConnectionRetainingDirectChooser), + ), + NewDirectOutboundBuilder( + service.Matching, + enableGRPCOutbound, + outboundTLS[service.Matching], + NewDirectPeerChooserFactory(service.Matching, logger), + dc.GetBoolProperty(dynamicconfig.EnableConnectionRetainingDirectChooser), + ), publicClientOutbound, ), InboundTLS: inboundTLS, diff --git a/common/rpc/params_test.go b/common/rpc/params_test.go index 994c1cff09f..cdfe9b6f38f 100644 --- a/common/rpc/params_test.go +++ b/common/rpc/params_test.go @@ -28,6 +28,7 @@ import ( "github.com/uber/cadence/common/config" "github.com/uber/cadence/common/dynamicconfig" + "github.com/uber/cadence/common/log/testlogger" "github.com/uber/cadence/common/service" ) @@ -39,47 +40,48 @@ func TestNewParams(t *testing.T) { PublicClient: config.PublicClient{HostPort: "localhost:9999"}, Services: map[string]config.Service{"frontend": svc}} } + logger := testlogger.New(t) - _, err := NewParams(serviceName, &config.Config{}, dc) + _, err := NewParams(serviceName, &config.Config{}, dc, logger) assert.EqualError(t, err, "no config section for service: frontend") - _, err = NewParams(serviceName, makeConfig(config.Service{RPC: config.RPC{BindOnLocalHost: true, BindOnIP: "1.2.3.4"}}), dc) + _, err = NewParams(serviceName, makeConfig(config.Service{RPC: config.RPC{BindOnLocalHost: true, BindOnIP: "1.2.3.4"}}), dc, logger) assert.EqualError(t, err, "get listen IP: bindOnLocalHost and bindOnIP are mutually exclusive") - _, err = NewParams(serviceName, makeConfig(config.Service{RPC: config.RPC{BindOnIP: "invalidIP"}}), dc) + _, err = NewParams(serviceName, makeConfig(config.Service{RPC: config.RPC{BindOnIP: "invalidIP"}}), dc, logger) assert.EqualError(t, err, "get listen IP: unable to parse bindOnIP value or it is not an IPv4 or IPv6 address: invalidIP") - _, err = NewParams(serviceName, &config.Config{Services: map[string]config.Service{"frontend": {}}}, dc) + _, err = NewParams(serviceName, &config.Config{Services: map[string]config.Service{"frontend": {}}}, dc, logger) assert.EqualError(t, err, "public client outbound: need to provide an endpoint config for PublicClient") - _, err = NewParams(serviceName, makeConfig(config.Service{RPC: config.RPC{BindOnLocalHost: true, TLS: config.TLS{Enabled: true, CertFile: "invalid", KeyFile: "invalid"}}}), dc) + _, err = NewParams(serviceName, makeConfig(config.Service{RPC: config.RPC{BindOnLocalHost: true, TLS: config.TLS{Enabled: true, CertFile: "invalid", KeyFile: "invalid"}}}), dc, logger) assert.EqualError(t, err, "inbound TLS config: open invalid: no such file or directory") _, err = NewParams(serviceName, &config.Config{Services: map[string]config.Service{ "frontend": {RPC: config.RPC{BindOnLocalHost: true}}, "history": {RPC: config.RPC{TLS: config.TLS{Enabled: true, CaFile: "invalid"}}}, - }}, dc) + }}, dc, logger) assert.EqualError(t, err, "outbound cadence-history TLS config: open invalid: no such file or directory") - params, err := NewParams(serviceName, makeConfig(config.Service{RPC: config.RPC{BindOnLocalHost: true, Port: 1111, GRPCPort: 2222, GRPCMaxMsgSize: 3333}}), dc) + params, err := NewParams(serviceName, makeConfig(config.Service{RPC: config.RPC{BindOnLocalHost: true, Port: 1111, GRPCPort: 2222, GRPCMaxMsgSize: 3333}}), dc, logger) assert.NoError(t, err) assert.Equal(t, "127.0.0.1:1111", params.TChannelAddress) assert.Equal(t, "127.0.0.1:2222", params.GRPCAddress) assert.Equal(t, 3333, params.GRPCMaxMsgSize) assert.Nil(t, params.InboundTLS) - params, err = NewParams(serviceName, makeConfig(config.Service{RPC: config.RPC{BindOnLocalHost: true, HTTP: &config.HTTP{Port: 8800}}}), dc) + params, err = NewParams(serviceName, makeConfig(config.Service{RPC: config.RPC{BindOnLocalHost: true, HTTP: &config.HTTP{Port: 8800}}}), dc, logger) assert.NoError(t, err) assert.Equal(t, "127.0.0.1:8800", params.HTTP.Address) - params, err = NewParams(serviceName, makeConfig(config.Service{RPC: config.RPC{BindOnLocalHost: true, HTTP: &config.HTTP{}}}), dc) + params, err = NewParams(serviceName, makeConfig(config.Service{RPC: config.RPC{BindOnLocalHost: true, HTTP: &config.HTTP{}}}), dc, logger) assert.Error(t, err) - params, err = NewParams(serviceName, makeConfig(config.Service{RPC: config.RPC{BindOnIP: "1.2.3.4", GRPCPort: 2222}}), dc) + params, err = NewParams(serviceName, makeConfig(config.Service{RPC: config.RPC{BindOnIP: "1.2.3.4", GRPCPort: 2222}}), dc, logger) assert.NoError(t, err) assert.Equal(t, "1.2.3.4:2222", params.GRPCAddress) - params, err = NewParams(serviceName, makeConfig(config.Service{RPC: config.RPC{GRPCPort: 2222, TLS: config.TLS{Enabled: true}}}), dc) + params, err = NewParams(serviceName, makeConfig(config.Service{RPC: config.RPC{GRPCPort: 2222, TLS: config.TLS{Enabled: true}}}), dc, logger) assert.NoError(t, err) ip, port, err := net.SplitHostPort(params.GRPCAddress) assert.NoError(t, err) diff --git a/common/rpc/peer_chooser.go b/common/rpc/peer_chooser.go index b2667f71b2a..d2c16344101 100644 --- a/common/rpc/peer_chooser.go +++ b/common/rpc/peer_chooser.go @@ -26,21 +26,56 @@ import ( "go.uber.org/yarpc/api/peer" "go.uber.org/yarpc/peer/roundrobin" + "github.com/uber/cadence/common/dynamicconfig" "github.com/uber/cadence/common/log" + "github.com/uber/cadence/common/membership" ) const defaultDNSRefreshInterval = time.Second * 10 type ( + PeerChooserOptions struct { + // Address is the target dns address. Used by dns peer chooser. + Address string + + // ServiceName is the name of service. Used by direct peer chooser. + ServiceName string + + // EnableConnectionRetainingDirectChooser is used by direct peer chooser. + // If false, yarpc's own default direct peer chooser will be used which doesn't retain connections. + // If true, cadence's own direct peer chooser will be used which retains connections. + EnableConnectionRetainingDirectChooser dynamicconfig.BoolPropertyFn + } PeerChooserFactory interface { - CreatePeerChooser(transport peer.Transport, address string) (peer.Chooser, error) + CreatePeerChooser(transport peer.Transport, opts PeerChooserOptions) (PeerChooser, error) } + + PeerChooser interface { + peer.Chooser + + // UpdatePeers updates the list of peers if needed. + UpdatePeers([]membership.HostInfo) + } + dnsPeerChooserFactory struct { interval time.Duration logger log.Logger } + + directPeerChooserFactory struct { + serviceName string + logger log.Logger + choosers []*directPeerChooser + } ) +type defaultPeerChooser struct { + peer.Chooser +} + +// UpdatePeers is a no-op for defaultPeerChooser. It is added to satisfy the PeerChooser interface. +func (d *defaultPeerChooser) UpdatePeers(peers []membership.HostInfo) {} + func NewDNSPeerChooserFactory(interval time.Duration, logger log.Logger) PeerChooserFactory { if interval <= 0 { interval = defaultDNSRefreshInterval @@ -49,12 +84,24 @@ func NewDNSPeerChooserFactory(interval time.Duration, logger log.Logger) PeerCho return &dnsPeerChooserFactory{interval, logger} } -func (f *dnsPeerChooserFactory) CreatePeerChooser(transport peer.Transport, address string) (peer.Chooser, error) { +func (f *dnsPeerChooserFactory) CreatePeerChooser(transport peer.Transport, opts PeerChooserOptions) (PeerChooser, error) { peerList := roundrobin.New(transport) - peerListUpdater, err := newDNSUpdater(peerList, address, f.interval, f.logger) + peerListUpdater, err := newDNSUpdater(peerList, opts.Address, f.interval, f.logger) if err != nil { return nil, err } peerListUpdater.Start() - return peerList, nil + return &defaultPeerChooser{Chooser: peerList}, nil +} + +func NewDirectPeerChooserFactory(serviceName string, logger log.Logger) PeerChooserFactory { + return &directPeerChooserFactory{ + logger: logger, + } +} + +func (f *directPeerChooserFactory) CreatePeerChooser(transport peer.Transport, opts PeerChooserOptions) (PeerChooser, error) { + c := newDirectChooser(f.serviceName, transport, f.logger, opts.EnableConnectionRetainingDirectChooser) + f.choosers = append(f.choosers, c) + return c, nil } diff --git a/common/rpc/peer_chooser_test.go b/common/rpc/peer_chooser_test.go index f000c1df85d..efd24db99bb 100644 --- a/common/rpc/peer_chooser_test.go +++ b/common/rpc/peer_chooser_test.go @@ -29,10 +29,30 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/yarpc/api/peer" "go.uber.org/yarpc/api/transport" + "go.uber.org/yarpc/transport/grpc" + "github.com/uber/cadence/common/dynamicconfig" "github.com/uber/cadence/common/log" + "github.com/uber/cadence/common/log/testlogger" ) +type ( + fakePeerTransport struct{} + fakePeer struct{} +) + +func (t *fakePeerTransport) RetainPeer(peer.Identifier, peer.Subscriber) (peer.Peer, error) { + return &fakePeer{}, nil +} +func (t *fakePeerTransport) ReleasePeer(peer.Identifier, peer.Subscriber) error { + return nil +} + +func (p *fakePeer) Identifier() string { return "fakePeer" } +func (p *fakePeer) Status() peer.Status { return peer.Status{ConnectionStatus: peer.Available} } +func (p *fakePeer) StartRequest() {} +func (p *fakePeer) EndRequest() {} + func TestDNSPeerChooserFactory(t *testing.T) { logger := log.NewNoop() ctx := context.Background() @@ -42,10 +62,10 @@ func TestDNSPeerChooserFactory(t *testing.T) { peerTransport := &fakePeerTransport{} // Ensure invalid address returns error - _, err := factory.CreatePeerChooser(peerTransport, "invalid address") + _, err := factory.CreatePeerChooser(peerTransport, PeerChooserOptions{Address: "invalid address"}) assert.EqualError(t, err, "incorrect DNS:Port format") - chooser, err := factory.CreatePeerChooser(peerTransport, "localhost:1234") + chooser, err := factory.CreatePeerChooser(peerTransport, PeerChooserOptions{Address: "localhost:1234"}) require.NoError(t, err) require.NoError(t, chooser.Start()) @@ -60,19 +80,24 @@ func TestDNSPeerChooserFactory(t *testing.T) { assert.Equal(t, "fakePeer", peer.Identifier()) } -type ( - fakePeerTransport struct{} - fakePeer struct{} -) +func TestDirectPeerChooserFactory(t *testing.T) { + logger := testlogger.New(t) + serviceName := "service" + pcf := NewDirectPeerChooserFactory(serviceName, logger) + directConnRetainFn := func(opts ...dynamicconfig.FilterOption) bool { return false } + grpcTransport := grpc.NewTransport() + chooser, err := pcf.CreatePeerChooser(grpcTransport, PeerChooserOptions{ + ServiceName: serviceName, + EnableConnectionRetainingDirectChooser: directConnRetainFn, + }) + if err != nil { + t.Fatalf("Failed to create direct peer chooser: %v", err) + } + if chooser == nil { + t.Fatal("Failed to create direct peer chooser: nil") + } -func (t *fakePeerTransport) RetainPeer(peer.Identifier, peer.Subscriber) (peer.Peer, error) { - return &fakePeer{}, nil + if _, dc := chooser.(*directPeerChooser); !dc { + t.Fatalf("Want chooser be of type (*directPeerChooser), got %d", chooser) + } } -func (t *fakePeerTransport) ReleasePeer(peer.Identifier, peer.Subscriber) error { - return nil -} - -func (p *fakePeer) Identifier() string { return "fakePeer" } -func (p *fakePeer) Status() peer.Status { return peer.Status{ConnectionStatus: peer.Available} } -func (p *fakePeer) StartRequest() {} -func (p *fakePeer) EndRequest() {} diff --git a/common/rpc/types.go b/common/rpc/types.go new file mode 100644 index 00000000000..085c48d5686 --- /dev/null +++ b/common/rpc/types.go @@ -0,0 +1,45 @@ +// Copyright (c) 2021 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +//go:generate mockgen -package $GOPACKAGE -source $GOFILE -destination factory_mock.go -self_package github.com/uber/cadence/common/rpc + +package rpc + +import ( + "go.uber.org/yarpc" + "go.uber.org/yarpc/transport/tchannel" + + "github.com/uber/cadence/common/membership" +) + +// Factory Creates a dispatcher that knows how to transport requests. +type Factory interface { + GetDispatcher() *yarpc.Dispatcher + GetMaxMessageSize() int + Start(PeerLister) error + GetTChannel() tchannel.Channel + Stop() error +} + +type PeerLister interface { + Subscribe(service, name string, notifyChannel chan<- *membership.ChangedEvent) error + Unsubscribe(service, name string) error + Members(service string) ([]membership.HostInfo, error) +} diff --git a/common/rpc_mock.go b/common/rpc_mock.go deleted file mode 100644 index ad16acb2ab5..00000000000 --- a/common/rpc_mock.go +++ /dev/null @@ -1,85 +0,0 @@ -// The MIT License (MIT) - -// Copyright (c) 2017-2020 Uber Technologies Inc. - -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -// Code generated by MockGen. DO NOT EDIT. -// Source: rpc.go - -// Package common is a generated GoMock package. -package common - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - yarpc "go.uber.org/yarpc" -) - -// MockRPCFactory is a mock of RPCFactory interface. -type MockRPCFactory struct { - ctrl *gomock.Controller - recorder *MockRPCFactoryMockRecorder -} - -// MockRPCFactoryMockRecorder is the mock recorder for MockRPCFactory. -type MockRPCFactoryMockRecorder struct { - mock *MockRPCFactory -} - -// NewMockRPCFactory creates a new mock instance. -func NewMockRPCFactory(ctrl *gomock.Controller) *MockRPCFactory { - mock := &MockRPCFactory{ctrl: ctrl} - mock.recorder = &MockRPCFactoryMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockRPCFactory) EXPECT() *MockRPCFactoryMockRecorder { - return m.recorder -} - -// GetDispatcher mocks base method. -func (m *MockRPCFactory) GetDispatcher() *yarpc.Dispatcher { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetDispatcher") - ret0, _ := ret[0].(*yarpc.Dispatcher) - return ret0 -} - -// GetDispatcher indicates an expected call of GetDispatcher. -func (mr *MockRPCFactoryMockRecorder) GetDispatcher() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDispatcher", reflect.TypeOf((*MockRPCFactory)(nil).GetDispatcher)) -} - -// GetMaxMessageSize mocks base method. -func (m *MockRPCFactory) GetMaxMessageSize() int { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetMaxMessageSize") - ret0, _ := ret[0].(int) - return ret0 -} - -// GetMaxMessageSize indicates an expected call of GetMaxMessageSize. -func (mr *MockRPCFactoryMockRecorder) GetMaxMessageSize() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMaxMessageSize", reflect.TypeOf((*MockRPCFactory)(nil).GetMaxMessageSize)) -} diff --git a/host/onebox.go b/host/onebox.go index a76f18ccd32..b1f658b3fe0 100644 --- a/host/onebox.go +++ b/host/onebox.go @@ -1026,7 +1026,7 @@ func newPublicClient(dispatcher *yarpc.Dispatcher) cwsc.Interface { ) } -func (c *cadenceImpl) newRPCFactory(serviceName string, host membership.HostInfo) common.RPCFactory { +func (c *cadenceImpl) newRPCFactory(serviceName string, host membership.HostInfo) rpc.Factory { tchannelAddress, err := host.GetNamedAddress(membership.PortTchannel) if err != nil { c.logger.Fatal("failed to get PortTchannel port from host", tag.Value(host), tag.Error(err)) @@ -1042,6 +1042,9 @@ func (c *cadenceImpl) newRPCFactory(serviceName string, host membership.HostInfo c.logger.Fatal("failed to get frontend PortGRPC", tag.Value(c.FrontendHost()), tag.Error(err)) } + directOutboundPCF := rpc.NewDirectPeerChooserFactory(serviceName, c.logger) + directConnRetainFn := func(opts ...dynamicconfig.FilterOption) bool { return false } + return rpc.NewFactory(c.logger, rpc.Params{ ServiceName: serviceName, TChannelAddress: tchannelAddress, @@ -1055,8 +1058,8 @@ func (c *cadenceImpl) newRPCFactory(serviceName string, host membership.HostInfo &singleGRPCOutbound{testOutboundName(serviceName), serviceName, grpcAddress}, &singleGRPCOutbound{rpc.OutboundPublicClient, service.Frontend, frontendGrpcAddress}, rpc.NewCrossDCOutbounds(c.clusterMetadata.GetAllClusterInfo(), rpc.NewDNSPeerChooserFactory(0, c.logger)), - rpc.NewDirectOutbound(service.History, true, nil), - rpc.NewDirectOutbound(service.Matching, true, nil), + rpc.NewDirectOutboundBuilder(service.History, true, nil, directOutboundPCF, directConnRetainFn), + rpc.NewDirectOutboundBuilder(service.Matching, true, nil, directOutboundPCF, directConnRetainFn), ), }) } @@ -1072,11 +1075,13 @@ type singleGRPCOutbound struct { address string } -func (b singleGRPCOutbound) Build(grpc *grpc.Transport, _ *tchannel.Transport) (yarpc.Outbounds, error) { - return yarpc.Outbounds{ - b.outboundName: { - ServiceName: b.serviceName, - Unary: grpc.NewSingleOutbound(b.address), +func (b singleGRPCOutbound) Build(grpc *grpc.Transport, _ *tchannel.Transport) (*rpc.Outbounds, error) { + return &rpc.Outbounds{ + Outbounds: yarpc.Outbounds{ + b.outboundName: { + ServiceName: b.serviceName, + Unary: grpc.NewSingleOutbound(b.address), + }, }, }, nil } diff --git a/host/service.go b/host/service.go index 9095aa4f6fb..368e5be8a6a 100644 --- a/host/service.go +++ b/host/service.go @@ -43,6 +43,7 @@ import ( "github.com/uber/cadence/common/metrics" "github.com/uber/cadence/common/persistence" "github.com/uber/cadence/common/resource" + "github.com/uber/cadence/common/rpc" ) type ( @@ -74,7 +75,7 @@ type ( hostInfo membership.HostInfo dispatcher *yarpc.Dispatcher membershipResolver membership.Resolver - rpcFactory common.RPCFactory + rpcFactory rpc.Factory pprofInitializer common.PProfInitializer clientBean client.Bean timeSource clock.TimeSource