diff --git a/pkg/registry/common/expire/nse_server.go b/pkg/registry/common/expire/nse_server.go index 0ffcbefec..be0d5edb4 100644 --- a/pkg/registry/common/expire/nse_server.go +++ b/pkg/registry/common/expire/nse_server.go @@ -1,6 +1,6 @@ // Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // -// Copyright (c) 2023 Cisco and/or its affiliates. +// Copyright (c) 2023-2024 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -29,9 +29,11 @@ import ( "github.com/networkservicemesh/api/pkg/api/registry" "github.com/networkservicemesh/sdk/pkg/registry/common/begin" + "github.com/networkservicemesh/sdk/pkg/registry/common/updatepath" "github.com/networkservicemesh/sdk/pkg/registry/core/next" "github.com/networkservicemesh/sdk/pkg/tools/clock" "github.com/networkservicemesh/sdk/pkg/tools/log" + "github.com/networkservicemesh/sdk/pkg/tools/token" ) type expireNSEServer struct { @@ -67,9 +69,11 @@ func (s *expireNSEServer) Register(ctx context.Context, nse *registry.NetworkSer requestTimeout = 0 } - expirationTime := nse.GetExpirationTime().AsTime() - if nse.GetExpirationTime() == nil { - expirationTime = timeClock.Now().Add(s.defaultExpiration).Local() + // Select the min(tokenExpirationTime, peerExpirationTime, defaultExpirationTime) + expirationTime, expirationTimeSelected := s.selectMinExpirationTime(ctx) + + // Update nse ExpirationTime if expirationTime is before + if expirationTimeSelected && (nse.GetExpirationTime() == nil || expirationTime.Before(nse.GetExpirationTime().AsTime().Local())) { nse.ExpirationTime = timestamppb.New(expirationTime) logger.Infof("selected expiration time %v for %v", expirationTime, nse.GetName()) } @@ -79,7 +83,7 @@ func (s *expireNSEServer) Register(ctx context.Context, nse *registry.NetworkSer return nil, err } - if nseExpirationTime := resp.GetExpirationTime().AsTime().Local(); nseExpirationTime.Before(expirationTime) { + if nseExpirationTime := resp.GetExpirationTime().AsTime().Local(); expirationTimeSelected && nseExpirationTime.Before(expirationTime) { expirationTime = nseExpirationTime logger.Infof("selected expiration time %v for %v", expirationTime, resp.GetName()) } @@ -114,3 +118,34 @@ func (s *expireNSEServer) Unregister(ctx context.Context, nse *registry.NetworkS } return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, nse) } + +func (s *expireNSEServer) selectMinExpirationTime(ctx context.Context) (time.Time, bool) { + timeClock := clock.FromContext(ctx) + + var expirationTime *time.Time + if tokenExpirationTime := updatepath.ExpirationTimeFromContext(ctx); tokenExpirationTime != nil { + expirationTime = tokenExpirationTime + } else { + log.FromContext(ctx).Warn("error during getting token expiration time from the context") + } + + if _, peerExpirationTime, peerTokenErr := token.FromContext(ctx); peerTokenErr == nil { + if expirationTime == nil || peerExpirationTime.Before(*expirationTime) { + expirationTime = &peerExpirationTime + } + } else { + log.FromContext(ctx).Warnf("error during getting peer expiration time from the context: %v", peerTokenErr) + } + + if s.defaultExpiration != 0 { + defaultExpirationTime := timeClock.Now().Add(s.defaultExpiration).Local() + if expirationTime == nil || defaultExpirationTime.Before(*expirationTime) { + expirationTime = &defaultExpirationTime + } + } + if expirationTime == nil { + return time.Time{}, false + } + + return *expirationTime, true +} diff --git a/pkg/registry/common/expire/nse_server_test.go b/pkg/registry/common/expire/nse_server_test.go index 82cfe4a4b..774fd3296 100644 --- a/pkg/registry/common/expire/nse_server_test.go +++ b/pkg/registry/common/expire/nse_server_test.go @@ -1,6 +1,6 @@ // Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // -// Copyright (c) 2023 Cisco and/or its affiliates. +// Copyright (c) 2023-2024 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -170,6 +170,28 @@ func TestExpireNSEServer_ShouldSetDefaultExpiration(t *testing.T) { require.Equal(t, expireTimeout, clockMock.Until(resp.ExpirationTime.AsTime())) } +func TestExpireNSEServer_ShouldUseLessExpirationTime_DefaultExpireTimeout(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + clockMock := clockmock.New(ctx) + ctx = clock.WithClock(ctx, clockMock) + + s := next.NewNetworkServiceEndpointRegistryServer( + injectpeertoken.NewNetworkServiceEndpointRegistryServer(generateTestToken(ctx, expireTimeout)), + updatepath.NewNetworkServiceEndpointRegistryServer(generateTestToken(ctx, expireTimeout)), + begin.NewNetworkServiceEndpointRegistryServer(), + expire.NewNetworkServiceEndpointRegistryServer(ctx, expire.WithDefaultExpiration(expireTimeout/2)), + ) + + resp, err := s.Register(ctx, ®istry.NetworkServiceEndpoint{Name: "nse-1"}) + require.NoError(t, err) + + require.Equal(t, expireTimeout/2, clockMock.Until(resp.ExpirationTime.AsTime())) +} + func TestExpireNSEServer_ShouldUseLessExpirationTimeFromResponse(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) @@ -187,6 +209,7 @@ func TestExpireNSEServer_ShouldUseLessExpirationTimeFromResponse(t *testing.T) { new(remoteNSEServer), // <-- GRPC invocation begin.NewNetworkServiceEndpointRegistryServer(), updatepath.NewNetworkServiceEndpointRegistryServer(generateTestToken(ctx, expireTimeout/2)), + expire.NewNetworkServiceEndpointRegistryServer(ctx), ) resp, err := s.Register(ctx, ®istry.NetworkServiceEndpoint{Name: "nse-1"}) diff --git a/pkg/registry/common/updatepath/context.go b/pkg/registry/common/updatepath/context.go new file mode 100644 index 000000000..ce63d78a9 --- /dev/null +++ b/pkg/registry/common/updatepath/context.go @@ -0,0 +1,37 @@ +// Copyright (c) 2024 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package updatepath + +import ( + "context" + "time" +) + +type key struct{} + +// ExpirationTimeFromContext returns the expiration time stored in context +func ExpirationTimeFromContext(ctx context.Context) *time.Time { + if value, ok := ctx.Value(key{}).(*time.Time); ok { + return value + } + return nil +} + +// withExpirationTime sets the expiration time stored in context +func withExpirationTime(ctx context.Context, t *time.Time) context.Context { + return context.WithValue(ctx, key{}, t) +} diff --git a/pkg/registry/common/updatepath/nse_server.go b/pkg/registry/common/updatepath/nse_server.go index e9e5bf2ce..86e182f55 100644 --- a/pkg/registry/common/updatepath/nse_server.go +++ b/pkg/registry/common/updatepath/nse_server.go @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Cisco and/or its affiliates. +// Copyright (c) 2022-2024 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -21,7 +21,6 @@ import ( "github.com/golang/protobuf/ptypes/empty" "github.com/pkg/errors" - "google.golang.org/protobuf/types/known/timestamppb" "github.com/networkservicemesh/api/pkg/api/registry" @@ -46,7 +45,7 @@ func (s *updatePathNSEServer) Register(ctx context.Context, nse *registry.Networ path := grpcmetadata.PathFromContext(ctx) // Update path - peerTok, peerExpirationTime, peerTokenErr := token.FromContext(ctx) + peerTok, _, peerTokenErr := token.FromContext(ctx) if peerTokenErr != nil { log.FromContext(ctx).Warnf("an error during getting peer token from the context: %+v", peerTokenErr) } @@ -71,12 +70,7 @@ func (s *updatePathNSEServer) Register(ctx context.Context, nse *registry.Networ nse.PathIds = updatePathIds(nse.PathIds, int(path.Index-1), peerID.String()) nse.PathIds = updatePathIds(nse.PathIds, int(path.Index), id.String()) - if nse.GetExpirationTime() == nil || expirationTime.Before(nse.GetExpirationTime().AsTime().Local()) { - nse.ExpirationTime = timestamppb.New(expirationTime) - } - if peerTokenErr == nil && peerExpirationTime.Before(nse.GetExpirationTime().AsTime().Local()) { - nse.ExpirationTime = timestamppb.New(peerExpirationTime) - } + ctx = withExpirationTime(ctx, &expirationTime) nse, err = next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, nse) if err != nil { diff --git a/pkg/tools/token/context.go b/pkg/tools/token/context.go index 6b0d15a81..451cc458b 100644 --- a/pkg/tools/token/context.go +++ b/pkg/tools/token/context.go @@ -1,6 +1,6 @@ // Copyright (c) 2021 Doc.ai and/or its affiliates. // -// Copyright (c) 2023 Cisco and/or its affiliates. +// Copyright (c) 2023-2024 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -37,7 +37,7 @@ func FromContext(ctx context.Context) (string, time.Time, error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { - return "", time.Time{}, errors.New("metadata is missed in ctx") + return "", time.Time{}, errors.New("grpc metadata is missed in ctx") } token, err := getSingleValue(md, tokenKey)