diff --git a/pkg/networkservice/chains/nsmgr/heal_test.go b/pkg/networkservice/chains/nsmgr/heal_test.go index 509947151..ac54b5b22 100644 --- a/pkg/networkservice/chains/nsmgr/heal_test.go +++ b/pkg/networkservice/chains/nsmgr/heal_test.go @@ -1,6 +1,6 @@ // Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // -// Copyright (c) 2023 Cisco and/or its affiliates. +// Copyright (c) 2023-2024 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -859,9 +859,8 @@ func TestNSMGR_RefreshFailed_ControlPlaneBroken(t *testing.T) { ), ) - requestCtx, requestCalcel := context.WithTimeout(ctx, time.Second) - requestCtx = clock.WithClock(requestCtx, clk) - defer requestCalcel() + requestCtx, requestCancel := context.WithTimeout(ctx, time.Second) + defer requestCancel() // allow the first Request syncCh <- struct{}{} @@ -871,13 +870,30 @@ func TestNSMGR_RefreshFailed_ControlPlaneBroken(t *testing.T) { // refresh interval in this test is expected to be 3 minutes and a few milliseconds clk.Add(time.Second * 190) + // start goroutine that will update mock clock every 50 ms. It is needed for retry refresh + go func() { + tickerDuration := time.Millisecond * 50 + tickCh := time.Tick(tickerDuration) + for { + select { + case <-ctx.Done(): + return + case <-tickCh: + clk.Add(tickerDuration) + } + } + }() - // kill the forwarder during the healing Request (it is stopped by syncCh). Then continue - the healing process will fail. - for _, forwarder := range domain.Nodes[0].Forwarders { + // kill the forwarder during the refresh (it is stopped by syncCh). Then continue - the refresh will fail. + for idx := range domain.Nodes[0].Forwarders { + forwarder := domain.Nodes[0].Forwarders[idx] forwarder.Cancel() - break + // wait until the forwarder dies + require.Eventually(t, func() bool { + return sandbox.CheckURLFree(forwarder.URL) + }, timeout, tick) } - syncCh <- struct{}{} + close(syncCh) // create a new forwarder and allow the healing Request forwarderReg := ®istry.NetworkServiceEndpoint{ @@ -885,7 +901,6 @@ func TestNSMGR_RefreshFailed_ControlPlaneBroken(t *testing.T) { NetworkServiceNames: []string{"forwarder"}, } domain.Nodes[0].NewForwarder(ctx, forwarderReg, sandbox.GenerateTestToken) - syncCh <- struct{}{} // wait till Request reached NSE require.Eventually(t, func() bool { diff --git a/pkg/networkservice/chains/nsmgr/single_test.go b/pkg/networkservice/chains/nsmgr/single_test.go index e34e48787..53afb1458 100644 --- a/pkg/networkservice/chains/nsmgr/single_test.go +++ b/pkg/networkservice/chains/nsmgr/single_test.go @@ -1,6 +1,6 @@ // Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // -// Copyright (c) 2023 Cisco and/or its affiliates. +// Copyright (c) 2023-2024 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -48,7 +48,6 @@ import ( "github.com/networkservicemesh/sdk/pkg/networkservice/common/authorize" "github.com/networkservicemesh/sdk/pkg/networkservice/common/begin" "github.com/networkservicemesh/sdk/pkg/networkservice/common/excludedprefixes" - "github.com/networkservicemesh/sdk/pkg/networkservice/common/heal" "github.com/networkservicemesh/sdk/pkg/networkservice/ipam/point2pointipam" "github.com/networkservicemesh/sdk/pkg/networkservice/utils/checks/checkcontext" "github.com/networkservicemesh/sdk/pkg/networkservice/utils/checks/checkrequest" @@ -625,7 +624,7 @@ func Test_RestartDuringRefresh(t *testing.T) { require.NoError(t, err) var countServer count.Server - var countClint count.Client + var countClient count.Client var m sync.Once var clientFactory begin.EventFactory var destroyFwd atomic.Bool @@ -636,16 +635,21 @@ func Test_RestartDuringRefresh(t *testing.T) { NetworkServiceNames: []string{"ns"}, }, sandbox.GenerateTestToken, &countServer, checkrequest.NewServer(t, func(t *testing.T, nsr *networkservice.NetworkServiceRequest) { if destroyFwd.Load() { - e.AsyncExec(func() { - for _, fwd := range domain.Nodes[0].Forwarders { - fwd.Cancel() + <-e.AsyncExec(func() { + for idx := range domain.Nodes[0].Forwarders { + forwarder := domain.Nodes[0].Forwarders[idx] + forwarder.Cancel() + // wait until the forwarder dies + require.Eventually(t, func() bool { + return sandbox.CheckURLFree(forwarder.URL) + }, timeout, tick) } }) } })) var nsc = domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken, client.WithAdditionalFunctionality( - &countClint, + &countClient, checkcontext.NewClient(t, func(t *testing.T, ctx context.Context) { m.Do(func() { clientFactory = begin.FromContext(ctx) @@ -660,7 +664,6 @@ func Test_RestartDuringRefresh(t *testing.T) { }) } }), - heal.NewClient(ctx), )) _, err = nsc.Request(ctx, &networkservice.NetworkServiceRequest{ @@ -673,16 +676,14 @@ func Test_RestartDuringRefresh(t *testing.T) { <-clientFactory.Request() require.Equal(t, 2, countServer.Requests()) require.Never(t, func() bool { return countServer.Requests() > 2 }, time.Second/2, time.Second/20) - destroyFwd.Store(true) for i := 0; i < 15; i++ { - var cs = countServer.Requests() destroyFwd.Store(true) err = <-clientFactory.Request() require.Error(t, err) + var cc = countClient.BackwardRequests() destroyFwd.Store(false) - var cc = countClint.Requests() - require.Eventually(t, func() bool { return cs < countServer.Requests() }, time.Second*2, time.Second/20) - require.Eventually(t, func() bool { return cc < countClint.Requests() }, time.Second*2, time.Second/20) + // Heal must be successful eventually + require.Eventually(t, func() bool { return cc < countClient.BackwardRequests() }, time.Second*2, time.Second/20) } } diff --git a/pkg/networkservice/common/connect/server_test.go b/pkg/networkservice/common/connect/server_test.go index 9bafaa9e7..77672ef4d 100644 --- a/pkg/networkservice/common/connect/server_test.go +++ b/pkg/networkservice/common/connect/server_test.go @@ -1,5 +1,7 @@ // Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // +// Copyright (c) 2024 Cisco and/or its affiliates. +// // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -197,7 +199,7 @@ func TestConnectServer_RequestParallel(t *testing.T) { connect.NewServer( next.NewNetworkServiceClient( dial.NewClient(context.Background(), - dial.WithDialTimeout(time.Second), + dial.WithDialTimeout(time.Second*2), dial.WithDialOptions(grpc.WithTransportCredentials(insecure.NewCredentials())), ), serverClient, diff --git a/pkg/networkservice/common/refresh/client.go b/pkg/networkservice/common/refresh/client.go index 191cc5fe9..c0a0d6ef0 100644 --- a/pkg/networkservice/common/refresh/client.go +++ b/pkg/networkservice/common/refresh/client.go @@ -1,6 +1,6 @@ -// Copyright (c) 2020 Cisco Systems, Inc. +// Copyright (c) 2020-2024 Cisco Systems, Inc. // -// Copyright (c) 2020-2022 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2024 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -68,18 +68,17 @@ func (t *refreshClient) Request(ctx context.Context, request *networkservice.Net store(ctx, metadata.IsClient(t), cancel) eventFactory := begin.FromContext(ctx) - clockTime := clock.FromContext(ctx) // Create the afterCh *outside* the go routine. This must be done to avoid picking up a later 'now' // from mockClock in testing - afterTicker := clockTime.Ticker(refreshAfter) + afterCh := clock.FromContext(ctx).After(refreshAfter) go func() { - defer afterTicker.Stop() for { select { case <-cancelCtx.Done(): return - case <-afterTicker.C(): + case <-afterCh: if err := <-eventFactory.Request(begin.CancelContext(cancelCtx)); err != nil { + afterCh = clock.FromContext(ctx).After(time.Millisecond * 200) logger.Warnf("refresh failed: %s", err.Error()) continue } diff --git a/pkg/networkservice/common/refresh/client_test.go b/pkg/networkservice/common/refresh/client_test.go index 0a571ac12..ee16d0914 100644 --- a/pkg/networkservice/common/refresh/client_test.go +++ b/pkg/networkservice/common/refresh/client_test.go @@ -1,5 +1,7 @@ // Copyright (c) 2020-2021 Doc.ai and/or its affiliates. // +// Copyright (c) 2024 Cisco and/or its affiliates. +// // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -48,6 +50,7 @@ import ( const ( expireTimeout = 15 * time.Minute + retryTimeout = 200 * time.Millisecond testWait = 100 * time.Millisecond testTick = testWait / 100 @@ -69,19 +72,19 @@ func testTokenFuncWithTimeout(clockTime clock.Clock, timeout time.Duration) toke } } -type captureTickerDuration struct { +type captureAfterDuration struct { *clockmock.Mock - tickerDuration time.Duration + afterDuration time.Duration } -func (m *captureTickerDuration) Ticker(d time.Duration) clock.Ticker { - m.tickerDuration = d - return m.Mock.Ticker(d) +func (m *captureAfterDuration) After(d time.Duration) <-chan time.Time { + m.afterDuration = d + return m.Mock.After(d) } -func (m *captureTickerDuration) Reset(t time.Time) { - m.tickerDuration = 0 +func (m *captureAfterDuration) Reset(t time.Time) { + m.afterDuration = 0 m.Set(t) } @@ -355,7 +358,7 @@ func TestRefreshClient_CalculatesShortestTokenTimeout(t *testing.T) { timeNow := time.Date(2009, 11, 10, 23, 0, 0, 0, time.Local) - clockMock := captureTickerDuration{ + clockMock := captureAfterDuration{ Mock: clockmock.New(ctx), } @@ -389,14 +392,14 @@ func TestRefreshClient_CalculatesShortestTokenTimeout(t *testing.T) { }) require.NoError(t, err) - require.Less(t, clockMock.tickerDuration, testDataElement.ExpectedRefreshTimeout+timeoutDelta) - require.Greater(t, clockMock.tickerDuration, testDataElement.ExpectedRefreshTimeout-timeoutDelta) + require.Less(t, clockMock.afterDuration, testDataElement.ExpectedRefreshTimeout+timeoutDelta) + require.Greater(t, clockMock.afterDuration, testDataElement.ExpectedRefreshTimeout-timeoutDelta) } require.Equal(t, countClient.Requests(), len(testData)) } -func TestRefreshClient_RefreshOnRefreshFailure(t *testing.T) { +func TestRefreshClient_RetryOnRefreshFailure(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) ctx, cancel := context.WithCancel(context.Background()) @@ -422,7 +425,37 @@ func TestRefreshClient_RefreshOnRefreshFailure(t *testing.T) { require.Eventually(t, cloneClient.validator(2), testWait, testTick) - clockMock.Add(expireTimeout) + clockMock.Add(retryTimeout) require.Eventually(t, cloneClient.validator(3), testWait, testTick) } + +func TestRefreshClient_NoRetryOnRefreshSuccess(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + clockMock := clockmock.New(ctx) + + cloneClient := &countClient{ + t: t, + } + client := testClient(ctx, testTokenFunc(clockMock), + clockMock, + cloneClient, + ) + + _, err := client.Request(ctx, &networkservice.NetworkServiceRequest{ + Connection: new(networkservice.Connection), + }) + require.NoError(t, err) + + clockMock.Add(expireTimeout) + + require.Eventually(t, cloneClient.validator(2), testWait, testTick) + + clockMock.Add(retryTimeout) + + require.Never(t, cloneClient.validator(3), testWait, testTick) +} diff --git a/pkg/networkservice/utils/count/client.go b/pkg/networkservice/utils/count/client.go index d24e09a23..f5f2fc535 100644 --- a/pkg/networkservice/utils/count/client.go +++ b/pkg/networkservice/utils/count/client.go @@ -1,5 +1,7 @@ // Copyright (c) 2020-2021 Doc.ai and/or its affiliates. // +// Copyright (c) 2024 Cisco and/or its affiliates. +// // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -31,67 +33,131 @@ import ( // Client is a client type for counting Requests/Closes type Client struct { - totalRequests, totalCloses int32 - requests, closes map[string]int32 - mu sync.Mutex + totalForwardRequests, totalForwardCloses int32 + totalBackwardRequests, totalBackwardCloses int32 + forwardRequests, forwardCloses map[string]int32 + backwardRequests, backwardCloses map[string]int32 + forwardMu, backwardMu sync.Mutex } // Request performs request and increments requests count func (c *Client) Request(ctx context.Context, request *networkservice.NetworkServiceRequest, opts ...grpc.CallOption) (*networkservice.Connection, error) { - c.mu.Lock() - defer c.mu.Unlock() + /* Forward pass*/ + c.forwardMu.Lock() + atomic.AddInt32(&c.totalForwardRequests, 1) + if c.forwardRequests == nil { + c.forwardRequests = make(map[string]int32) + } + c.forwardRequests[request.GetConnection().GetId()]++ + c.forwardMu.Unlock() + + /* Request */ + conn, err := next.Client(ctx).Request(ctx, request, opts...) + if err != nil { + return conn, err + } - atomic.AddInt32(&c.totalRequests, 1) - if c.requests == nil { - c.requests = make(map[string]int32) + /* Backward pass*/ + c.backwardMu.Lock() + atomic.AddInt32(&c.totalBackwardRequests, 1) + if c.backwardRequests == nil { + c.backwardRequests = make(map[string]int32) } - c.requests[request.GetConnection().GetId()]++ + c.backwardRequests[conn.GetId()]++ + c.backwardMu.Unlock() - return next.Client(ctx).Request(ctx, request, opts...) + return conn, err } // Close performs close and increments closes count func (c *Client) Close(ctx context.Context, connection *networkservice.Connection, opts ...grpc.CallOption) (*empty.Empty, error) { - c.mu.Lock() - defer c.mu.Unlock() + /* Forward pass*/ + c.forwardMu.Lock() + atomic.AddInt32(&c.totalForwardCloses, 1) + if c.forwardCloses == nil { + c.forwardCloses = make(map[string]int32) + } + c.forwardCloses[connection.GetId()]++ + c.forwardMu.Unlock() - atomic.AddInt32(&c.totalCloses, 1) - if c.closes == nil { - c.closes = make(map[string]int32) + /* Close */ + r, err := next.Client(ctx).Close(ctx, connection, opts...) + if err != nil { + return r, err } - c.closes[connection.GetId()]++ - return next.Client(ctx).Close(ctx, connection, opts...) + /* Backward pass*/ + c.backwardMu.Lock() + atomic.AddInt32(&c.totalBackwardCloses, 1) + if c.backwardCloses == nil { + c.backwardCloses = make(map[string]int32) + } + c.backwardCloses[connection.GetId()]++ + c.backwardMu.Unlock() + + return r, err } -// Requests returns requests count +// Requests returns forward requests count func (c *Client) Requests() int { - return int(atomic.LoadInt32(&c.totalRequests)) + return int(atomic.LoadInt32(&c.totalForwardRequests)) } -// Closes returns closes count +// Closes returns forward closes count func (c *Client) Closes() int { - return int(atomic.LoadInt32(&c.totalCloses)) + return int(atomic.LoadInt32(&c.totalForwardCloses)) +} + +// BackwardRequests returns backward requests count +func (c *Client) BackwardRequests() int { + return int(atomic.LoadInt32(&c.totalBackwardRequests)) +} + +// BackwardCloses returns backward closes count +func (c *Client) BackwardCloses() int { + return int(atomic.LoadInt32(&c.totalBackwardCloses)) } -// UniqueRequests returns unique requests count +// UniqueRequests returns unique forward requests count func (c *Client) UniqueRequests() int { - c.mu.Lock() - defer c.mu.Unlock() + c.forwardMu.Lock() + defer c.forwardMu.Unlock() - if c.requests == nil { + if c.forwardRequests == nil { return 0 } - return len(c.requests) + return len(c.forwardRequests) } -// UniqueCloses returns unique closes count +// UniqueCloses returns unique forward closes count func (c *Client) UniqueCloses() int { - c.mu.Lock() - defer c.mu.Unlock() + c.forwardMu.Lock() + defer c.forwardMu.Unlock() + + if c.forwardCloses == nil { + return 0 + } + return len(c.forwardCloses) +} + +// UniqueBackwardRequests returns unique backward requests count +func (c *Client) UniqueBackwardRequests() int { + c.backwardMu.Lock() + defer c.backwardMu.Unlock() + + if c.backwardRequests == nil { + return 0 + } + return len(c.backwardRequests) +} + +// UniqueBackwardCloses returns unique backward closes count +func (c *Client) UniqueBackwardCloses() int { + c.backwardMu.Lock() + defer c.backwardMu.Unlock() - if c.closes == nil { + if c.backwardCloses == nil { return 0 } - return len(c.closes) + return len(c.backwardCloses) } diff --git a/pkg/networkservice/utils/count/server.go b/pkg/networkservice/utils/count/server.go index be27f515c..35f260441 100644 --- a/pkg/networkservice/utils/count/server.go +++ b/pkg/networkservice/utils/count/server.go @@ -1,5 +1,7 @@ // Copyright (c) 2020-2021 Doc.ai and/or its affiliates. // +// Copyright (c) 2024 Cisco and/or its affiliates. +// // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -38,13 +40,13 @@ type Server struct { // Request performs request and increments requests count func (s *Server) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) { s.mu.Lock() - defer s.mu.Unlock() atomic.AddInt32(&s.totalRequests, 1) if s.requests == nil { s.requests = make(map[string]int32) } s.requests[request.GetConnection().GetId()]++ + s.mu.Unlock() return next.Server(ctx).Request(ctx, request) } @@ -52,13 +54,13 @@ func (s *Server) Request(ctx context.Context, request *networkservice.NetworkSer // Close performs close and increments closes count func (s *Server) Close(ctx context.Context, connection *networkservice.Connection) (*empty.Empty, error) { s.mu.Lock() - defer s.mu.Unlock() atomic.AddInt32(&s.TotalCloses, 1) if s.closes == nil { s.closes = make(map[string]int32) } s.closes[connection.GetId()]++ + s.mu.Unlock() return next.Server(ctx).Close(ctx, connection) }