diff --git a/pkg/networkservice/chains/client/client_heal_test.go b/pkg/networkservice/chains/client/client_heal_test.go index 1f0be1460..1c5f1c1da 100644 --- a/pkg/networkservice/chains/client/client_heal_test.go +++ b/pkg/networkservice/chains/client/client_heal_test.go @@ -43,7 +43,7 @@ func TestClientHeal(t *testing.T) { nsc := client.NewClient(ctx, serverURL, - client.WithDialOptions(sandbox.DefaultDialOptions(sandbox.GenerateTestToken)...), + client.WithDialOptions(sandbox.DialOptions()...), client.WithDialTimeout(time.Second), ) _, err := nsc.Request(ctx, &networkservice.NetworkServiceRequest{}) diff --git a/pkg/networkservice/chains/nsmgr/suite_test.go b/pkg/networkservice/chains/nsmgr/suite_test.go index fdb315400..11fb0e042 100644 --- a/pkg/networkservice/chains/nsmgr/suite_test.go +++ b/pkg/networkservice/chains/nsmgr/suite_test.go @@ -162,7 +162,7 @@ func (s *nsmgrSuite) Test_SelectsRestartingEndpointUsecase() { require.NoError(t, err) nseRegistryClient := registryclient.NewNetworkServiceEndpointRegistryClient(ctx, sandbox.CloneURL(s.domain.Nodes[0].NSMgr.URL), - registryclient.WithDialOptions(sandbox.DefaultDialOptions(sandbox.GenerateTestToken)...)) + registryclient.WithDialOptions(sandbox.DialOptions()...)) nseReg, err = nseRegistryClient.Register(ctx, nseReg) require.NoError(t, err) @@ -786,7 +786,7 @@ func additionalFunctionalityChain(ctx context.Context, clientURL *url.URL, clien ), ), connect.WithDialTimeout(sandbox.DialTimeout), - connect.WithDialOptions(sandbox.DefaultDialOptions(sandbox.GenerateTestToken)...), + connect.WithDialOptions(sandbox.DialOptions()...), ), ), } diff --git a/pkg/networkservice/chains/nsmgrproxy/server_test.go b/pkg/networkservice/chains/nsmgrproxy/server_test.go index 074a73227..8e0b50e7b 100644 --- a/pkg/networkservice/chains/nsmgrproxy/server_test.go +++ b/pkg/networkservice/chains/nsmgrproxy/server_test.go @@ -578,7 +578,7 @@ func Test_Interdomain_PassThroughUsecase(t *testing.T) { kernelmech.NewClient(), )), connect.WithDialTimeout(sandbox.DialTimeout), - connect.WithDialOptions(sandbox.DefaultDialOptions(sandbox.GenerateTestToken)...), + connect.WithDialOptions(sandbox.DialOptions()...), ), ), } diff --git a/pkg/networkservice/common/connect/server_cancel_test.go b/pkg/networkservice/common/connect/server_cancel_test.go index 9483fa63b..7e1060426 100644 --- a/pkg/networkservice/common/connect/server_cancel_test.go +++ b/pkg/networkservice/common/connect/server_cancel_test.go @@ -106,7 +106,7 @@ func TestConnect_CancelDuringRequest(t *testing.T) { connect.NewServer(ctx, clientFactory, connect.WithDialTimeout(sandbox.DialTimeout), - connect.WithDialOptions(sandbox.DefaultDialOptions(sandbox.GenerateTestToken)...), + connect.WithDialOptions(sandbox.DialOptions()...), ), ), ) diff --git a/pkg/registry/chains/memory/server_test.go b/pkg/registry/chains/memory/server_test.go index f621e1f38..e5dc1428f 100644 --- a/pkg/registry/chains/memory/server_test.go +++ b/pkg/registry/chains/memory/server_test.go @@ -19,20 +19,18 @@ package memory_test import ( "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + "google.golang.org/grpc" "github.com/networkservicemesh/api/pkg/api/networkservice/payload" "github.com/networkservicemesh/api/pkg/api/registry" "github.com/networkservicemesh/sdk/pkg/tools/grpcutils" "github.com/networkservicemesh/sdk/pkg/tools/sandbox" - - "github.com/stretchr/testify/require" - - "go.uber.org/goleak" - "google.golang.org/grpc" - - "testing" - "time" ) func Test_RegistryMemory_ShouldSetDefaultPayload(t *testing.T) { @@ -48,7 +46,7 @@ func Test_RegistryMemory_ShouldSetDefaultPayload(t *testing.T) { Build() // start grpc client connection and register it - cc, err := grpc.DialContext(ctx, grpcutils.URLToTarget(domain.Registry.URL), sandbox.DefaultDialOptions(sandbox.GenerateTestToken)...) + cc, err := grpc.DialContext(ctx, grpcutils.URLToTarget(domain.Registry.URL), sandbox.DialOptions()...) require.NoError(t, err) defer func() { _ = cc.Close() diff --git a/pkg/registry/common/heal/find_test.go b/pkg/registry/common/heal/find_test.go index 013453939..08a1a8a68 100644 --- a/pkg/registry/common/heal/find_test.go +++ b/pkg/registry/common/heal/find_test.go @@ -55,7 +55,7 @@ func TestHealClient_FindTest(t *testing.T) { findCtx, findCancel := context.WithCancel(ctx) nsRegistryClient := registryclient.NewNetworkServiceRegistryClient(ctx, sandbox.CloneURL(domain.Nodes[0].NSMgr.URL), - registryclient.WithDialOptions(sandbox.DefaultDialOptions(sandbox.GenerateTestToken)...)) + registryclient.WithDialOptions(sandbox.DialOptions()...)) nsStream, err := nsRegistryClient.Find(findCtx, ®istry.NetworkServiceQuery{ NetworkService: new(registry.NetworkService), @@ -64,7 +64,7 @@ func TestHealClient_FindTest(t *testing.T) { require.NoError(t, err) nseRegistryClient := registryclient.NewNetworkServiceEndpointRegistryClient(ctx, sandbox.CloneURL(domain.Nodes[0].NSMgr.URL), - registryclient.WithDialOptions(sandbox.DefaultDialOptions(sandbox.GenerateTestToken)...)) + registryclient.WithDialOptions(sandbox.DialOptions()...)) nseStream, err := nseRegistryClient.Find(findCtx, ®istry.NetworkServiceEndpointQuery{ NetworkServiceEndpoint: new(registry.NetworkServiceEndpoint), diff --git a/pkg/tools/sandbox/builder.go b/pkg/tools/sandbox/builder.go index 0bbf85a10..13efd17b4 100644 --- a/pkg/tools/sandbox/builder.go +++ b/pkg/tools/sandbox/builder.go @@ -239,7 +239,7 @@ func (b *Builder) newRegistryProxy() *RegistryEntry { entry.Registry = b.supplyRegistryProxy( ctx, b.dnsResolver, - DefaultDialOptions(b.generateTokenFunc)..., + DialOptions(WithTokenGenerator(b.generateTokenFunc))..., ) serve(ctx, b.t, entry.URL, entry.Register) @@ -267,7 +267,7 @@ func (b *Builder) newRegistry() *RegistryEntry { ctx, b.registryExpiryDuration, nsmgrProxyURL, - DefaultDialOptions(b.generateTokenFunc)..., + DialOptions(WithTokenGenerator(b.generateTokenFunc))..., ) serve(ctx, b.t, entry.URL, entry.Register) @@ -287,6 +287,7 @@ func (b *Builder) newNSMgrProxy() *NSMgrEntry { URL: b.domain.NSMgrProxy.URL, } entry.restartableServer = newRestartableServer(b.ctx, b.t, entry.URL, func(ctx context.Context) { + dialOptions := DialOptions(WithTokenGenerator(b.generateTokenFunc)) entry.Nsmgr = b.supplyNSMgrProxy(ctx, CloneURL(b.domain.Registry.URL), CloneURL(b.domain.RegistryProxy.URL), @@ -295,9 +296,9 @@ func (b *Builder) newNSMgrProxy() *NSMgrEntry { nsmgrproxy.WithName(entry.Name), nsmgrproxy.WithConnectOptions( connect.WithDialTimeout(DialTimeout), - connect.WithDialOptions(DefaultDialOptions(b.generateTokenFunc)...)), + connect.WithDialOptions(dialOptions...)), nsmgrproxy.WithRegistryConnectOptions( - registryconnect.WithDialOptions(DefaultDialOptions(b.generateTokenFunc)...), + registryconnect.WithDialOptions(dialOptions...), ), ) serve(ctx, b.t, entry.URL, entry.Register) diff --git a/pkg/tools/sandbox/dial_options.go b/pkg/tools/sandbox/dial_options.go new file mode 100644 index 000000000..4e37db367 --- /dev/null +++ b/pkg/tools/sandbox/dial_options.go @@ -0,0 +1,68 @@ +// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sandbox + +import ( + "github.com/edwarnicke/grpcfd" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + "github.com/networkservicemesh/sdk/pkg/tools/opentracing" + "github.com/networkservicemesh/sdk/pkg/tools/token" +) + +type dialOpts struct { + tokenGenerator token.GeneratorFunc +} + +// DialOption is an option pattern for DialOptions +type DialOption func(o *dialOpts) + +// WithTokenGenerator sets tokenGenerator for DialOptions +func WithTokenGenerator(tokenGenerator token.GeneratorFunc) DialOption { + return func(opts *dialOpts) { + opts.tokenGenerator = tokenGenerator + } +} + +// DialOptions is a helper method for building []grpc.DialOption for testing +func DialOptions(options ...DialOption) []grpc.DialOption { + tokenResetCh := make(chan struct{}) + close(tokenResetCh) + + opts := &dialOpts{ + tokenGenerator: GenerateTestToken, + } + for _, o := range options { + o(opts) + } + + return append([]grpc.DialOption{ + grpc.WithTransportCredentials( + grpcfdTransportCredentials(insecure.NewCredentials()), + ), + grpc.WithBlock(), + grpc.WithDefaultCallOptions( + grpc.WaitForReady(true), + grpc.PerRPCCredentials(token.NewPerRPCCredentials(opts.tokenGenerator)), + ), + grpcfd.WithChainStreamInterceptor(), + grpcfd.WithChainUnaryInterceptor(), + WithInsecureRPCCredentials(), + WithInsecureStreamRPCCredentials(), + }, opentracing.WithTracingDial()...) +} diff --git a/pkg/tools/sandbox/node.go b/pkg/tools/sandbox/node.go index 6ff03baef..c6af3889a 100644 --- a/pkg/tools/sandbox/node.go +++ b/pkg/tools/sandbox/node.go @@ -61,7 +61,7 @@ func (n *Node) NewNSMgr( serveURL = n.domain.supplyURL("nsmgr") } - dialOptions := DefaultDialOptions(generatorFunc) + dialOptions := DialOptions(WithTokenGenerator(generatorFunc)) options := []nsmgr.Option{ nsmgr.WithName(name), @@ -113,7 +113,7 @@ func (n *Node) NewForwarder( } nseClone := nse.Clone() - dialOptions := DefaultDialOptions(generatorFunc) + dialOptions := DialOptions(WithTokenGenerator(generatorFunc)) entry := &EndpointEntry{ Name: nse.Name, @@ -173,7 +173,7 @@ func (n *Node) NewEndpoint( } nseClone := nse.Clone() - dialOptions := DefaultDialOptions(generatorFunc) + dialOptions := DialOptions(WithTokenGenerator(generatorFunc)) entry := &EndpointEntry{ Name: nse.Name, @@ -217,7 +217,7 @@ func (n *Node) NewClient( return client.NewClient( ctx, CloneURL(n.NSMgr.URL), - client.WithDialOptions(DefaultDialOptions(generatorFunc)...), + client.WithDialOptions(DialOptions(WithTokenGenerator(generatorFunc))...), client.WithDialTimeout(DialTimeout), client.WithAuthorizeClient(authorize.NewClient(authorize.Any())), client.WithAdditionalFunctionality(additionalFunctionality...), diff --git a/pkg/tools/sandbox/types.go b/pkg/tools/sandbox/types.go index c6f14e858..e17aba236 100644 --- a/pkg/tools/sandbox/types.go +++ b/pkg/tools/sandbox/types.go @@ -101,5 +101,5 @@ func (d *Domain) NewNSRegistryClient(ctx context.Context, generatorFunc token.Ge } return registryclient.NewNetworkServiceRegistryClient(ctx, registryURL, - registryclient.WithDialOptions(DefaultDialOptions(generatorFunc)...)) + registryclient.WithDialOptions(DialOptions(WithTokenGenerator(generatorFunc))...)) } diff --git a/pkg/tools/sandbox/utils.go b/pkg/tools/sandbox/utils.go index a7e4a006e..71a2db053 100644 --- a/pkg/tools/sandbox/utils.go +++ b/pkg/tools/sandbox/utils.go @@ -21,15 +21,12 @@ import ( "fmt" "time" - "github.com/edwarnicke/grpcfd" "github.com/google/uuid" "google.golang.org/grpc" "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" registryapi "github.com/networkservicemesh/api/pkg/api/registry" - "github.com/networkservicemesh/sdk/pkg/tools/opentracing" "github.com/networkservicemesh/sdk/pkg/tools/token" ) @@ -87,22 +84,6 @@ func GenerateExpiringToken(duration time.Duration) token.GeneratorFunc { } } -// DefaultDialOptions returns default dial options for sandbox testing -func DefaultDialOptions(genTokenFunc token.GeneratorFunc) []grpc.DialOption { - return append([]grpc.DialOption{ - grpc.WithTransportCredentials(grpcfdTransportCredentials(insecure.NewCredentials())), - grpc.WithBlock(), - grpc.WithDefaultCallOptions( - grpc.WaitForReady(true), - grpc.PerRPCCredentials(token.NewPerRPCCredentials(genTokenFunc)), - ), - grpcfd.WithChainStreamInterceptor(), - grpcfd.WithChainUnaryInterceptor(), - WithInsecureRPCCredentials(), - WithInsecureStreamRPCCredentials(), - }, opentracing.WithTracingDial()...) -} - // UniqueName creates unique name with the given prefix func UniqueName(prefix string) string { return fmt.Sprintf("%s-%s", prefix, uuid.New().String())