From dc57328ab864aa4f2b2583d5a8037233f5e29aba Mon Sep 17 00:00:00 2001 From: Artem Glazychev Date: Mon, 19 Dec 2022 12:47:18 +0700 Subject: [PATCH] Fix timeout/expire chain elements Signed-off-by: Artem Glazychev --- pkg/networkservice/chains/nsmgr/heal_test.go | 18 +- pkg/networkservice/chains/nsmgr/server.go | 2 +- .../chains/nsmgr/single_test.go | 210 ++++++++++++++++-- .../chains/nsmgrproxy/server.go | 2 +- pkg/networkservice/common/timeout/server.go | 12 +- pkg/registry/chains/memory/server.go | 2 +- pkg/registry/chains/proxydns/server.go | 2 +- pkg/registry/common/expire/nse_server.go | 29 +-- pkg/registry/common/expire/nse_server_test.go | 27 ++- .../common/grpcmetadata/common_test.go | 15 +- pkg/registry/common/grpcmetadata/ns_client.go | 15 +- pkg/registry/common/grpcmetadata/ns_test.go | 86 +++++-- .../common/grpcmetadata/nse_client.go | 15 +- pkg/registry/common/grpcmetadata/nse_test.go | 103 +++++++-- .../common/updatepath/ns_server_test.go | 10 +- pkg/registry/common/updatepath/nse_server.go | 12 +- .../common/updatepath/nse_server_test.go | 10 +- .../utils/inject/injectpeertoken/ns_server.go | 14 +- .../inject/injectpeertoken/nse_server.go | 15 +- .../opa/policies/common/tokens_expired.rego | 8 +- 20 files changed, 493 insertions(+), 114 deletions(-) diff --git a/pkg/networkservice/chains/nsmgr/heal_test.go b/pkg/networkservice/chains/nsmgr/heal_test.go index c3dfd2be05..89bbdf4b0a 100644 --- a/pkg/networkservice/chains/nsmgr/heal_test.go +++ b/pkg/networkservice/chains/nsmgr/heal_test.go @@ -538,10 +538,6 @@ func testNSMGRCloseHeal(t *testing.T, withNSEExpiration bool) { SetNSMgrProxySupplier(nil). SetRegistryProxySupplier(nil) - if withNSEExpiration { - builder = builder.SetRegistryExpiryDuration(time.Second / 2) - } - domain := builder.Build() nsRegistryClient := domain.NewNSRegistryClient(ctx, sandbox.GenerateTestToken) @@ -549,9 +545,13 @@ func testNSMGRCloseHeal(t *testing.T, withNSEExpiration bool) { nsReg, err := nsRegistryClient.Register(ctx, defaultRegistryService(t.Name())) require.NoError(t, err) - nseCtx, nseCtxCancel := context.WithCancel(ctx) - - domain.Nodes[0].NewEndpoint(nseCtx, defaultRegistryEndpoint(nsReg.Name), sandbox.GenerateTestToken) + nseCtx, nseCtxCancel := context.WithTimeout(ctx, time.Second/2) + if withNSEExpiration { + // NSE will be unregistered after (tokenTimeout - registerTimeout) + domain.Nodes[0].NewEndpoint(nseCtx, defaultRegistryEndpoint(nsReg.Name), sandbox.GenerateExpiringToken(time.Second)) + } else { + domain.Nodes[0].NewEndpoint(nseCtx, defaultRegistryEndpoint(nsReg.Name), sandbox.GenerateTestToken) + } request := defaultRequest(nsReg.Name) @@ -601,10 +601,6 @@ func testNSMGRCloseHeal(t *testing.T, withNSEExpiration bool) { nscCtxCancel() - for _, fwd := range domain.Nodes[0].Forwarders { - fwd.Cancel() - } - require.Eventually(t, func() bool { logrus.Error(goleak.Find()) return goleak.Find(ignoreCurrent) == nil diff --git a/pkg/networkservice/chains/nsmgr/server.go b/pkg/networkservice/chains/nsmgr/server.go index eaf4ac4ed1..02854d08e4 100644 --- a/pkg/networkservice/chains/nsmgr/server.go +++ b/pkg/networkservice/chains/nsmgr/server.go @@ -271,8 +271,8 @@ func NewServer(ctx context.Context, tokenGenerator token.GeneratorFunc, options } var nseRegistry = chain.NewNetworkServiceEndpointRegistryServer( - grpcmetadata.NewNetworkServiceEndpointRegistryServer(), begin.NewNetworkServiceEndpointRegistryServer(), + grpcmetadata.NewNetworkServiceEndpointRegistryServer(), updatepath.NewNetworkServiceEndpointRegistryServer(tokenGenerator), opts.authorizeNSERegistryServer, registryclientinfo.NewNetworkServiceEndpointRegistryServer(), diff --git a/pkg/networkservice/chains/nsmgr/single_test.go b/pkg/networkservice/chains/nsmgr/single_test.go index 8203ff2413..023aa51738 100644 --- a/pkg/networkservice/chains/nsmgr/single_test.go +++ b/pkg/networkservice/chains/nsmgr/single_test.go @@ -27,25 +27,32 @@ import ( "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" - "github.com/networkservicemesh/api/pkg/api/networkservice" - "github.com/networkservicemesh/api/pkg/api/networkservice/mechanisms/cls" - kernelmech "github.com/networkservicemesh/api/pkg/api/networkservice/mechanisms/kernel" - registryapi "github.com/networkservicemesh/api/pkg/api/registry" "github.com/stretchr/testify/require" "go.uber.org/goleak" "google.golang.org/grpc" "google.golang.org/grpc/credentials" - registryclient "github.com/networkservicemesh/sdk/pkg/registry/chains/client" + "github.com/networkservicemesh/api/pkg/api/networkservice" + "github.com/networkservicemesh/api/pkg/api/networkservice/mechanisms/cls" + kernelmech "github.com/networkservicemesh/api/pkg/api/networkservice/mechanisms/kernel" + registryapi "github.com/networkservicemesh/api/pkg/api/registry" "github.com/networkservicemesh/sdk/pkg/networkservice/chains/client" + "github.com/networkservicemesh/sdk/pkg/networkservice/chains/endpoint" "github.com/networkservicemesh/sdk/pkg/networkservice/chains/nsmgr" + "github.com/networkservicemesh/sdk/pkg/networkservice/common/authorize" "github.com/networkservicemesh/sdk/pkg/networkservice/common/excludedprefixes" "github.com/networkservicemesh/sdk/pkg/networkservice/ipam/point2pointipam" + countutils "github.com/networkservicemesh/sdk/pkg/networkservice/utils/count" + "github.com/networkservicemesh/sdk/pkg/networkservice/utils/inject/injecterror" "github.com/networkservicemesh/sdk/pkg/registry" + registryclient "github.com/networkservicemesh/sdk/pkg/registry/chains/client" "github.com/networkservicemesh/sdk/pkg/registry/chains/memory" - "github.com/networkservicemesh/sdk/pkg/registry/common/authorize" + authorizeregistry "github.com/networkservicemesh/sdk/pkg/registry/common/authorize" + "github.com/networkservicemesh/sdk/pkg/registry/common/sendfd" + injecterrorregistry "github.com/networkservicemesh/sdk/pkg/registry/utils/inject/injecterror" "github.com/networkservicemesh/sdk/pkg/tools/clientinfo" + "github.com/networkservicemesh/sdk/pkg/tools/grpcutils" "github.com/networkservicemesh/sdk/pkg/tools/sandbox" "github.com/networkservicemesh/sdk/pkg/tools/token" ) @@ -500,13 +507,13 @@ func Test_FailedRegistryAuthorization(t *testing.T) { nsmgrSuppier := func(ctx context.Context, tokenGenerator token.GeneratorFunc, options ...nsmgr.Option) nsmgr.Nsmgr { options = append(options, nsmgr.WithAuthorizeNSERegistryServer( - authorize.NewNetworkServiceEndpointRegistryServer(authorize.WithPolicies("etc/nsm/opa/registry/client_allowed.rego"))), + authorizeregistry.NewNetworkServiceEndpointRegistryServer(authorizeregistry.WithPolicies("etc/nsm/opa/registry/client_allowed.rego"))), nsmgr.WithAuthorizeNSRegistryServer( - authorize.NewNetworkServiceRegistryServer(authorize.WithPolicies("etc/nsm/opa/registry/client_allowed.rego"))), + authorizeregistry.NewNetworkServiceRegistryServer(authorizeregistry.WithPolicies("etc/nsm/opa/registry/client_allowed.rego"))), nsmgr.WithAuthorizeNSERegistryClient( - authorize.NewNetworkServiceEndpointRegistryClient(authorize.WithPolicies("etc/nsm/opa/registry/client_allowed.rego"))), + authorizeregistry.NewNetworkServiceEndpointRegistryClient(authorizeregistry.WithPolicies("etc/nsm/opa/registry/client_allowed.rego"))), nsmgr.WithAuthorizeNSRegistryClient( - authorize.NewNetworkServiceRegistryClient(authorize.WithPolicies("etc/nsm/opa/registry/client_allowed.rego"))), + authorizeregistry.NewNetworkServiceRegistryClient(authorizeregistry.WithPolicies("etc/nsm/opa/registry/client_allowed.rego"))), ) return nsmgr.NewServer(ctx, tokenGenerator, options...) } @@ -526,7 +533,7 @@ func Test_FailedRegistryAuthorization(t *testing.T) { memory.WithProxyRegistryURL(proxyRegistryURL), memory.WithDialOptions(options...), memory.WithAuthorizeNSRegistryServer( - authorize.NewNetworkServiceRegistryServer(authorize.WithPolicies("etc/nsm/opa/registry/client_allowed.rego"))), + authorizeregistry.NewNetworkServiceRegistryServer(authorizeregistry.WithPolicies("etc/nsm/opa/registry/client_allowed.rego"))), ) } domain := sandbox.NewBuilder(ctx, t). @@ -553,7 +560,7 @@ func Test_FailedRegistryAuthorization(t *testing.T) { nsRegistryClient1 := domain.NewNSRegistryClient(ctx, tokenGeneratorFunc("spiffe://test.com/ns-1"), registryclient.WithAuthorizeNSRegistryClient( - authorize.NewNetworkServiceRegistryClient(authorize.WithPolicies("etc/nsm/opa/registry/client_allowed.rego")))) + authorizeregistry.NewNetworkServiceRegistryClient(authorizeregistry.WithPolicies("etc/nsm/opa/registry/client_allowed.rego")))) ns1 := defaultRegistryService("ns-1") _, err := nsRegistryClient1.Register(ctx, ns1) @@ -561,9 +568,186 @@ func Test_FailedRegistryAuthorization(t *testing.T) { nsRegistryClient2 := domain.NewNSRegistryClient(ctx, tokenGeneratorFunc("spiffe://test.com/ns-2"), registryclient.WithAuthorizeNSRegistryClient( - authorize.NewNetworkServiceRegistryClient(authorize.WithPolicies("etc/nsm/opa/registry/client_allowed.rego")))) + authorizeregistry.NewNetworkServiceRegistryClient(authorizeregistry.WithPolicies("etc/nsm/opa/registry/client_allowed.rego")))) ns2 := defaultRegistryService("ns-1") _, err = nsRegistryClient2.Register(ctx, ns2) require.Error(t, err) } + +func createAuthorizedEndpoint(ctx context.Context, t *testing.T, ns string, nsmgrURL *url.URL, counter networkservice.NetworkServiceServer) { + nseReg := defaultRegistryEndpoint(ns) + + nse := endpoint.NewServer(ctx, sandbox.GenerateTestToken, + endpoint.WithName("final-endpoint"), + endpoint.WithAuthorizeServer(authorize.NewServer(authorize.WithPolicies("etc/nsm/opa/common/tokens_expired.rego"))), + endpoint.WithAdditionalFunctionality(counter), + ) + + nseServer := grpc.NewServer() + nse.Register(nseServer) + nseURL := &url.URL{Scheme: "tcp", Host: "127.0.0.1:0"} + errCh := grpcutils.ListenAndServe(ctx, nseURL, nseServer) + select { + case err := <-errCh: + require.NoError(t, err) + default: + } + + nseRegistryClient := registryclient.NewNetworkServiceEndpointRegistryClient( + ctx, + registryclient.WithClientURL(nsmgrURL), + registryclient.WithDialOptions(sandbox.DialOptions(sandbox.WithTokenGenerator(sandbox.GenerateTestToken))...), + registryclient.WithNSEAdditionalFunctionality(sendfd.NewNetworkServiceEndpointRegistryClient()), + ) + + nseReg.Url = nseURL.String() + _, err := nseRegistryClient.Register(ctx, nseReg.Clone()) + require.NoError(t, err) +} + +// This test checks timeout on sandbox +// We run nsmgr and NSE with networkservice authorize chain element (tokens_expired.rego) +func Test_Timeout(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + // timeout chain element will call Close() after (tokenTimeout - requestTimeout) + // to be sure that token is not expired + tokenTimeout := time.Second * 2 + requestTimeout := time.Second + time.Millisecond*500 + + chainCtx, chainCtxCancel := context.WithTimeout(context.Background(), time.Second*5) + defer chainCtxCancel() + + // Set tokens_expired policy + nsmgrSuppier := func(ctx context.Context, tokenGenerator token.GeneratorFunc, options ...nsmgr.Option) nsmgr.Nsmgr { + options = append(options, + nsmgr.WithAuthorizeServer(authorize.NewServer(authorize.WithPolicies("etc/nsm/opa/common/tokens_expired.rego"))), + ) + return nsmgr.NewServer(ctx, tokenGenerator, options...) + } + + domain := sandbox.NewBuilder(chainCtx, t). + SetNodesCount(1). + SetNSMgrSupplier(nsmgrSuppier). + Build() + + nsRegistryClient := domain.NewNSRegistryClient(chainCtx, sandbox.GenerateTestToken) + ns := defaultRegistryService("ns") + + nsReg, err := nsRegistryClient.Register(chainCtx, ns) + require.NoError(t, err) + + counter := new(countutils.Server) + + createAuthorizedEndpoint(chainCtx, t, ns.Name, domain.Nodes[0].NSMgr.URL, counter) + + // Set an expiring token. + // Add injecterror to allow only the first Request. All subsequent ones will fall. + // This emulates the death of the client. + nsc := domain.Nodes[0].NewClient(chainCtx, + sandbox.GenerateExpiringToken(tokenTimeout), + client.WithAdditionalFunctionality( + injecterror.NewClient(injecterror.WithRequestErrorTimes(1, -1)), + ), + ) + + request := defaultRequest(nsReg.Name) + requestCtx, requestCtxCancel := context.WithTimeout(context.Background(), requestTimeout) + defer requestCtxCancel() + + conn, err := nsc.Request(requestCtx, request) + require.NoError(t, err) + require.NotNil(t, conn) + // Closes equal to 0 for now + require.Equal(t, 1, counter.Requests()) + require.Equal(t, 0, counter.Closes()) + + // Waiting for the timeout + require.Eventually(t, func() bool { return counter.Closes() == 1 }, tokenTimeout, time.Millisecond*100) +} + +// This test checks registry expire on sandbox +// We run nsmgr and registry with registry authorize chain element (tokens_expired.rego) +func Test_Expire(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + // expire chain element will call Unregister() after (tokenTimeout - registerTimeout) + // to be sure that token is not expired + tokenTimeout := time.Second * 2 + registerTimeout := time.Second + time.Millisecond*500 + + chainCtx, chainCtxCancel := context.WithTimeout(context.Background(), time.Second*5) + defer chainCtxCancel() + + // Set tokens_expired policy for nsmgr and registry + nsmgrSuppier := func(ctx context.Context, tokenGenerator token.GeneratorFunc, options ...nsmgr.Option) nsmgr.Nsmgr { + options = append(options, + nsmgr.WithAuthorizeNSERegistryServer( + authorizeregistry.NewNetworkServiceEndpointRegistryServer(authorizeregistry.WithPolicies("etc/nsm/opa/common/tokens_expired.rego"))), + ) + return nsmgr.NewServer(ctx, tokenGenerator, options...) + } + + registrySupplier := func( + ctx context.Context, + tokenGenerator token.GeneratorFunc, + expiryDuration time.Duration, + proxyRegistryURL *url.URL, + options ...grpc.DialOption) registry.Registry { + return memory.NewServer( + ctx, + tokenGenerator, + memory.WithExpireDuration(expiryDuration), + memory.WithProxyRegistryURL(proxyRegistryURL), + memory.WithDialOptions(options...), + memory.WithAuthorizeNSRegistryServer( + authorizeregistry.NewNetworkServiceRegistryServer(authorizeregistry.WithPolicies("etc/nsm/opa/common/tokens_expired.rego"))), + ) + } + + domain := sandbox.NewBuilder(chainCtx, t). + SetNodesCount(1). + SetNSMgrSupplier(nsmgrSuppier). + SetRegistrySupplier(registrySupplier). + Build() + + nsRegistryClient := domain.NewNSRegistryClient(chainCtx, sandbox.GenerateTestToken) + ns := defaultRegistryService("ns") + + nsReg, err := nsRegistryClient.Register(chainCtx, ns) + require.NoError(t, err) + + // Set an expiring token. + // Add injecterrorregistry to allow only the first Register. All subsequent ones will fall. + // This emulates the death of the NSE. + nseRegistryClient := registryclient.NewNetworkServiceEndpointRegistryClient(chainCtx, + registryclient.WithClientURL(domain.Nodes[0].NSMgr.URL), + registryclient.WithDialOptions(sandbox.DialOptions(sandbox.WithTokenGenerator(sandbox.GenerateExpiringToken(tokenTimeout)))...), + registryclient.WithNSEAdditionalFunctionality( + injecterrorregistry.NewNetworkServiceEndpointRegistryClient( + injecterrorregistry.WithRegisterErrorTimes(1, -1), + injecterrorregistry.WithFindErrorTimes())), + ) + + registerCtx, registerCtxCancel := context.WithTimeout(context.Background(), registerTimeout) + defer registerCtxCancel() + _, err = nseRegistryClient.Register(registerCtx, ®istryapi.NetworkServiceEndpoint{ + Name: "final-endpoint", + Url: "nseURL", + NetworkServiceNames: []string{nsReg.Name}, + }) + require.NoError(t, err) + + // Wait for the endpoint expiration + time.Sleep(tokenTimeout) + stream, err := nseRegistryClient.Find(chainCtx, ®istryapi.NetworkServiceEndpointQuery{ + NetworkServiceEndpoint: ®istryapi.NetworkServiceEndpoint{ + Name: "final-endpoint", + }, + }) + require.NoError(t, err) + + // Eventually expire will call Unregister + require.Len(t, registryapi.ReadNetworkServiceEndpointList(stream), 0) +} diff --git a/pkg/networkservice/chains/nsmgrproxy/server.go b/pkg/networkservice/chains/nsmgrproxy/server.go index d8c2b99b96..62288023b5 100644 --- a/pkg/networkservice/chains/nsmgrproxy/server.go +++ b/pkg/networkservice/chains/nsmgrproxy/server.go @@ -291,8 +291,8 @@ func NewServer(ctx context.Context, regURL, proxyURL *url.URL, tokenGenerator to ) var nseServerChain = chain.NewNetworkServiceEndpointRegistryServer( - grpcmetadata.NewNetworkServiceEndpointRegistryServer(), begin.NewNetworkServiceEndpointRegistryServer(), + grpcmetadata.NewNetworkServiceEndpointRegistryServer(), updatepath.NewNetworkServiceEndpointRegistryServer(tokenGenerator), opts.authorizeNSERegistryServer, clienturl.NewNetworkServiceEndpointRegistryServer(proxyURL), diff --git a/pkg/networkservice/common/timeout/server.go b/pkg/networkservice/common/timeout/server.go index f82eb2ff0c..45515ddc2b 100644 --- a/pkg/networkservice/common/timeout/server.go +++ b/pkg/networkservice/common/timeout/server.go @@ -51,6 +51,14 @@ func NewServer(ctx context.Context) networkservice.NetworkServiceServer { } func (s *timeoutServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (conn *networkservice.Connection, err error) { + timeClock := clock.FromContext(ctx) + + deadline, ok := ctx.Deadline() + requestTimeout := timeClock.Until(deadline) + if !ok { + requestTimeout = 0 + } + conn, err = next.Server(ctx).Request(ctx, request) if err != nil { return nil, err @@ -67,8 +75,8 @@ func (s *timeoutServer) Request(ctx context.Context, request *networkservice.Net } store(ctx, metadata.IsClient(s), cancel) eventFactory := begin.FromContext(ctx) - timeClock := clock.FromContext(ctx) - afterCh := timeClock.After(timeClock.Until(expirationTime)) + afterCh := timeClock.After(timeClock.Until(expirationTime) - requestTimeout) + go func(cancelCtx context.Context, afterCh <-chan time.Time) { select { case <-cancelCtx.Done(): diff --git a/pkg/registry/chains/memory/server.go b/pkg/registry/chains/memory/server.go index c35b845954..33d370976f 100644 --- a/pkg/registry/chains/memory/server.go +++ b/pkg/registry/chains/memory/server.go @@ -135,8 +135,8 @@ func NewServer(ctx context.Context, tokenGenerator token.GeneratorFunc, options } nseChain := chain.NewNetworkServiceEndpointRegistryServer( - grpcmetadata.NewNetworkServiceEndpointRegistryServer(), begin.NewNetworkServiceEndpointRegistryServer(), + grpcmetadata.NewNetworkServiceEndpointRegistryServer(), updatepath.NewNetworkServiceEndpointRegistryServer(tokenGenerator), opts.authorizeNSERegistryServer, switchcase.NewNetworkServiceEndpointRegistryServer(switchcase.NSEServerCase{ diff --git a/pkg/registry/chains/proxydns/server.go b/pkg/registry/chains/proxydns/server.go index b5bcd136ae..7d1d39e549 100644 --- a/pkg/registry/chains/proxydns/server.go +++ b/pkg/registry/chains/proxydns/server.go @@ -109,8 +109,8 @@ func NewServer(ctx context.Context, tokenGenerator token.GeneratorFunc, dnsResol } nseChain := chain.NewNetworkServiceEndpointRegistryServer( - grpcmetadata.NewNetworkServiceEndpointRegistryServer(), begin.NewNetworkServiceEndpointRegistryServer(), + grpcmetadata.NewNetworkServiceEndpointRegistryServer(), updatepath.NewNetworkServiceEndpointRegistryServer(tokenGenerator), opts.authorizeNSERegistryServer, dnsresolve.NewNetworkServiceEndpointRegistryServer(dnsresolve.WithResolver(dnsResolver)), diff --git a/pkg/registry/common/expire/nse_server.go b/pkg/registry/common/expire/nse_server.go index bca5375f42..c7f381bc69 100644 --- a/pkg/registry/common/expire/nse_server.go +++ b/pkg/registry/common/expire/nse_server.go @@ -32,35 +32,38 @@ import ( ) type expireNSEServer struct { - nseExpiration time.Duration - ctx context.Context + defaultNseExpiration time.Duration + ctx context.Context cancelsMap } // NewNetworkServiceEndpointRegistryServer creates a new NetworkServiceServer chain element that implements unregister // of expired connections for the subsequent chain elements. -func NewNetworkServiceEndpointRegistryServer(ctx context.Context, nseExpiration time.Duration) registry.NetworkServiceEndpointRegistryServer { +func NewNetworkServiceEndpointRegistryServer(ctx context.Context, defaultNseExpiration time.Duration) registry.NetworkServiceEndpointRegistryServer { return &expireNSEServer{ - nseExpiration: nseExpiration, - ctx: ctx, + defaultNseExpiration: defaultNseExpiration, + ctx: ctx, } } func (s *expireNSEServer) Register(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) { factory := begin.FromContext(ctx) timeClock := clock.FromContext(ctx) - expirationTime := timeClock.Now().Add(s.nseExpiration).Local() logger := log.FromContext(ctx).WithField("expireNSEServer", "Register") - if nse.GetExpirationTime() != nil { - if nseExpirationTime := nse.GetExpirationTime().AsTime().Local(); nseExpirationTime.Before(expirationTime) { - expirationTime = nseExpirationTime - logger.Infof("selected expiration time %v for %v", expirationTime, nse.GetName()) - } + deadline, ok := ctx.Deadline() + requestTimeout := timeClock.Until(deadline) + if !ok { + requestTimeout = 0 } - nse.ExpirationTime = timestamppb.New(expirationTime) + expirationTime := nse.GetExpirationTime().AsTime() + if nse.GetExpirationTime() == nil { + expirationTime = timeClock.Now().Add(s.defaultNseExpiration).Local() + nse.ExpirationTime = timestamppb.New(expirationTime) + logger.Infof("selected expiration time %v for %v", expirationTime, nse.GetName()) + } resp, err := next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, nse) if err != nil { @@ -78,7 +81,7 @@ func (s *expireNSEServer) Register(ctx context.Context, nse *registry.NetworkSer } s.cancelsMap.Store(nse.GetName(), cancel) - expireCh := timeClock.After(timeClock.Until(expirationTime.Local())) + expireCh := timeClock.After(timeClock.Until(expirationTime.Local()) - requestTimeout) go func() { select { diff --git a/pkg/registry/common/expire/nse_server_test.go b/pkg/registry/common/expire/nse_server_test.go index 0b53f76461..a8f35fb4ec 100644 --- a/pkg/registry/common/expire/nse_server_test.go +++ b/pkg/registry/common/expire/nse_server_test.go @@ -23,9 +23,11 @@ import ( "testing" "time" + "github.com/golang-jwt/jwt/v4" "github.com/golang/protobuf/ptypes/empty" "github.com/stretchr/testify/require" "go.uber.org/goleak" + "google.golang.org/grpc/credentials" "google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/timestamppb" @@ -36,11 +38,14 @@ import ( "github.com/networkservicemesh/sdk/pkg/registry/common/localbypass" "github.com/networkservicemesh/sdk/pkg/registry/common/memory" "github.com/networkservicemesh/sdk/pkg/registry/common/refresh" + "github.com/networkservicemesh/sdk/pkg/registry/common/updatepath" "github.com/networkservicemesh/sdk/pkg/registry/core/adapters" "github.com/networkservicemesh/sdk/pkg/registry/core/next" "github.com/networkservicemesh/sdk/pkg/registry/utils/inject/injecterror" + "github.com/networkservicemesh/sdk/pkg/registry/utils/inject/injectpeertoken" "github.com/networkservicemesh/sdk/pkg/tools/clock" "github.com/networkservicemesh/sdk/pkg/tools/clockmock" + "github.com/networkservicemesh/sdk/pkg/tools/token" ) const ( @@ -125,6 +130,20 @@ func TestExpireNSEServer_ShouldUseLessExpirationTimeFromInput_AndWork(t *testing }, testWait, testTick) } +func generateTestToken(ctx context.Context, duration time.Duration) token.GeneratorFunc { + return func(_ credentials.AuthInfo) (string, time.Time, error) { + expireTime := clock.FromContext(ctx).Now().Add(duration).Local() + + claims := jwt.RegisteredClaims{ + Subject: "spiffe://test.com/subject", + ExpiresAt: jwt.NewNumericDate(expireTime), + } + + tok, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte("supersecret")) + return tok, expireTime, err + } +} + func TestExpireNSEServer_ShouldUseLessExpirationTimeFromResponse(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) @@ -136,10 +155,12 @@ func TestExpireNSEServer_ShouldUseLessExpirationTimeFromResponse(t *testing.T) { s := next.NewNetworkServiceEndpointRegistryServer( begin.NewNetworkServiceEndpointRegistryServer(), + injectpeertoken.NewNetworkServiceEndpointRegistryServer(generateTestToken(ctx, expireTimeout)), + updatepath.NewNetworkServiceEndpointRegistryServer(generateTestToken(ctx, expireTimeout)), expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout), new(remoteNSEServer), // <-- GRPC invocation begin.NewNetworkServiceEndpointRegistryServer(), - expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout/2), + updatepath.NewNetworkServiceEndpointRegistryServer(generateTestToken(ctx, expireTimeout/2)), ) resp, err := s.Register(ctx, ®istry.NetworkServiceEndpoint{Name: "nse-1"}) @@ -161,7 +182,9 @@ func TestExpireNSEServer_ShouldRemoveNSEAfterExpirationTime(t *testing.T) { s := next.NewNetworkServiceEndpointRegistryServer( begin.NewNetworkServiceEndpointRegistryServer(), - expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout), + injectpeertoken.NewNetworkServiceEndpointRegistryServer(generateTestToken(ctx, expireTimeout)), + updatepath.NewNetworkServiceEndpointRegistryServer(generateTestToken(ctx, expireTimeout)), + expire.NewNetworkServiceEndpointRegistryServer(ctx, expireTimeout*2), new(remoteNSEServer), // <-- GRPC invocation mem, ) diff --git a/pkg/registry/common/grpcmetadata/common_test.go b/pkg/registry/common/grpcmetadata/common_test.go index a71d0add8c..cf984f2926 100644 --- a/pkg/registry/common/grpcmetadata/common_test.go +++ b/pkg/registry/common/grpcmetadata/common_test.go @@ -19,6 +19,8 @@ package grpcmetadata_test import ( "time" + "github.com/networkservicemesh/sdk/pkg/tools/clockmock" + "github.com/golang-jwt/jwt/v4" "google.golang.org/grpc/credentials" @@ -29,9 +31,16 @@ const ( key = "supersecret" ) -func tokenGeneratorFunc(spiffeID string) token.GeneratorFunc { +// tokenGeneratorFunc generates new tokens automatically (based on time change). +// time.Second + smth - the time tick for jwt is a second. +func tokenGeneratorFunc(clock *clockmock.Mock, spiffeID string) token.GeneratorFunc { return func(peerAuthInfo credentials.AuthInfo) (string, time.Time, error) { - tok, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"sub": spiffeID}).SignedString([]byte(key)) - return tok, time.Date(3000, 1, 1, 1, 1, 1, 1, time.UTC), err + clock.Add(time.Second + time.Millisecond*10) + tok, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "sub": spiffeID, + "exp": jwt.NewNumericDate(clock.Now().Add(time.Hour)), + }, + ).SignedString([]byte(key)) + return tok, clock.Now(), err } } diff --git a/pkg/registry/common/grpcmetadata/ns_client.go b/pkg/registry/common/grpcmetadata/ns_client.go index a7459e2dbb..8a5d95c70b 100644 --- a/pkg/registry/common/grpcmetadata/ns_client.go +++ b/pkg/registry/common/grpcmetadata/ns_client.go @@ -75,5 +75,18 @@ func (c *grpcMetadataNSClient) Unregister(ctx context.Context, ns *registry.Netw return nil, err } - return next.NetworkServiceRegistryClient(ctx).Unregister(ctx, ns, opts...) + var header metadata.MD + opts = append(opts, grpc.Header(&header)) + + resp, err := next.NetworkServiceRegistryClient(ctx).Unregister(ctx, ns, opts...) + if err != nil { + return nil, err + } + + newpath, err := fromMD(header) + if err == nil { + path.Index = newpath.Index + path.PathSegments = newpath.PathSegments + } + return resp, nil } diff --git a/pkg/registry/common/grpcmetadata/ns_test.go b/pkg/registry/common/grpcmetadata/ns_test.go index a4694ee92d..d315ff18ab 100644 --- a/pkg/registry/common/grpcmetadata/ns_test.go +++ b/pkg/registry/common/grpcmetadata/ns_test.go @@ -24,8 +24,10 @@ import ( "github.com/networkservicemesh/api/pkg/api/registry" "github.com/stretchr/testify/require" + "go.uber.org/goleak" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/protobuf/types/known/emptypb" "github.com/networkservicemesh/sdk/pkg/registry/common/grpcmetadata" "github.com/networkservicemesh/sdk/pkg/registry/common/updatepath" @@ -33,8 +35,8 @@ import ( "github.com/networkservicemesh/sdk/pkg/registry/core/next" "github.com/networkservicemesh/sdk/pkg/registry/utils/checks/checkcontext" "github.com/networkservicemesh/sdk/pkg/registry/utils/inject/injectpeertoken" - - "go.uber.org/goleak" + "github.com/networkservicemesh/sdk/pkg/tools/clock" + "github.com/networkservicemesh/sdk/pkg/tools/clockmock" ) const ( @@ -43,22 +45,64 @@ const ( serverID = "spiffe://test.com/server" ) +type pathCheckerNSClient struct { + funcBefore func(ctx context.Context) *grpcmetadata.Path + funcAfter func(ctx context.Context, pBefore *grpcmetadata.Path) +} + +func newPathCheckerNSClient(t *testing.T, expectedPathIndex int) registry.NetworkServiceRegistryClient { + client := &pathCheckerNSClient{} + + client.funcBefore = func(ctx context.Context) *grpcmetadata.Path { + p := grpcmetadata.PathFromContext(ctx).Clone() + require.Equal(t, int(p.Index), expectedPathIndex) + + return p + } + client.funcAfter = func(ctx context.Context, pBefore *grpcmetadata.Path) { + pAfter := grpcmetadata.PathFromContext(ctx).Clone() + require.Equal(t, int(pAfter.Index), expectedPathIndex) + for i := expectedPathIndex; i < len(pBefore.PathSegments); i++ { + require.NotEqual(t, pBefore.PathSegments[i].Token, pAfter.PathSegments[i].Token) + } + } + return client +} + +func (p *pathCheckerNSClient) Register(ctx context.Context, in *registry.NetworkService, opts ...grpc.CallOption) (*registry.NetworkService, error) { + pBefore := p.funcBefore(ctx) + r, e := next.NetworkServiceRegistryClient(ctx).Register(ctx, in, opts...) + p.funcAfter(ctx, pBefore) + return r, e +} + +func (p *pathCheckerNSClient) Find(ctx context.Context, query *registry.NetworkServiceQuery, opts ...grpc.CallOption) (registry.NetworkServiceRegistry_FindClient, error) { + return next.NetworkServiceRegistryClient(ctx).Find(ctx, query, opts...) +} + +func (p *pathCheckerNSClient) Unregister(ctx context.Context, in *registry.NetworkService, opts ...grpc.CallOption) (*emptypb.Empty, error) { + pBefore := p.funcBefore(ctx) + r, e := next.NetworkServiceRegistryClient(ctx).Unregister(ctx, in, opts...) + p.funcAfter(ctx, pBefore) + return r, e +} + func TestGRPCMetadataNetworkService(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) - ctx, cacncel := context.WithTimeout(context.Background(), time.Second) - defer cacncel() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() - clientToken, _, _ := tokenGeneratorFunc(clientID)(nil) - proxyToken, _, _ := tokenGeneratorFunc(proxyID)(nil) + clockMock := clockmock.New(ctx) + ctx = clock.WithClock(ctx, clockMock) serverLis, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) server := next.NewNetworkServiceRegistryServer( - injectpeertoken.NewNetworkServiceRegistryServer(proxyToken), + injectpeertoken.NewNetworkServiceRegistryServer(tokenGeneratorFunc(clockMock, proxyID)), grpcmetadata.NewNetworkServiceRegistryServer(), - updatepath.NewNetworkServiceRegistryServer(tokenGeneratorFunc(serverID)), + updatepath.NewNetworkServiceRegistryServer(tokenGeneratorFunc(clockMock, serverID)), checkcontext.NewNSServer(t, func(t *testing.T, ctx context.Context) { path := grpcmetadata.PathFromContext(ctx) require.Equal(t, int(path.Index), 2) @@ -83,14 +127,15 @@ func TestGRPCMetadataNetworkService(t *testing.T) { }() proxyServer := next.NewNetworkServiceRegistryServer( - injectpeertoken.NewNetworkServiceRegistryServer(clientToken), + injectpeertoken.NewNetworkServiceRegistryServer(tokenGeneratorFunc(clockMock, clientID)), grpcmetadata.NewNetworkServiceRegistryServer(), - updatepath.NewNetworkServiceRegistryServer(tokenGeneratorFunc(proxyID)), + updatepath.NewNetworkServiceRegistryServer(tokenGeneratorFunc(clockMock, proxyID)), checkcontext.NewNSServer(t, func(t *testing.T, ctx context.Context) { path := grpcmetadata.PathFromContext(ctx) require.Equal(t, int(path.Index), 1) }), adapters.NetworkServiceClientToServer(next.NewNetworkServiceRegistryClient( + newPathCheckerNSClient(t, 1), grpcmetadata.NewNetworkServiceRegistryClient(), registry.NewNetworkServiceRegistryClient(serverConn), )), @@ -111,6 +156,7 @@ func TestGRPCMetadataNetworkService(t *testing.T) { }() client := next.NewNetworkServiceRegistryClient( + newPathCheckerNSClient(t, 0), grpcmetadata.NewNetworkServiceRegistryClient(), registry.NewNetworkServiceRegistryClient(conn)) @@ -124,6 +170,10 @@ func TestGRPCMetadataNetworkService(t *testing.T) { require.Equal(t, int(path.Index), 0) require.Len(t, path.PathSegments, 3) + // Simulate refresh + _, err = client.Register(ctx, ns) + require.NoError(t, err) + _, err = client.Unregister(ctx, ns) require.NoError(t, err) @@ -134,10 +184,11 @@ func TestGRPCMetadataNetworkService(t *testing.T) { func TestGRPCMetadataNetworkService_BackwardCompatibility(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) - ctx, cacncel := context.WithTimeout(context.Background(), time.Second) - defer cacncel() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() - clientToken, _, _ := tokenGeneratorFunc(clientID)(nil) + clockMock := clockmock.New(ctx) + ctx = clock.WithClock(ctx, clockMock) serverLis, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -165,9 +216,9 @@ func TestGRPCMetadataNetworkService_BackwardCompatibility(t *testing.T) { }() proxyServer := next.NewNetworkServiceRegistryServer( - injectpeertoken.NewNetworkServiceRegistryServer(clientToken), + injectpeertoken.NewNetworkServiceRegistryServer(tokenGeneratorFunc(clockMock, clientID)), grpcmetadata.NewNetworkServiceRegistryServer(), - updatepath.NewNetworkServiceRegistryServer(tokenGeneratorFunc(proxyID)), + updatepath.NewNetworkServiceRegistryServer(tokenGeneratorFunc(clockMock, proxyID)), checkcontext.NewNSServer(t, func(t *testing.T, ctx context.Context) { path := grpcmetadata.PathFromContext(ctx) require.Equal(t, int(path.Index), 1) @@ -196,6 +247,7 @@ func TestGRPCMetadataNetworkService_BackwardCompatibility(t *testing.T) { }() client := next.NewNetworkServiceRegistryClient( + newPathCheckerNSClient(t, 0), grpcmetadata.NewNetworkServiceRegistryClient(), registry.NewNetworkServiceRegistryClient(conn)) @@ -209,6 +261,10 @@ func TestGRPCMetadataNetworkService_BackwardCompatibility(t *testing.T) { require.Equal(t, int(path.Index), 0) require.Len(t, path.PathSegments, 2) + // Simulate refresh + _, err = client.Register(ctx, ns) + require.NoError(t, err) + _, err = client.Unregister(ctx, ns) require.NoError(t, err) } diff --git a/pkg/registry/common/grpcmetadata/nse_client.go b/pkg/registry/common/grpcmetadata/nse_client.go index e5ec57e32b..b77ba989b0 100644 --- a/pkg/registry/common/grpcmetadata/nse_client.go +++ b/pkg/registry/common/grpcmetadata/nse_client.go @@ -73,5 +73,18 @@ func (c *grpcMetadataNSEClient) Unregister(ctx context.Context, nse *registry.Ne return nil, err } - return next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, nse, opts...) + var header metadata.MD + opts = append(opts, grpc.Header(&header)) + + resp, err := next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, nse, opts...) + if err != nil { + return nil, err + } + + newpath, err := fromMD(header) + if err == nil { + path.Index = newpath.Index + path.PathSegments = newpath.PathSegments + } + return resp, nil } diff --git a/pkg/registry/common/grpcmetadata/nse_test.go b/pkg/registry/common/grpcmetadata/nse_test.go index 83ed2520f0..949965aec7 100644 --- a/pkg/registry/common/grpcmetadata/nse_test.go +++ b/pkg/registry/common/grpcmetadata/nse_test.go @@ -22,36 +22,86 @@ import ( "testing" "time" - "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/golang/protobuf/ptypes/empty" "github.com/stretchr/testify/require" + "go.uber.org/goleak" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/networkservicemesh/sdk/pkg/registry/common/grpcmetadata" "github.com/networkservicemesh/sdk/pkg/registry/common/updatepath" "github.com/networkservicemesh/sdk/pkg/registry/core/adapters" "github.com/networkservicemesh/sdk/pkg/registry/core/next" "github.com/networkservicemesh/sdk/pkg/registry/utils/checks/checkcontext" "github.com/networkservicemesh/sdk/pkg/registry/utils/inject/injectpeertoken" - - "go.uber.org/goleak" + "github.com/networkservicemesh/sdk/pkg/tools/clock" + "github.com/networkservicemesh/sdk/pkg/tools/clockmock" ) +type pathCheckerNSEClient struct { + funcBefore func(ctx context.Context) *grpcmetadata.Path + funcAfter func(ctx context.Context, pBefore *grpcmetadata.Path) +} + +func newPathCheckerNSEClient(t *testing.T, expectedPathIndex int) registry.NetworkServiceEndpointRegistryClient { + client := &pathCheckerNSEClient{} + + client.funcBefore = func(ctx context.Context) *grpcmetadata.Path { + p := grpcmetadata.PathFromContext(ctx).Clone() + require.Equal(t, int(p.Index), expectedPathIndex) + + return p + } + client.funcAfter = func(ctx context.Context, pBefore *grpcmetadata.Path) { + pAfter := grpcmetadata.PathFromContext(ctx).Clone() + require.Equal(t, int(pAfter.Index), expectedPathIndex) + for i := expectedPathIndex; i < len(pBefore.PathSegments); i++ { + require.NotEqual(t, pBefore.PathSegments[i].Token, pAfter.PathSegments[i].Token) + } + } + return client +} + +func (p *pathCheckerNSEClient) Register(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*registry.NetworkServiceEndpoint, error) { + pBefore := p.funcBefore(ctx) + r, e := next.NetworkServiceEndpointRegistryClient(ctx).Register(ctx, in, opts...) + p.funcAfter(ctx, pBefore) + return r, e +} + +func (p *pathCheckerNSEClient) Find(ctx context.Context, in *registry.NetworkServiceEndpointQuery, opts ...grpc.CallOption) (registry.NetworkServiceEndpointRegistry_FindClient, error) { + return next.NetworkServiceEndpointRegistryClient(ctx).Find(ctx, in, opts...) +} + +func (p *pathCheckerNSEClient) Unregister(ctx context.Context, in *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) (*empty.Empty, error) { + pBefore := p.funcBefore(ctx) + r, e := next.NetworkServiceEndpointRegistryClient(ctx).Unregister(ctx, in, opts...) + p.funcAfter(ctx, pBefore) + return r, e +} + +// This test checks that registry Path is correctly updated and passed through grpc metadata +// Test scheme: client ---> proxyServer ---> server func TestGRPCMetadataNetworkServiceEndpoint(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() - clientToken, _, _ := tokenGeneratorFunc(clientID)(nil) - proxyToken, _, _ := tokenGeneratorFunc(proxyID)(nil) + // Add clockMock to the context + clockMock := clockmock.New(ctx) + ctx = clock.WithClock(ctx, clockMock) serverLis, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) + // tokenGeneratorFunc generates new tokens automatically server := next.NewNetworkServiceEndpointRegistryServer( - injectpeertoken.NewNetworkServiceEndpointRegistryServer(proxyToken), + injectpeertoken.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(clockMock, proxyID)), grpcmetadata.NewNetworkServiceEndpointRegistryServer(), - updatepath.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(serverID)), + updatepath.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(clockMock, serverID)), checkcontext.NewNSEServer(t, func(t *testing.T, ctx context.Context) { path := grpcmetadata.PathFromContext(ctx) require.Equal(t, int(path.Index), 2) @@ -76,14 +126,15 @@ func TestGRPCMetadataNetworkServiceEndpoint(t *testing.T) { }() proxyServer := next.NewNetworkServiceEndpointRegistryServer( - injectpeertoken.NewNetworkServiceEndpointRegistryServer(clientToken), + injectpeertoken.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(clockMock, clientID)), grpcmetadata.NewNetworkServiceEndpointRegistryServer(), - updatepath.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(proxyID)), + updatepath.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(clockMock, proxyID)), checkcontext.NewNSEServer(t, func(t *testing.T, ctx context.Context) { path := grpcmetadata.PathFromContext(ctx) require.Equal(t, int(path.Index), 1) }), adapters.NetworkServiceEndpointClientToServer(next.NewNetworkServiceEndpointRegistryClient( + newPathCheckerNSEClient(t, 1), grpcmetadata.NewNetworkServiceEndpointRegistryClient(), registry.NewNetworkServiceEndpointRegistryClient(serverConn), )), @@ -104,10 +155,7 @@ func TestGRPCMetadataNetworkServiceEndpoint(t *testing.T) { }() client := next.NewNetworkServiceEndpointRegistryClient( - checkcontext.NewNSEClient(t, func(t *testing.T, ctx context.Context) { - path := grpcmetadata.PathFromContext(ctx) - require.Equal(t, int(path.Index), 0) - }), + newPathCheckerNSEClient(t, 0), grpcmetadata.NewNetworkServiceEndpointRegistryClient(), registry.NewNetworkServiceEndpointRegistryClient(conn)) @@ -121,6 +169,10 @@ func TestGRPCMetadataNetworkServiceEndpoint(t *testing.T) { require.Equal(t, int(path.Index), 0) require.Len(t, path.PathSegments, 3) + // Simulate refresh + _, err = client.Register(ctx, nse) + require.NoError(t, err) + _, err = client.Unregister(ctx, nse) require.NoError(t, err) @@ -131,10 +183,12 @@ func TestGRPCMetadataNetworkServiceEndpoint(t *testing.T) { func TestGRPCMetadataNetworkServiceEndpoint_BackwardCompatibility(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) - ctx, cacncel := context.WithTimeout(context.Background(), time.Second) - defer cacncel() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() - clientToken, _, _ := tokenGeneratorFunc(clientID)(nil) + // Add clockMock to the context + clockMock := clockmock.New(ctx) + ctx = clock.WithClock(ctx, clockMock) serverLis, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -162,9 +216,9 @@ func TestGRPCMetadataNetworkServiceEndpoint_BackwardCompatibility(t *testing.T) }() proxyServer := next.NewNetworkServiceEndpointRegistryServer( - injectpeertoken.NewNetworkServiceEndpointRegistryServer(clientToken), + injectpeertoken.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(clockMock, clientID)), grpcmetadata.NewNetworkServiceEndpointRegistryServer(), - updatepath.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(proxyID)), + updatepath.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(clockMock, proxyID)), checkcontext.NewNSEServer(t, func(t *testing.T, ctx context.Context) { path := grpcmetadata.PathFromContext(ctx) require.Equal(t, int(path.Index), 1) @@ -193,19 +247,24 @@ func TestGRPCMetadataNetworkServiceEndpoint_BackwardCompatibility(t *testing.T) }() client := next.NewNetworkServiceEndpointRegistryClient( + newPathCheckerNSEClient(t, 0), grpcmetadata.NewNetworkServiceEndpointRegistryClient(), registry.NewNetworkServiceEndpointRegistryClient(conn)) path := grpcmetadata.Path{} ctx = grpcmetadata.PathWithContext(ctx, &path) - ns := ®istry.NetworkServiceEndpoint{Name: "ns"} - _, err = client.Register(ctx, ns) + nse := ®istry.NetworkServiceEndpoint{Name: "ns"} + _, err = client.Register(ctx, nse) require.NoError(t, err) require.Equal(t, int(path.Index), 0) require.Len(t, path.PathSegments, 2) - _, err = client.Unregister(ctx, ns) + // Simulate refresh + _, err = client.Register(ctx, nse) + require.NoError(t, err) + + _, err = client.Unregister(ctx, nse) require.NoError(t, err) } diff --git a/pkg/registry/common/updatepath/ns_server_test.go b/pkg/registry/common/updatepath/ns_server_test.go index d5894d3b00..124e0d2865 100644 --- a/pkg/registry/common/updatepath/ns_server_test.go +++ b/pkg/registry/common/updatepath/ns_server_test.go @@ -53,7 +53,7 @@ var nsSamples = []*nsSample{ } server := next.NewNetworkServiceRegistryServer( - injectpeertoken.NewNetworkServiceRegistryServer(clientToken), + injectpeertoken.NewNetworkServiceRegistryServer(tokenGeneratorFunc(clientID)), updatepath.NewNetworkServiceRegistryServer(tokenGeneratorFunc(serverID)), ) @@ -107,9 +107,9 @@ var nsSamples = []*nsSample{ } server := next.NewNetworkServiceRegistryServer( - injectpeertoken.NewNetworkServiceRegistryServer(clientToken), + injectpeertoken.NewNetworkServiceRegistryServer(tokenGeneratorFunc(clientID)), updatepath.NewNetworkServiceRegistryServer(tokenGeneratorFunc(proxyID)), - injectpeertoken.NewNetworkServiceRegistryServer(proxyToken), + injectpeertoken.NewNetworkServiceRegistryServer(tokenGeneratorFunc(proxyID)), updatepath.NewNetworkServiceRegistryServer(tokenGeneratorFunc(serverID)), ) @@ -155,9 +155,9 @@ var nsSamples = []*nsSample{ } server := next.NewNetworkServiceRegistryServer( - injectpeertoken.NewNetworkServiceRegistryServer(clientToken), + injectpeertoken.NewNetworkServiceRegistryServer(tokenGeneratorFunc(clientID)), updatepath.NewNetworkServiceRegistryServer(tokenGeneratorFunc(proxyID)), - injectpeertoken.NewNetworkServiceRegistryServer(proxyToken), + injectpeertoken.NewNetworkServiceRegistryServer(tokenGeneratorFunc(proxyID)), updatepath.NewNetworkServiceRegistryServer(tokenGeneratorFunc(serverID)), ) diff --git a/pkg/registry/common/updatepath/nse_server.go b/pkg/registry/common/updatepath/nse_server.go index b072e9b68c..54fd5f95c7 100644 --- a/pkg/registry/common/updatepath/nse_server.go +++ b/pkg/registry/common/updatepath/nse_server.go @@ -21,6 +21,7 @@ import ( "github.com/golang/protobuf/ptypes/empty" "github.com/pkg/errors" + "google.golang.org/protobuf/types/known/timestamppb" "github.com/networkservicemesh/api/pkg/api/registry" @@ -45,11 +46,11 @@ func (s *updatePathNSEServer) Register(ctx context.Context, nse *registry.Networ path := grpcmetadata.PathFromContext(ctx) // Update path - peerTok, _, tokenErr := token.FromContext(ctx) + peerTok, peerExpirationTime, tokenErr := token.FromContext(ctx) if tokenErr != nil { log.FromContext(ctx).Warnf("an error during getting peer token from the context: %+v", tokenErr) } - tok, _, tokenErr := generateToken(ctx, s.tokenGenerator) + tok, expirationTime, tokenErr := generateToken(ctx, s.tokenGenerator) if tokenErr != nil { return nil, errors.Wrap(tokenErr, "an error during generating token") } @@ -70,6 +71,13 @@ func (s *updatePathNSEServer) Register(ctx context.Context, nse *registry.Networ nse.PathIds = updatePathIds(nse.PathIds, int(path.Index-1), peerID.String()) nse.PathIds = updatePathIds(nse.PathIds, int(path.Index), id.String()) + if nse.GetExpirationTime() == nil || peerExpirationTime.Before(nse.GetExpirationTime().AsTime().Local()) { + nse.ExpirationTime = timestamppb.New(peerExpirationTime) + } + if expirationTime.Before(nse.GetExpirationTime().AsTime().Local()) { + nse.ExpirationTime = timestamppb.New(expirationTime) + } + nse, err = next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, nse) if err != nil { return nil, err diff --git a/pkg/registry/common/updatepath/nse_server_test.go b/pkg/registry/common/updatepath/nse_server_test.go index ad6344dd2c..bb4e5935e5 100644 --- a/pkg/registry/common/updatepath/nse_server_test.go +++ b/pkg/registry/common/updatepath/nse_server_test.go @@ -53,7 +53,7 @@ var nseSamples = []*nseSample{ } server := next.NewNetworkServiceEndpointRegistryServer( - injectpeertoken.NewNetworkServiceEndpointRegistryServer(clientToken), + injectpeertoken.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(clientID)), updatepath.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(serverID)), ) @@ -107,9 +107,9 @@ var nseSamples = []*nseSample{ } server := next.NewNetworkServiceEndpointRegistryServer( - injectpeertoken.NewNetworkServiceEndpointRegistryServer(clientToken), + injectpeertoken.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(clientID)), updatepath.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(proxyID)), - injectpeertoken.NewNetworkServiceEndpointRegistryServer(proxyToken), + injectpeertoken.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(proxyID)), updatepath.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(serverID)), ) @@ -155,9 +155,9 @@ var nseSamples = []*nseSample{ } server := next.NewNetworkServiceEndpointRegistryServer( - injectpeertoken.NewNetworkServiceEndpointRegistryServer(clientToken), + injectpeertoken.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(clientID)), updatepath.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(proxyID)), - injectpeertoken.NewNetworkServiceEndpointRegistryServer(proxyToken), + injectpeertoken.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(proxyID)), updatepath.NewNetworkServiceEndpointRegistryServer(tokenGeneratorFunc(serverID)), ) diff --git a/pkg/registry/utils/inject/injectpeertoken/ns_server.go b/pkg/registry/utils/inject/injectpeertoken/ns_server.go index 0499a54eb4..be060c2253 100644 --- a/pkg/registry/utils/inject/injectpeertoken/ns_server.go +++ b/pkg/registry/utils/inject/injectpeertoken/ns_server.go @@ -19,6 +19,8 @@ package injectpeertoken import ( "context" + "github.com/networkservicemesh/sdk/pkg/tools/token" + "google.golang.org/protobuf/types/known/emptypb" "github.com/networkservicemesh/api/pkg/api/registry" @@ -27,18 +29,19 @@ import ( ) type injectSpiffeIDNSServer struct { - peerToken string + tokenGenerator token.GeneratorFunc } // NewNetworkServiceRegistryServer returns a server chain element putting spiffeID to context on Register and Unregister -func NewNetworkServiceRegistryServer(peerToken string) registry.NetworkServiceRegistryServer { +func NewNetworkServiceRegistryServer(tokenGenerator token.GeneratorFunc) registry.NetworkServiceRegistryServer { return &injectSpiffeIDNSServer{ - peerToken: peerToken, + tokenGenerator: tokenGenerator, } } func (s *injectSpiffeIDNSServer) Register(ctx context.Context, ns *registry.NetworkService) (*registry.NetworkService, error) { - ctx = withPeerToken(ctx, s.peerToken) + peerToken, _, _ := s.tokenGenerator(nil) + ctx = withPeerToken(ctx, peerToken) return next.NetworkServiceRegistryServer(ctx).Register(ctx, ns) } @@ -47,6 +50,7 @@ func (s *injectSpiffeIDNSServer) Find(query *registry.NetworkServiceQuery, serve } func (s *injectSpiffeIDNSServer) Unregister(ctx context.Context, ns *registry.NetworkService) (*emptypb.Empty, error) { - ctx = withPeerToken(ctx, s.peerToken) + peerToken, _, _ := s.tokenGenerator(nil) + ctx = withPeerToken(ctx, peerToken) return next.NetworkServiceRegistryServer(ctx).Unregister(ctx, ns) } diff --git a/pkg/registry/utils/inject/injectpeertoken/nse_server.go b/pkg/registry/utils/inject/injectpeertoken/nse_server.go index 4fbdecb80f..20e05a0136 100644 --- a/pkg/registry/utils/inject/injectpeertoken/nse_server.go +++ b/pkg/registry/utils/inject/injectpeertoken/nse_server.go @@ -24,21 +24,23 @@ import ( "github.com/networkservicemesh/api/pkg/api/registry" "github.com/networkservicemesh/sdk/pkg/registry/core/next" + "github.com/networkservicemesh/sdk/pkg/tools/token" ) type injectSpiffeIDNSEServer struct { - peerToken string + tokenGenerator token.GeneratorFunc } -// NewNetworkServiceEndpointRegistryServer returns a server chain element putting spiffeID to context on Register and Unregister -func NewNetworkServiceEndpointRegistryServer(peerToken string) registry.NetworkServiceEndpointRegistryServer { +// NewNetworkServiceEndpointRegistryServer returns a server chain element putting peer token to context on Register and Unregister +func NewNetworkServiceEndpointRegistryServer(tokenGenerator token.GeneratorFunc) registry.NetworkServiceEndpointRegistryServer { return &injectSpiffeIDNSEServer{ - peerToken: peerToken, + tokenGenerator: tokenGenerator, } } func (s *injectSpiffeIDNSEServer) Register(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) { - ctx = withPeerToken(ctx, s.peerToken) + peerToken, _, _ := s.tokenGenerator(nil) + ctx = withPeerToken(ctx, peerToken) return next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, nse) } @@ -47,6 +49,7 @@ func (s *injectSpiffeIDNSEServer) Find(query *registry.NetworkServiceEndpointQue } func (s *injectSpiffeIDNSEServer) Unregister(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*empty.Empty, error) { - ctx = withPeerToken(ctx, s.peerToken) + peerToken, _, _ := s.tokenGenerator(nil) + ctx = withPeerToken(ctx, peerToken) return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, nse) } diff --git a/pkg/tools/opa/policies/common/tokens_expired.rego b/pkg/tools/opa/policies/common/tokens_expired.rego index 4f44d1a4fa..5a87cb286c 100644 --- a/pkg/tools/opa/policies/common/tokens_expired.rego +++ b/pkg/tools/opa/policies/common/tokens_expired.rego @@ -1,4 +1,4 @@ -# Copyright (c) 2020 Cisco and/or its affiliates. +# Copyright (c) 2020-2022 Cisco and/or its affiliates. # # SPDX-License-Identifier: Apache-2.0 # @@ -19,11 +19,11 @@ package nsm default valid = false valid { - count({x | input.path_segments[x]; token_expired(input.path_segments[x].token)}) == count(input.path_segments) + count({x | input.path_segments[x]; token_alive(input.path_segments[x].token)}) == count(input.path_segments) } -token_expired(token) { - print(token) +# alive means not expired +token_alive(token) { [_, payload, _] := io.jwt.decode(token) now < payload.exp }