diff --git a/pkg/networkservice/chains/nsmgr/server_test.go b/pkg/networkservice/chains/nsmgr/server_test.go index e90722bf5..06a54cc7e 100644 --- a/pkg/networkservice/chains/nsmgr/server_test.go +++ b/pkg/networkservice/chains/nsmgr/server_test.go @@ -35,6 +35,7 @@ import ( "github.com/networkservicemesh/api/pkg/api/networkservice" "github.com/networkservicemesh/api/pkg/api/networkservice/mechanisms/cls" kernelmech "github.com/networkservicemesh/api/pkg/api/networkservice/mechanisms/kernel" + "github.com/networkservicemesh/api/pkg/api/networkservice/payload" "github.com/networkservicemesh/api/pkg/api/registry" "github.com/networkservicemesh/sdk/pkg/networkservice/common/clienturl" @@ -507,6 +508,80 @@ func TestNSMGR_PassThroughLocal(t *testing.T) { require.Equal(t, 5*(nsesCount-1)+5, len(conn.Path.PathSegments)) } +func TestNSMGR_ShouldCleanAllClientAndEndpointGoroutines(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + domain := sandbox.NewBuilder(t). + SetNodesCount(1). + SetRegistryProxySupplier(nil). + SetContext(ctx). + Build() + defer domain.Cleanup() + + // We have lazy initialization in some chain elements in both networkservice, registry chains. So registering an + // endpoint and requesting it from client can result in new endless NSMgr goroutines. + + testNSEAndClient(ctx, t, domain, ®istry.NetworkServiceEndpoint{ + Name: "endpoint-init", + NetworkServiceNames: []string{"service-init"}, + }) + + // At this moment all possible endless NSMgr goroutines have been started. So we expect all newly created goroutines + // to be canceled no later than some of these events: + // 1. GRPC request context cancel + // 2. NSC connection close + // 3. NSE unregister + + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + + testNSEAndClient(ctx, t, domain, ®istry.NetworkServiceEndpoint{ + Name: "endpoint-final", + NetworkServiceNames: []string{"service-final"}, + }) +} + +func testNSEAndClient( + ctx context.Context, + t *testing.T, + domain *sandbox.Domain, + nseReg *registry.NetworkServiceEndpoint, +) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + _, err := sandbox.NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken, domain.Nodes[0].NSMgr) + require.NoError(t, err) + + nsc := sandbox.NewClient(ctx, sandbox.GenerateTestToken, domain.Nodes[0].NSMgr.URL) + + conn, err := nsc.Request(ctx, &networkservice.NetworkServiceRequest{ + MechanismPreferences: []*networkservice.Mechanism{ + {Cls: cls.LOCAL, Type: kernelmech.MECHANISM}, + }, + Connection: &networkservice.Connection{ + NetworkService: nseReg.NetworkServiceNames[0], + }, + }) + require.NoError(t, err) + + _, err = nsc.Close(ctx, conn) + require.NoError(t, err) + + _, err = domain.Nodes[0].NSMgr.NetworkServiceEndpointRegistryServer().Unregister(ctx, nseReg) + require.NoError(t, err) + + for _, name := range nseReg.NetworkServiceNames { + _, err = domain.Nodes[0].NSMgr.NetworkServiceRegistryServer().Unregister(ctx, ®istry.NetworkService{ + Name: name, + Payload: payload.IP, + }) + require.NoError(t, err) + } +} + type passThroughClient struct { networkService string networkServiceEndpointName string diff --git a/pkg/registry/common/querycache/nse_client.go b/pkg/registry/common/querycache/nse_client.go index 9f4b1e47d..021304b6c 100644 --- a/pkg/registry/common/querycache/nse_client.go +++ b/pkg/registry/common/querycache/nse_client.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -21,9 +21,10 @@ import ( "context" "github.com/golang/protobuf/ptypes/empty" - "github.com/networkservicemesh/api/pkg/api/registry" "google.golang.org/grpc" + "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/networkservicemesh/sdk/pkg/registry/common/memory" "github.com/networkservicemesh/sdk/pkg/registry/core/next" "github.com/networkservicemesh/sdk/pkg/registry/core/streamchannel" @@ -42,34 +43,47 @@ func (q *queryCacheNSEClient) Find(ctx context.Context, in *registry.NetworkServ if in.Watch { return next.NetworkServiceEndpointRegistryClient(ctx).Find(ctx, in, opts...) } + if nse, ok := q.cache.Load(in.String()); ok { resultCh := make(chan *registry.NetworkServiceEndpoint, 1) resultCh <- nse close(resultCh) return streamchannel.NewNetworkServiceEndpointFindClient(ctx, resultCh), nil } + client, err := next.NetworkServiceEndpointRegistryClient(ctx).Find(ctx, in, opts...) if err != nil { return nil, err } + nses := registry.ReadNetworkServiceEndpointList(client) + resultCh := make(chan *registry.NetworkServiceEndpoint, len(nses)) for _, nse := range nses { + resultCh <- nse + nseQuery := ®istry.NetworkServiceEndpointQuery{ NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{ Name: nse.Name, }, } - resultCh <- nse + key := nseQuery.String() q.cache.Store(key, nse) + go func() { defer q.cache.Delete(key) + + findCtx, findCancel := context.WithCancel(q.chainCtx) + defer findCancel() + nseQuery.Watch = true - stream, err := next.NetworkServiceEndpointRegistryClient(ctx).Find(q.chainCtx, nseQuery, opts...) + + stream, err := next.NetworkServiceEndpointRegistryClient(ctx).Find(findCtx, nseQuery, opts...) if err != nil { return } + for update, err := stream.Recv(); err == nil; update, err = stream.Recv() { if update.Name != nseQuery.NetworkServiceEndpoint.Name { continue @@ -82,6 +96,7 @@ func (q *queryCacheNSEClient) Find(ctx context.Context, in *registry.NetworkServ }() } close(resultCh) + return streamchannel.NewNetworkServiceEndpointFindClient(ctx, resultCh), nil } diff --git a/pkg/registry/common/querycache/nse_client_test.go b/pkg/registry/common/querycache/nse_client_test.go index c476757c9..f664f1cb5 100644 --- a/pkg/registry/common/querycache/nse_client_test.go +++ b/pkg/registry/common/querycache/nse_client_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -106,3 +106,52 @@ func Test_QueryCacheServer_ShouldCacheNSEs(t *testing.T) { }, time.Second, time.Second/10) } } + +func Test_QueryCacheServer_ShouldCleanupGoroutinesOnNSEUnregister(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + mem := memory.NewNetworkServiceEndpointRegistryServer() + + reg, err := func() (*registry.NetworkServiceEndpoint, error) { + registerCtx, registerCancel := context.WithCancel(ctx) + defer registerCancel() + + return mem.Register(registerCtx, ®istry.NetworkServiceEndpoint{ + Name: "nse-1", + }) + }() + require.NoError(t, err) + + client := next.NewNetworkServiceEndpointRegistryClient( + querycache.NewClient(ctx), + adapters.NetworkServiceEndpointServerToClient(mem), + ) + + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + + // 1. Find + findCtx, findCancel := context.WithCancel(ctx) + + _, err = client.Find(findCtx, ®istry.NetworkServiceEndpointQuery{ + NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{ + Name: reg.Name, + }, + }) + require.NoError(t, err) + + findCancel() + + // 2. Wait a bit for the (cache -> registry) stream to start + <-time.After(1 * time.Millisecond) + + // 3. Unregister + unregisterCtx, unregisterCancel := context.WithCancel(ctx) + + _, err = mem.Unregister(unregisterCtx, reg) + require.NoError(t, err) + + unregisterCancel() +} diff --git a/pkg/tools/sandbox/utils.go b/pkg/tools/sandbox/utils.go index b40e3af46..0d7515e20 100644 --- a/pkg/tools/sandbox/utils.go +++ b/pkg/tools/sandbox/utils.go @@ -22,14 +22,14 @@ import ( "net/url" "time" - "github.com/networkservicemesh/sdk/pkg/tools/logger" + "google.golang.org/protobuf/types/known/timestamppb" - "github.com/golang/protobuf/ptypes" "github.com/google/uuid" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "github.com/networkservicemesh/api/pkg/api/networkservice" + "github.com/networkservicemesh/api/pkg/api/networkservice/payload" "github.com/networkservicemesh/api/pkg/api/registry" "github.com/networkservicemesh/sdk/pkg/networkservice/chains/client" @@ -38,6 +38,7 @@ import ( "github.com/networkservicemesh/sdk/pkg/networkservice/common/authorize" "github.com/networkservicemesh/sdk/pkg/networkservice/common/clienturl" "github.com/networkservicemesh/sdk/pkg/tools/clienturlctx" + "github.com/networkservicemesh/sdk/pkg/tools/logger" "github.com/networkservicemesh/sdk/pkg/tools/opentracing" "github.com/networkservicemesh/sdk/pkg/tools/token" ) @@ -50,7 +51,9 @@ func GenerateTestToken(_ credentials.AuthInfo) (tokenValue string, expireTime ti // NewEndpoint creates endpoint and registers it into passed NSMgr. func NewEndpoint(ctx context.Context, nse *registry.NetworkServiceEndpoint, generatorFunc token.GeneratorFunc, mgr nsmgr.Nsmgr, additionalFunctionality ...networkservice.NetworkServiceServer) (*EndpointEntry, error) { ep := endpoint.NewServer(ctx, nse.Name, authorize.NewServer(), generatorFunc, additionalFunctionality...) + ctx = logger.WithLog(ctx) + u := &url.URL{Scheme: "tcp", Host: "127.0.0.1:0"} var err error if nse.Url != "" { @@ -60,26 +63,33 @@ func NewEndpoint(ctx context.Context, nse *registry.NetworkServiceEndpoint, gene } } serve(ctx, u, ep.Register) + if nse.Url == "" { nse.Url = u.String() } if nse.ExpirationTime == nil { - deadline := time.Now().Add(time.Hour) - expirationTime, err := ptypes.TimestampProto(deadline) - if err != nil { - return nil, err - } - nse.ExpirationTime = expirationTime + nse.ExpirationTime = timestamppb.New(time.Now().Add(time.Hour)) } - if _, err := mgr.NetworkServiceEndpointRegistryServer().Register(ctx, nse); err != nil { + + var reg *registry.NetworkServiceEndpoint + if reg, err = mgr.NetworkServiceEndpointRegistryServer().Register(ctx, nse); err != nil { return nil, err } + + nse.Name = reg.Name + nse.ExpirationTime = reg.ExpirationTime + for _, service := range nse.NetworkServiceNames { - if _, err := mgr.NetworkServiceRegistryServer().Register(ctx, ®istry.NetworkService{Name: service, Payload: "IP"}); err != nil { + if _, err := mgr.NetworkServiceRegistryServer().Register(ctx, ®istry.NetworkService{ + Name: service, + Payload: payload.IP, + }); err != nil { return nil, err } } + logger.Log(ctx).Infof("Started listen endpoint %v on %v.", nse.Name, u.String()) + return &EndpointEntry{Endpoint: ep, URL: u}, nil }