diff --git a/pkg/networkservice/chains/nsmgr/server_heal_test.go b/pkg/networkservice/chains/nsmgr/server_heal_test.go index 81029d8c58..a73fd48f2e 100644 --- a/pkg/networkservice/chains/nsmgr/server_heal_test.go +++ b/pkg/networkservice/chains/nsmgr/server_heal_test.go @@ -72,7 +72,7 @@ func testNSMGRHealEndpoint(t *testing.T, restored bool) { nseCtx, nseCtxCancel := context.WithTimeout(context.Background(), time.Second) defer nseCtxCancel() - nse, err := domain.Nodes[0].NewEndpoint(nseCtx, nseReg, sandbox.GenerateExpiringToken(time.Second), counter) + nse, err := domain.Nodes[0].NewEndpoint(nseCtx, nseReg, sandbox.GenerateExpiringToken(ctx, time.Second), counter) require.NoError(t, err) request := &networkservice.NetworkServiceRequest{ @@ -86,7 +86,7 @@ func testNSMGRHealEndpoint(t *testing.T, restored bool) { }, } - nsc := domain.Nodes[1].NewClient(ctx, sandbox.GenerateTestToken) + nsc := domain.Nodes[1].NewClient(ctx, sandbox.GenerateTestToken(ctx)) conn, err := nsc.Request(ctx, request.Clone()) require.NoError(t, err) @@ -106,7 +106,7 @@ func testNSMGRHealEndpoint(t *testing.T, restored bool) { if restored { nseReg2.Url = nse.URL.String() } - _, err = domain.Nodes[0].NewEndpoint(ctx, nseReg2, sandbox.GenerateTestToken, counter) + _, err = domain.Nodes[0].NewEndpoint(ctx, nseReg2, sandbox.GenerateTestToken(ctx), counter) require.NoError(t, err) // Wait NSE expired and reconnecting to the new NSE @@ -134,7 +134,7 @@ func TestNSMGR_HealLocalForwarder(t *testing.T) { nil, { ForwarderCtx: forwarderCtx, - ForwarderGenerateTokenFunc: sandbox.GenerateExpiringToken(time.Second), + ForwarderGenerateTokenFunc: sandbox.GenerateExpiringToken(forwarderCtx, time.Second), }, } @@ -149,7 +149,7 @@ func TestNSMGR_HealLocalForwarderRestored(t *testing.T) { nil, { ForwarderCtx: forwarderCtx, - ForwarderGenerateTokenFunc: sandbox.GenerateExpiringToken(time.Second), + ForwarderGenerateTokenFunc: sandbox.GenerateExpiringToken(forwarderCtx, time.Second), }, } @@ -163,7 +163,7 @@ func TestNSMGR_HealRemoteForwarder(t *testing.T) { customConfig := []*sandbox.NodeConfig{ { ForwarderCtx: forwarderCtx, - ForwarderGenerateTokenFunc: sandbox.GenerateExpiringToken(time.Second), + ForwarderGenerateTokenFunc: sandbox.GenerateExpiringToken(forwarderCtx, time.Second), }, } @@ -177,7 +177,7 @@ func TestNSMGR_HealRemoteForwarderRestored(t *testing.T) { customConfig := []*sandbox.NodeConfig{ { ForwarderCtx: forwarderCtx, - ForwarderGenerateTokenFunc: sandbox.GenerateExpiringToken(time.Second), + ForwarderGenerateTokenFunc: sandbox.GenerateExpiringToken(forwarderCtx, time.Second), }, } @@ -204,7 +204,7 @@ func testNSMGRHealForwarder(t *testing.T, nodeNum int, restored bool, customConf } counter := &counterServer{} - _, err := domain.Nodes[0].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken, counter) + _, err := domain.Nodes[0].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken(ctx), counter) require.NoError(t, err) request := &networkservice.NetworkServiceRequest{ @@ -218,7 +218,7 @@ func testNSMGRHealForwarder(t *testing.T, nodeNum int, restored bool, customConf }, } - nsc := domain.Nodes[1].NewClient(ctx, sandbox.GenerateTestToken) + nsc := domain.Nodes[1].NewClient(ctx, sandbox.GenerateTestToken(ctx)) conn, err := nsc.Request(ctx, request.Clone()) require.NoError(t, err) @@ -239,7 +239,7 @@ func testNSMGRHealForwarder(t *testing.T, nodeNum int, restored bool, customConf if restored { forwarderReg.Url = domain.Nodes[nodeNum].Forwarder[0].URL.String() } - _, err = domain.Nodes[nodeNum].NewForwarder(ctx, forwarderReg, sandbox.GenerateTestToken) + _, err = domain.Nodes[nodeNum].NewForwarder(ctx, forwarderReg, sandbox.GenerateTestToken(ctx)) require.NoError(t, err) // Wait Cross NSE expired and reconnecting through the new Cross NSE @@ -285,9 +285,9 @@ func TestNSMGR_HealRemoteNSMgrRestored(t *testing.T) { customConfig := []*sandbox.NodeConfig{ { NsmgrCtx: nsmgrCtx, - NsmgrGenerateTokenFunc: sandbox.GenerateExpiringToken(time.Second), + NsmgrGenerateTokenFunc: sandbox.GenerateExpiringToken(nsmgrCtx, time.Second), ForwarderCtx: nsmgrCtx, - ForwarderGenerateTokenFunc: sandbox.GenerateExpiringToken(time.Second), + ForwarderGenerateTokenFunc: sandbox.GenerateExpiringToken(nsmgrCtx, time.Second), }, } @@ -314,7 +314,7 @@ func testNSMGRHealNSMgr(t *testing.T, nodeNum int, customConfig []*sandbox.NodeC } counter := &counterServer{} - nse, err := domain.Nodes[0].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken, counter) + nse, err := domain.Nodes[0].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken(ctx), counter) require.NoError(t, err) request := &networkservice.NetworkServiceRequest{ @@ -328,7 +328,7 @@ func testNSMGRHealNSMgr(t *testing.T, nodeNum int, customConfig []*sandbox.NodeC }, } - nsc := domain.Nodes[1].NewClient(ctx, sandbox.GenerateTestToken) + nsc := domain.Nodes[1].NewClient(ctx, sandbox.GenerateTestToken(ctx)) conn, err := nsc.Request(ctx, request.Clone()) require.NoError(t, err) @@ -342,14 +342,14 @@ func testNSMGRHealNSMgr(t *testing.T, nodeNum int, customConfig []*sandbox.NodeC require.Eventually(t, checkURLFree(domain.Nodes[nodeNum].NSMgr.URL.Host), timeout, tick) - restoredNSMgrEntry, restoredNSMgrResources := builder.NewNSMgr(ctx, domain.Nodes[nodeNum], domain.Nodes[nodeNum].NSMgr.URL.Host, domain.Registry.URL, sandbox.GenerateTestToken) + restoredNSMgrEntry, restoredNSMgrResources := builder.NewNSMgr(ctx, domain.Nodes[nodeNum], domain.Nodes[nodeNum].NSMgr.URL.Host, domain.Registry.URL, sandbox.GenerateTestToken(ctx)) domain.Nodes[nodeNum].NSMgr = restoredNSMgrEntry domain.AddResources(restoredNSMgrResources) forwarderReg := ®istry.NetworkServiceEndpoint{ Name: "forwarder-restored", } - _, err = domain.Nodes[nodeNum].NewForwarder(ctx, forwarderReg, sandbox.GenerateTestToken) + _, err = domain.Nodes[nodeNum].NewForwarder(ctx, forwarderReg, sandbox.GenerateTestToken(ctx)) require.NoError(t, err) nseReg.Url = nse.URL.String() @@ -383,9 +383,9 @@ func TestNSMGR_HealRemoteNSMgr(t *testing.T) { customConfig := []*sandbox.NodeConfig{ { NsmgrCtx: nsmgrCtx, - NsmgrGenerateTokenFunc: sandbox.GenerateExpiringToken(time.Second), + NsmgrGenerateTokenFunc: sandbox.GenerateExpiringToken(nsmgrCtx, time.Second), ForwarderCtx: nsmgrCtx, - ForwarderGenerateTokenFunc: sandbox.GenerateExpiringToken(time.Second), + ForwarderGenerateTokenFunc: sandbox.GenerateExpiringToken(nsmgrCtx, time.Second), }, } @@ -407,7 +407,7 @@ func TestNSMGR_HealRemoteNSMgr(t *testing.T) { } counter := &counterServer{} - _, err := domain.Nodes[0].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken, counter) + _, err := domain.Nodes[0].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken(ctx), counter) require.NoError(t, err) request := &networkservice.NetworkServiceRequest{ @@ -421,7 +421,7 @@ func TestNSMGR_HealRemoteNSMgr(t *testing.T) { }, } - nsc := domain.Nodes[1].NewClient(ctx, sandbox.GenerateTestToken) + nsc := domain.Nodes[1].NewClient(ctx, sandbox.GenerateTestToken(ctx)) conn, err := nsc.Request(ctx, request.Clone()) require.NoError(t, err) @@ -435,7 +435,7 @@ func TestNSMGR_HealRemoteNSMgr(t *testing.T) { Name: "final-endpoint-2", NetworkServiceNames: []string{"my-service-remote"}, } - _, err = domain.Nodes[2].NewEndpoint(ctx, nseReg2, sandbox.GenerateTestToken, counter) + _, err = domain.Nodes[2].NewEndpoint(ctx, nseReg2, sandbox.GenerateTestToken(ctx), counter) require.NoError(t, err) // Wait Cross NSE expired and reconnecting through the new Cross NSE diff --git a/pkg/networkservice/chains/nsmgr/server_test.go b/pkg/networkservice/chains/nsmgr/server_test.go index 661265393e..b7f5df71ff 100644 --- a/pkg/networkservice/chains/nsmgr/server_test.go +++ b/pkg/networkservice/chains/nsmgr/server_test.go @@ -80,10 +80,10 @@ func TestNSMGR_RemoteUsecase_Parallel(t *testing.T) { Name: "final-endpoint", NetworkServiceNames: []string{"my-service-remote"}, } - _, err := domain.Nodes[0].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken, counter) + _, err := domain.Nodes[0].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken(ctx), counter) require.NoError(t, err) }() - nsc := domain.Nodes[1].NewClient(ctx, sandbox.GenerateTestToken) + nsc := domain.Nodes[1].NewClient(ctx, sandbox.GenerateTestToken(ctx)) conn, err := nsc.Request(ctx, request.Clone()) require.NoError(t, err) @@ -155,12 +155,12 @@ func TestNSMGR_SelectsRestartingEndpoint(t *testing.T) { // 2. Postpone endpoint start time.AfterFunc(time.Second, func() { s := grpc.NewServer() - endpoint.NewServer(ctx, sandbox.GenerateTestToken).Register(s) + endpoint.NewServer(ctx, sandbox.GenerateTestToken(ctx)).Register(s) _ = s.Serve(netListener) }) // 3. Create client and request endpoint - nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken) + nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken(ctx)) conn, err := nsc.Request(ctx, request.Clone()) require.NoError(t, err) @@ -203,7 +203,7 @@ func TestNSMGR_RemoteUsecase_BusyEndpoints(t *testing.T) { Name: "final-endpoint-" + strconv.Itoa(id), NetworkServiceNames: []string{"my-service-remote"}, } - _, err := domain.Nodes[1].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken, newBusyEndpoint()) + _, err := domain.Nodes[1].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken(ctx), newBusyEndpoint()) require.NoError(t, err) wg.Done() }(i) @@ -215,10 +215,10 @@ func TestNSMGR_RemoteUsecase_BusyEndpoints(t *testing.T) { Name: "final-endpoint-3", NetworkServiceNames: []string{"my-service-remote"}, } - _, err := domain.Nodes[1].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken, counter) + _, err := domain.Nodes[1].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken(ctx), counter) require.NoError(t, err) }() - nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken) + nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken(ctx)) conn, err := nsc.Request(ctx, request.Clone()) require.NoError(t, err) @@ -257,7 +257,7 @@ func TestNSMGR_RemoteUsecase(t *testing.T) { } counter := &counterServer{} - _, err := domain.Nodes[0].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken, counter) + _, err := domain.Nodes[0].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken(ctx), counter) require.NoError(t, err) request := &networkservice.NetworkServiceRequest{ @@ -271,7 +271,7 @@ func TestNSMGR_RemoteUsecase(t *testing.T) { }, } - nsc := domain.Nodes[1].NewClient(ctx, sandbox.GenerateTestToken) + nsc := domain.Nodes[1].NewClient(ctx, sandbox.GenerateTestToken(ctx)) conn, err := nsc.Request(ctx, request.Clone()) require.NoError(t, err) @@ -318,10 +318,10 @@ func TestNSMGR_ConnectToDeadNSE(t *testing.T) { counter := &counterServer{} nseCtx, killNse := context.WithCancel(ctx) - _, err := domain.Nodes[0].NewEndpoint(nseCtx, nseReg, sandbox.GenerateTestToken, counter) + _, err := domain.Nodes[0].NewEndpoint(nseCtx, nseReg, sandbox.GenerateTestToken(ctx), counter) require.NoError(t, err) - nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken) + nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken(ctx)) request := &networkservice.NetworkServiceRequest{ MechanismPreferences: []*networkservice.Mechanism{ @@ -371,10 +371,10 @@ func TestNSMGR_LocalUsecase(t *testing.T) { } counter := &counterServer{} - _, err := domain.Nodes[0].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken, counter) + _, err := domain.Nodes[0].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken(ctx), counter) require.NoError(t, err) - nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken) + nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken(ctx)) request := &networkservice.NetworkServiceRequest{ MechanismPreferences: []*networkservice.Mechanism{ @@ -433,10 +433,10 @@ func TestNSMGR_PassThroughRemote(t *testing.T) { chain.NewNetworkServiceServer( clienturl.NewServer(domain.Nodes[i].NSMgr.URL), connect.NewServer(ctx, - sandbox.NewCrossConnectClientFactory(sandbox.GenerateTestToken, + sandbox.NewCrossConnectClientFactory(sandbox.GenerateTestToken(ctx), newPassTroughClient(fmt.Sprintf("my-service-remote-%v", i-1)), kernel.NewClient()), - connect.WithDialOptions(sandbox.DefaultDialOptions(sandbox.GenerateTestToken)...), + connect.WithDialOptions(sandbox.DefaultDialOptions(ctx, sandbox.GenerateTestToken(ctx))...), ), ), } @@ -445,11 +445,11 @@ func TestNSMGR_PassThroughRemote(t *testing.T) { Name: fmt.Sprintf("endpoint-%v", i), NetworkServiceNames: []string{fmt.Sprintf("my-service-remote-%v", i)}, } - _, err := domain.Nodes[i].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken, additionalFunctionality...) + _, err := domain.Nodes[i].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken(ctx), additionalFunctionality...) require.NoError(t, err) } - nsc := domain.Nodes[nodesCount-1].NewClient(ctx, sandbox.GenerateTestToken) + nsc := domain.Nodes[nodesCount-1].NewClient(ctx, sandbox.GenerateTestToken(ctx)) request := &networkservice.NetworkServiceRequest{ MechanismPreferences: []*networkservice.Mechanism{ @@ -492,10 +492,10 @@ func TestNSMGR_PassThroughLocal(t *testing.T) { chain.NewNetworkServiceServer( clienturl.NewServer(domain.Nodes[0].NSMgr.URL), connect.NewServer(ctx, - sandbox.NewCrossConnectClientFactory(sandbox.GenerateTestToken, + sandbox.NewCrossConnectClientFactory(sandbox.GenerateTestToken(ctx), newPassTroughClient(fmt.Sprintf("my-service-remote-%v", i-1)), kernel.NewClient()), - connect.WithDialOptions(sandbox.DefaultDialOptions(sandbox.GenerateTestToken)...), + connect.WithDialOptions(sandbox.DefaultDialOptions(ctx, sandbox.GenerateTestToken(ctx))...), ), ), } @@ -504,11 +504,11 @@ func TestNSMGR_PassThroughLocal(t *testing.T) { Name: fmt.Sprintf("endpoint-%v", i), NetworkServiceNames: []string{fmt.Sprintf("my-service-remote-%v", i)}, } - _, err := domain.Nodes[0].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken, additionalFunctionality...) + _, err := domain.Nodes[0].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken(ctx), additionalFunctionality...) require.NoError(t, err) } - nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken) + nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken(ctx)) request := &networkservice.NetworkServiceRequest{ MechanismPreferences: []*networkservice.Mechanism{ @@ -556,15 +556,15 @@ func TestNSMGR_ShouldCorrectlyAddForwardersWithSameNames(t *testing.T) { // 1. Add forwarders forwarder1Reg := forwarderReg.Clone() - _, err := domain.Nodes[0].NewForwarder(ctx, forwarder1Reg, sandbox.GenerateTestToken) + _, err := domain.Nodes[0].NewForwarder(ctx, forwarder1Reg, sandbox.GenerateTestToken(ctx)) require.NoError(t, err) forwarder2Reg := forwarderReg.Clone() - _, err = domain.Nodes[0].NewForwarder(ctx, forwarder2Reg, sandbox.GenerateTestToken) + _, err = domain.Nodes[0].NewForwarder(ctx, forwarder2Reg, sandbox.GenerateTestToken(ctx)) require.NoError(t, err) forwarder3Reg := forwarderReg.Clone() - _, err = domain.Nodes[0].NewForwarder(ctx, forwarder3Reg, sandbox.GenerateTestToken) + _, err = domain.Nodes[0].NewForwarder(ctx, forwarder3Reg, sandbox.GenerateTestToken(ctx)) require.NoError(t, err) // 2. Wait for refresh @@ -606,17 +606,17 @@ func TestNSMGR_ShouldCorrectlyAddEndpointsWithSameNames(t *testing.T) { Name: "endpoint", } - nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken) + nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken(ctx)) // 1. Add endpoints nse1Reg := nseReg.Clone() nse1Reg.NetworkServiceNames = []string{"service-1"} - _, err := domain.Nodes[0].NewEndpoint(ctx, nse1Reg, sandbox.GenerateTestToken) + _, err := domain.Nodes[0].NewEndpoint(ctx, nse1Reg, sandbox.GenerateTestToken(ctx)) require.NoError(t, err) nse2Reg := nseReg.Clone() nse2Reg.NetworkServiceNames = []string{"service-2"} - _, err = domain.Nodes[0].NewEndpoint(ctx, nse2Reg, sandbox.GenerateTestToken) + _, err = domain.Nodes[0].NewEndpoint(ctx, nse2Reg, sandbox.GenerateTestToken(ctx)) require.NoError(t, err) // 2. Wait for refresh @@ -689,10 +689,10 @@ func testNSEAndClient( ctx, cancel := context.WithCancel(ctx) defer cancel() - _, err := domain.Nodes[0].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken) + _, err := domain.Nodes[0].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken(ctx)) require.NoError(t, err) - nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken) + nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken(ctx)) conn, err := nsc.Request(ctx, &networkservice.NetworkServiceRequest{ MechanismPreferences: []*networkservice.Mechanism{ diff --git a/pkg/networkservice/chains/nsmgrproxy/server_test.go b/pkg/networkservice/chains/nsmgrproxy/server_test.go index e80d0c9d68..97b092ecf0 100644 --- a/pkg/networkservice/chains/nsmgrproxy/server_test.go +++ b/pkg/networkservice/chains/nsmgrproxy/server_test.go @@ -63,10 +63,10 @@ func TestNSMGR_InterdomainUseCase(t *testing.T) { NetworkServiceNames: []string{"my-service-interdomain"}, } - _, err := domain2.Nodes[0].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken) + _, err := domain2.Nodes[0].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken(ctx)) require.NoError(t, err) - nsc := domain1.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken) + nsc := domain1.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken(ctx)) request := &networkservice.NetworkServiceRequest{ MechanismPreferences: []*networkservice.Mechanism{ diff --git a/pkg/networkservice/common/heal/server_test.go b/pkg/networkservice/common/heal/server_test.go index 351c8ebca1..64190ddfc8 100644 --- a/pkg/networkservice/common/heal/server_test.go +++ b/pkg/networkservice/common/heal/server_test.go @@ -76,14 +76,14 @@ func TestHealClient_Request(t *testing.T) { server := chain.NewNetworkServiceServer( updatepath.NewServer("testServer"), monitor.NewServer(ctx, &monitorServer), - updatetoken.NewServer(sandbox.GenerateTestToken), + updatetoken.NewServer(sandbox.GenerateTestToken(ctx)), ) healServer := heal.NewServer(ctx, addressof.NetworkServiceClient(onHeal)) client := chain.NewNetworkServiceClient( updatepath.NewClient("testClient"), adapters.NewServerToClient(healServer), heal.NewClient(ctx, adapters.NewMonitorServerToClient(monitorServer)), - adapters.NewServerToClient(updatetoken.NewServer(sandbox.GenerateTestToken)), + adapters.NewServerToClient(updatetoken.NewServer(sandbox.GenerateTestToken(ctx))), adapters.NewServerToClient(server), ) @@ -146,10 +146,10 @@ func TestHealClient_EmptyInit(t *testing.T) { updatepath.NewClient("testClient"), adapters.NewServerToClient(healServer), heal.NewClient(ctx, eventchannel.NewMonitorConnectionClient(eventCh)), - adapters.NewServerToClient(updatetoken.NewServer(sandbox.GenerateTestToken)), + adapters.NewServerToClient(updatetoken.NewServer(sandbox.GenerateTestToken(ctx))), updatepath.NewClient("testServer"), eventTrigger, - adapters.NewServerToClient(updatetoken.NewServer(sandbox.GenerateTestToken)), + adapters.NewServerToClient(updatetoken.NewServer(sandbox.GenerateTestToken(ctx))), ) requestCtx, reqCancelFunc := context.WithTimeout(ctx, waitForTimeout) diff --git a/pkg/networkservice/common/refresh/client_test.go b/pkg/networkservice/common/refresh/client_test.go index 62d5541cf8..800902fc3b 100644 --- a/pkg/networkservice/common/refresh/client_test.go +++ b/pkg/networkservice/common/refresh/client_test.go @@ -64,7 +64,7 @@ func TestRefreshClient_ValidRefresh(t *testing.T) { serialize.NewClient(), updatepath.NewClient("refresh"), refresh.NewClient(ctx), - adapters.NewServerToClient(updatetoken.NewServer(sandbox.GenerateExpiringToken(expireTimeout))), + adapters.NewServerToClient(updatetoken.NewServer(sandbox.GenerateExpiringToken(ctx, expireTimeout))), cloneClient, ) @@ -104,7 +104,7 @@ func TestRefreshClient_StopRefreshAtClose(t *testing.T) { serialize.NewClient(), updatepath.NewClient("refresh"), refresh.NewClient(ctx), - adapters.NewServerToClient(updatetoken.NewServer(sandbox.GenerateExpiringToken(expireTimeout))), + adapters.NewServerToClient(updatetoken.NewServer(sandbox.GenerateExpiringToken(ctx, expireTimeout))), cloneClient, ) @@ -138,7 +138,7 @@ func TestRefreshClient_RestartsRefreshAtAnotherRequest(t *testing.T) { serialize.NewClient(), updatepath.NewClient("refresh"), refresh.NewClient(ctx), - adapters.NewServerToClient(updatetoken.NewServer(sandbox.GenerateExpiringToken(expireTimeout))), + adapters.NewServerToClient(updatetoken.NewServer(sandbox.GenerateExpiringToken(ctx, expireTimeout))), cloneClient, ) @@ -190,7 +190,7 @@ func TestRefreshClient_CheckRaceConditions(t *testing.T) { serialize.NewClient(), updatepath.NewClient("foo"), refresh.NewClient(ctx), - adapters.NewServerToClient(updatetoken.NewServer(sandbox.GenerateExpiringToken(conf.expireTimeout))), + adapters.NewServerToClient(updatetoken.NewServer(sandbox.GenerateExpiringToken(ctx, conf.expireTimeout))), adapters.NewServerToClient(refreshTester), ) @@ -207,7 +207,7 @@ func TestRefreshClient_Sandbox(t *testing.T) { SetNodesCount(2). SetContext(ctx). SetRegistryProxySupplier(nil). - SetTokenGenerateFunc(sandbox.GenerateTestToken). + SetTokenGenerateFunc(sandbox.GenerateTestToken(ctx)). Build() defer domain.Cleanup() @@ -217,10 +217,10 @@ func TestRefreshClient_Sandbox(t *testing.T) { } refreshSrv := newRefreshTesterServer(t, sandboxMinDuration, sandboxExpireTimeout) - _, err := domain.Nodes[0].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken, refreshSrv) + _, err := domain.Nodes[0].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken(ctx), refreshSrv) require.NoError(t, err) - nscTokenGenerator := sandbox.GenerateExpiringToken(sandboxExpireTimeout) + nscTokenGenerator := sandbox.GenerateExpiringToken(ctx, sandboxExpireTimeout) nsc := domain.Nodes[1].NewClient(ctx, nscTokenGenerator) refreshSrv.beforeRequest("test-conn") diff --git a/pkg/tools/sandbox/builder.go b/pkg/tools/sandbox/builder.go index 0e357d74b3..0b770e30d7 100644 --- a/pkg/tools/sandbox/builder.go +++ b/pkg/tools/sandbox/builder.go @@ -72,7 +72,6 @@ func NewBuilder(t *testing.T) *Builder { supplyRegistryProxy: proxydns.NewServer, supplyNSMgrProxy: nsmgrproxy.NewServer, setupNode: defaultSetupNode(t), - generateTokenFunc: GenerateTestToken, registryExpiryDuration: time.Minute, } } @@ -87,6 +86,10 @@ func (b *Builder) Build() *Domain { } ctx = log.Join(ctx, log.Empty()) + if b.generateTokenFunc == nil { + b.generateTokenFunc = GenerateTestToken(ctx) + } + domain := new(Domain) domain.NSMgrProxy = b.newNSMgrProxy(ctx) if domain.NSMgrProxy == nil { @@ -134,26 +137,15 @@ func (b *Builder) SetCustomConfig(config []*NodeConfig) *Builder { if customConfig.NsmgrCtx != nil { nodeConfig.NsmgrCtx = customConfig.NsmgrCtx - } else if nodeConfig.NsmgrCtx == nil { - nodeConfig.NsmgrCtx = b.ctx } - if customConfig.NsmgrGenerateTokenFunc != nil { nodeConfig.NsmgrGenerateTokenFunc = customConfig.NsmgrGenerateTokenFunc - } else if nodeConfig.NsmgrGenerateTokenFunc == nil { - nodeConfig.NsmgrGenerateTokenFunc = b.generateTokenFunc } - if customConfig.ForwarderCtx != nil { nodeConfig.ForwarderCtx = customConfig.ForwarderCtx - } else if nodeConfig.ForwarderCtx == nil { - nodeConfig.ForwarderCtx = b.ctx } - if customConfig.ForwarderGenerateTokenFunc != nil { nodeConfig.ForwarderGenerateTokenFunc = customConfig.ForwarderGenerateTokenFunc - } else if nodeConfig.ForwarderGenerateTokenFunc == nil { - nodeConfig.ForwarderGenerateTokenFunc = b.generateTokenFunc } b.nodesConfig = append(b.nodesConfig, nodeConfig) @@ -223,7 +215,7 @@ func (b *Builder) SetRegistryExpiryDuration(registryExpiryDuration time.Duration } func (b *Builder) dialContext(ctx context.Context, u *url.URL) *grpc.ClientConn { - conn, err := grpc.DialContext(ctx, grpcutils.URLToTarget(u), DefaultDialOptions(b.generateTokenFunc)...) + conn, err := grpc.DialContext(ctx, grpcutils.URLToTarget(u), DefaultDialOptions(ctx, b.generateTokenFunc)...) b.resources = append(b.resources, func() { _ = conn.Close() }) @@ -238,7 +230,7 @@ func (b *Builder) newNSMgrProxy(ctx context.Context) *EndpointEntry { name := "nsmgr-proxy-" + uuid.New().String() mgr := b.supplyNSMgrProxy(ctx, b.generateTokenFunc, nsmgrproxy.WithName(name), - nsmgrproxy.WithDialOptions(DefaultDialOptions(b.generateTokenFunc)...)) + nsmgrproxy.WithDialOptions(DefaultDialOptions(ctx, b.generateTokenFunc)...)) serveURL := &url.URL{Scheme: "tcp", Host: "127.0.0.1:0"} serve(ctx, serveURL, mgr.Register) log.FromContext(ctx).Infof("%v listen on: %v", name, serveURL) @@ -279,7 +271,7 @@ func (b *Builder) newNSMgr(ctx context.Context, address string, registryURL *url Url: serveURL.String(), } - mgr := b.supplyNSMgr(ctx, nsmgrReg, authorize.NewServer(authorize.Any()), generateTokenFunc, registryCC, DefaultDialOptions(generateTokenFunc)...) + mgr := b.supplyNSMgr(ctx, nsmgrReg, authorize.NewServer(authorize.Any()), generateTokenFunc, registryCC, DefaultDialOptions(ctx, generateTokenFunc)...) serve(ctx, serveURL, mgr.Register) log.FromContext(ctx).Infof("%v listen on: %v", nsmgrReg.Name, serveURL) @@ -310,7 +302,7 @@ func (b *Builder) newRegistryProxy(ctx context.Context, nsmgrProxyURL *url.URL) if b.supplyRegistryProxy == nil { return nil } - result := b.supplyRegistryProxy(ctx, b.Resolver, b.DNSDomainName, nsmgrProxyURL, DefaultDialOptions(b.generateTokenFunc)...) + result := b.supplyRegistryProxy(ctx, b.Resolver, b.DNSDomainName, nsmgrProxyURL, DefaultDialOptions(ctx, b.generateTokenFunc)...) serveURL := &url.URL{Scheme: "tcp", Host: "127.0.0.1:0"} serve(ctx, serveURL, result.Register) log.FromContext(ctx).Infof("registry-proxy-dns listen on: %v", serveURL) @@ -324,7 +316,7 @@ func (b *Builder) newRegistry(ctx context.Context, proxyRegistryURL *url.URL) *R if b.supplyRegistry == nil { return nil } - result := b.supplyRegistry(ctx, b.registryExpiryDuration, proxyRegistryURL, DefaultDialOptions(b.generateTokenFunc)...) + result := b.supplyRegistry(ctx, b.registryExpiryDuration, proxyRegistryURL, DefaultDialOptions(ctx, b.generateTokenFunc)...) serveURL := &url.URL{Scheme: "tcp", Host: "127.0.0.1:0"} serve(ctx, serveURL, result.Register) log.FromContext(ctx).Infof("Registry listen on: %v", serveURL) @@ -335,10 +327,23 @@ func (b *Builder) newRegistry(ctx context.Context, proxyRegistryURL *url.URL) *R } func (b *Builder) newNode(ctx context.Context, registryURL *url.URL, nodeConfig *NodeConfig) *Node { + if nodeConfig.NsmgrCtx == nil { + nodeConfig.NsmgrCtx = ctx + } + if nodeConfig.NsmgrGenerateTokenFunc == nil { + nodeConfig.NsmgrGenerateTokenFunc = b.generateTokenFunc + } + if nodeConfig.ForwarderCtx == nil { + nodeConfig.ForwarderCtx = ctx + } + if nodeConfig.ForwarderGenerateTokenFunc == nil { + nodeConfig.ForwarderGenerateTokenFunc = b.generateTokenFunc + } + nsmgrEntry := b.newNSMgr(nodeConfig.NsmgrCtx, "127.0.0.1:0", registryURL, nodeConfig.NsmgrGenerateTokenFunc) node := &Node{ - ctx: b.ctx, + ctx: ctx, NSMgr: nsmgrEntry, } diff --git a/pkg/tools/sandbox/node.go b/pkg/tools/sandbox/node.go index 6b4e008dde..fdb3614614 100644 --- a/pkg/tools/sandbox/node.go +++ b/pkg/tools/sandbox/node.go @@ -64,7 +64,7 @@ func (n *Node) NewForwarder( client.NewCrossConnectClientFactory( client.WithName(nse.Name), ), - connect.WithDialOptions(DefaultDialOptions(generatorFunc)...), + connect.WithDialOptions(DefaultDialOptions(ctx, generatorFunc)...), ), ) @@ -159,8 +159,7 @@ func (n *Node) NewClient( additionalFunctionality ...networkservice.NetworkServiceClient, ) networkservice.NetworkServiceClient { ctx = log.Join(ctx, log.Empty()) - cc, err := grpc.DialContext(ctx, grpcutils.URLToTarget(n.NSMgr.URL), DefaultDialOptions(generatorFunc)..., - ) + cc, err := grpc.DialContext(ctx, grpcutils.URLToTarget(n.NSMgr.URL), DefaultDialOptions(ctx, generatorFunc)...) if err != nil { log.FromContext(ctx).Fatalf("Failed to dial node NSMgr: %s", err.Error()) } diff --git a/pkg/tools/sandbox/utils.go b/pkg/tools/sandbox/utils.go index 64407ea5cd..eda6d090c2 100644 --- a/pkg/tools/sandbox/utils.go +++ b/pkg/tools/sandbox/utils.go @@ -29,6 +29,7 @@ import ( "github.com/networkservicemesh/api/pkg/api/networkservice" "github.com/networkservicemesh/sdk/pkg/networkservice/chains/client" + "github.com/networkservicemesh/sdk/pkg/tools/clock" "github.com/networkservicemesh/sdk/pkg/tools/opentracing" "github.com/networkservicemesh/sdk/pkg/tools/token" ) @@ -71,15 +72,18 @@ func WithInsecureStreamRPCCredentials() grpc.DialOption { } // GenerateTestToken generates test token -func GenerateTestToken(_ credentials.AuthInfo) (tokenValue string, expireTime time.Time, err error) { - return "TestToken", time.Now().Add(time.Hour).Local(), nil +func GenerateTestToken(ctx context.Context) token.GeneratorFunc { + clockTime := clock.FromContext(ctx) + return func(credentials.AuthInfo) (token string, expireTime time.Time, err error) { + return "TestToken", clockTime.Now().Add(time.Hour).Local(), nil + } } // GenerateExpiringToken returns a token generator with the specified expiration duration. -func GenerateExpiringToken(duration time.Duration) token.GeneratorFunc { - value := fmt.Sprintf("TestToken-%s", duration) +func GenerateExpiringToken(ctx context.Context, duration time.Duration) token.GeneratorFunc { + clockTime := clock.FromContext(ctx) return func(_ credentials.AuthInfo) (tokenValue string, expireTime time.Time, err error) { - return value, time.Now().Add(duration).Local(), nil + return fmt.Sprintf("TestToken-%s", duration), clockTime.Now().Add(duration).Local(), nil } } @@ -92,7 +96,8 @@ func NewCrossConnectClientFactory(generatorFunc token.GeneratorFunc, additionalF } // DefaultDialOptions returns default dial options for sandbox testing -func DefaultDialOptions(genTokenFunc token.GeneratorFunc) []grpc.DialOption { +func DefaultDialOptions(ctx context.Context, genTokenFunc token.GeneratorFunc) []grpc.DialOption { + clockTime := clock.FromContext(ctx) return append([]grpc.DialOption{ grpc.WithInsecure(), grpc.WithBlock(), @@ -104,5 +109,25 @@ func DefaultDialOptions(genTokenFunc token.GeneratorFunc) []grpc.DialOption { grpcfd.WithChainUnaryInterceptor(), WithInsecureRPCCredentials(), WithInsecureStreamRPCCredentials(), + withClockUnaryInterceptor(clockTime), + withClockStreamInterceptor(clockTime), }, opentracing.WithTracingDial()...) } + +func withClockUnaryInterceptor(clockTime clock.Clock) grpc.DialOption { + return grpc.WithUnaryInterceptor( + func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + ctx = clock.WithClock(ctx, clockTime) + return invoker(ctx, method, req, reply, cc, opts...) + }, + ) +} + +func withClockStreamInterceptor(clockTime clock.Clock) grpc.DialOption { + return grpc.WithStreamInterceptor( + func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + ctx = clock.WithClock(ctx, clockTime) + return streamer(ctx, desc, cc, method, opts...) + }, + ) +}