diff --git a/pkg/networkservice/chains/nsmgr/upstreamrefresh_test.go b/pkg/networkservice/chains/nsmgr/upstreamrefresh_test.go new file mode 100644 index 000000000..57a93093e --- /dev/null +++ b/pkg/networkservice/chains/nsmgr/upstreamrefresh_test.go @@ -0,0 +1,246 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nsmgr_test + +import ( + "context" + "testing" + "time" + + "github.com/google/uuid" + "go.uber.org/goleak" + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/networkservicemesh/api/pkg/api/networkservice" + "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/stretchr/testify/require" + + "github.com/networkservicemesh/sdk/pkg/networkservice/chains/client" + "github.com/networkservicemesh/sdk/pkg/networkservice/common/monitor" + "github.com/networkservicemesh/sdk/pkg/networkservice/common/upstreamrefresh" + "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" + "github.com/networkservicemesh/sdk/pkg/networkservice/utils/count" + "github.com/networkservicemesh/sdk/pkg/tools/sandbox" +) + +func Test_UpstreamRefreshClient(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + domain := sandbox.NewBuilder(ctx, t). + SetNodesCount(1). + SetNSMgrProxySupplier(nil). + SetRegistryProxySupplier(nil). + Build() + + nsRegistryClient := domain.NewNSRegistryClient(ctx, sandbox.GenerateTestToken) + + nsReg, err := nsRegistryClient.Register(ctx, defaultRegistryService("my-service")) + require.NoError(t, err) + + nseReg := defaultRegistryEndpoint(nsReg.Name) + + // This NSE will send REFRESH_REQUESTED events if mtu will be changed + counter := new(count.Server) + _ = domain.Nodes[0].NewEndpoint( + ctx, + nseReg, + sandbox.GenerateTestToken, + newRefreshSenderServer(), + counter, + ) + + // Create the first client + nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken, client.WithAdditionalFunctionality(upstreamrefresh.NewClient(ctx))) + reqCtx, reqClose := context.WithTimeout(ctx, time.Second) + defer reqClose() + + req := defaultRequest(nsReg.Name) + req.Connection.Id = uuid.New().String() + req.GetConnection().GetContext().MTU = defaultMtu + + conn, err := nsc.Request(reqCtx, req) + require.NoError(t, err) + require.Equal(t, 1, counter.UniqueRequests()) + + // Create the second client + nsc2 := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken, client.WithAdditionalFunctionality(upstreamrefresh.NewClient(ctx))) + reqCtx2, reqClose2 := context.WithTimeout(ctx, time.Second) + defer reqClose2() + + // Change MTU for the second client + req2 := defaultRequest(nsReg.Name) + req2.Connection.Id = uuid.New().String() + req2.GetConnection().GetContext().MTU = 1000 + + conn2, err := nsc2.Request(reqCtx2, req2) + require.NoError(t, err) + require.Equal(t, 2, counter.UniqueRequests()) + + // The request from the second client should trigger refresh on the first one + require.Eventually(t, func() bool { return counter.Requests() == 3 }, timeout, tick) + + _, err = nsc.Close(ctx, conn) + require.NoError(t, err) + _, err = nsc.Close(ctx, conn2) + require.NoError(t, err) +} + +func Test_UpstreamRefreshClient_LocalNotifications(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + domain := sandbox.NewBuilder(ctx, t). + SetNodesCount(1). + SetNSMgrProxySupplier(nil). + SetRegistryProxySupplier(nil). + Build() + + nsRegistryClient := domain.NewNSRegistryClient(ctx, sandbox.GenerateTestToken) + + nsReg, err := nsRegistryClient.Register(ctx, defaultRegistryService("my-service")) + require.NoError(t, err) + + // Create the first NSE + nseReg := ®istry.NetworkServiceEndpoint{ + Name: "final-endpoint1", + NetworkServiceNames: []string{nsReg.Name}, + } + counter1 := new(count.Server) + _ = domain.Nodes[0].NewEndpoint( + ctx, + nseReg, + sandbox.GenerateTestToken, + newRefreshSenderServer(), + counter1, + ) + + // Create the second NSE + nseReg2 := ®istry.NetworkServiceEndpoint{ + Name: "final-endpoint2", + NetworkServiceNames: []string{nsReg.Name}, + } + counter2 := new(count.Server) + _ = domain.Nodes[0].NewEndpoint( + ctx, + nseReg2, + sandbox.GenerateTestToken, + newRefreshSenderServer(), + counter2, + ) + + // Create the client + nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken, client.WithAdditionalFunctionality(upstreamrefresh.NewClient(ctx, upstreamrefresh.WithLocalNotifications()))) + + // Send request --> NSE1 + reqCtx, reqClose := context.WithTimeout(ctx, time.Second) + defer reqClose() + + req := defaultRequest(nsReg.Name) + req.Connection.Id = "1" + req.GetConnection().NetworkServiceEndpointName = nseReg.Name + req.GetConnection().GetContext().MTU = defaultMtu + + conn, err := nsc.Request(reqCtx, req) + require.NoError(t, err) + require.Equal(t, 1, counter1.UniqueRequests()) + + // Send request2 --> NSE2 + reqCtx2, reqClose2 := context.WithTimeout(ctx, time.Second) + defer reqClose2() + + req2 := defaultRequest(nsReg.Name) + req2.Connection.Id = "2" + req2.GetConnection().NetworkServiceEndpointName = nseReg2.Name + req2.GetConnection().GetContext().MTU = defaultMtu + + conn2, err := nsc.Request(reqCtx2, req2) + require.NoError(t, err) + require.Equal(t, 1, counter2.UniqueRequests()) + + // Send request3 --> NSE1 with different MTU + reqCtx3, reqClose3 := context.WithTimeout(ctx, time.Second) + defer reqClose3() + + req3 := defaultRequest(nsReg.Name) + req3.Connection.Id = "3" + req3.GetConnection().NetworkServiceEndpointName = nseReg.Name + req3.GetConnection().GetContext().MTU = 1000 + + conn3, err := nsc.Request(reqCtx3, req3) + require.NoError(t, err) + require.Equal(t, 2, counter1.UniqueRequests()) + + // Third request should trigger the first and the second to refresh their connections even if they connected to different endpoints + require.Eventually(t, func() bool { return counter2.Requests() == 2 }, timeout, tick) + require.Equal(t, 1, counter2.UniqueRequests()) + + _, err = nsc.Close(ctx, conn) + require.NoError(t, err) + _, err = nsc.Close(ctx, conn2) + require.NoError(t, err) + _, err = nsc.Close(ctx, conn3) + require.NoError(t, err) +} + +type refreshSenderServer struct { + m map[string]*networkservice.Connection + mtu uint32 +} + +const defaultMtu = 9000 + +func newRefreshSenderServer() *refreshSenderServer { + return &refreshSenderServer{ + m: make(map[string]*networkservice.Connection), + mtu: defaultMtu, + } +} + +func (r *refreshSenderServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) { + conn, err := next.Server(ctx).Request(ctx, request) + if err != nil { + return nil, err + } + if conn.GetContext().GetMTU() != r.mtu { + if _, ok := r.m[conn.Id]; ok { + return conn, err + } + ec, _ := monitor.LoadEventConsumer(ctx, false) + + connectionsToSend := make(map[string]*networkservice.Connection) + for k, v := range r.m { + connectionsToSend[k] = v.Clone() + connectionsToSend[k].State = networkservice.State_REFRESH_REQUESTED + } + _ = ec.Send(&networkservice.ConnectionEvent{ + Type: networkservice.ConnectionEventType_UPDATE, + Connections: connectionsToSend, + }) + } + r.m[conn.Id] = conn + + return conn, err +} + +func (r *refreshSenderServer) Close(ctx context.Context, conn *networkservice.Connection) (*emptypb.Empty, error) { + return next.Server(ctx).Close(ctx, conn) +} diff --git a/pkg/networkservice/common/upstreamrefresh/client.go b/pkg/networkservice/common/upstreamrefresh/client.go index 37b5a24c0..823353946 100644 --- a/pkg/networkservice/common/upstreamrefresh/client.go +++ b/pkg/networkservice/common/upstreamrefresh/client.go @@ -27,6 +27,7 @@ import ( "github.com/networkservicemesh/sdk/pkg/networkservice/common/clientconn" "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" + "github.com/networkservicemesh/sdk/pkg/networkservice/utils/metadata" "github.com/networkservicemesh/sdk/pkg/tools/extend" "github.com/networkservicemesh/sdk/pkg/tools/postpone" ) @@ -55,6 +56,11 @@ func (u *upstreamRefreshClient) Request(ctx context.Context, request *networkser if cancelEventLoop, loaded := loadAndDelete(ctx); loaded { cancelEventLoop() } + u.localNotifier.unsubscribe(request.GetConnection().GetId()) + + if u.localNotifier != nil { + storeLocalNotifier(ctx, metadata.IsClient(u), u.localNotifier) + } conn, err := next.Client(ctx).Request(ctx, request, opts...) if err != nil { diff --git a/pkg/networkservice/common/upstreamrefresh/eventloop.go b/pkg/networkservice/common/upstreamrefresh/eventloop.go index b16b111e7..3d4e297a6 100644 --- a/pkg/networkservice/common/upstreamrefresh/eventloop.go +++ b/pkg/networkservice/common/upstreamrefresh/eventloop.go @@ -103,7 +103,7 @@ func (cev *eventLoop) eventLoop() { cev.logger.Debug("refresh requested from upstream") <-cev.eventFactory.Request() - cev.localNotifier.notify(cev.eventLoopCtx, cev.conn.GetId()) + cev.localNotifier.Notify(cev.eventLoopCtx, cev.conn.GetId()) case _, ok := <-localCh: if !ok { diff --git a/pkg/networkservice/common/upstreamrefresh/gen.go b/pkg/networkservice/common/upstreamrefresh/gen.go deleted file mode 100644 index 722e08ed2..000000000 --- a/pkg/networkservice/common/upstreamrefresh/gen.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) 2022 Cisco and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package upstreamrefresh - -import ( - "sync" -) - -type typeCh = chan struct{} - -//go:generate go-syncmap -output notifier_map.gen.go -type notifierMap - -// clientMap - sync.Map with key == string and value == chan struct{} -type notifierMap sync.Map diff --git a/pkg/networkservice/common/upstreamrefresh/metadata.go b/pkg/networkservice/common/upstreamrefresh/metadata.go index c368f456f..1cc58f508 100644 --- a/pkg/networkservice/common/upstreamrefresh/metadata.go +++ b/pkg/networkservice/common/upstreamrefresh/metadata.go @@ -23,6 +23,7 @@ import ( ) type key struct{} +type keyNotifier struct{} // store sets the context.CancelFunc stored in per Connection.Id metadata. func store(ctx context.Context, cancel context.CancelFunc) { @@ -39,3 +40,19 @@ func loadAndDelete(ctx context.Context) (value context.CancelFunc, ok bool) { value, ok = rawValue.(context.CancelFunc) return value, ok } + +// storeLocalNotifier sets the Notifier stored in per Connection.Id metadata. +func storeLocalNotifier(ctx context.Context, isClient bool, notifier Notifier) { + metadata.Map(ctx, isClient).Store(keyNotifier{}, notifier) +} + +// LoadLocalNotifier loads Notifier stored in per Connection.Id metadata. +// The loaded result reports whether the key was present. +func LoadLocalNotifier(ctx context.Context, isClient bool) (value Notifier, ok bool) { + rawValue, ok := metadata.Map(ctx, isClient).Load(keyNotifier{}) + if !ok { + return + } + value, ok = rawValue.(Notifier) + return value, ok +} diff --git a/pkg/networkservice/common/upstreamrefresh/notifier.go b/pkg/networkservice/common/upstreamrefresh/notifier.go index 7288e2b7d..d45188411 100644 --- a/pkg/networkservice/common/upstreamrefresh/notifier.go +++ b/pkg/networkservice/common/upstreamrefresh/notifier.go @@ -19,55 +19,75 @@ package upstreamrefresh import ( "context" + "github.com/edwarnicke/serialize" + "github.com/networkservicemesh/sdk/pkg/tools/log" ) // notifier - notifies all subscribers of an event type notifier struct { - channels notifierMap + executor serialize.Executor + channels map[string]chan struct{} } func newNotifier() *notifier { - return ¬ifier{} + return ¬ifier{ + channels: make(map[string]chan struct{}), + } } func (n *notifier) subscribe(id string) { if n == nil { return } - n.unsubscribe(id) - n.channels.Store(id, make(chan struct{})) + <-n.executor.AsyncExec(func() { + n.channels[id] = make(chan struct{}) + }) } func (n *notifier) get(id string) <-chan struct{} { if n == nil { return nil } - if v, ok := n.channels.Load(id); ok { - return v - } - return nil + var ch chan struct{} = nil + <-n.executor.AsyncExec(func() { + if v, ok := n.channels[id]; ok { + ch = v + } + }) + return ch } func (n *notifier) unsubscribe(id string) { if n == nil { return } - if v, ok := n.channels.LoadAndDelete(id); ok { - close(v) - } + <-n.executor.AsyncExec(func() { + if v, ok := n.channels[id]; ok { + close(v) + } + delete(n.channels, id) + }) } -func (n *notifier) notify(ctx context.Context, initiatorID string) { +func (n *notifier) Notify(ctx context.Context, initiatorID string) { if n == nil { return } - n.channels.Range(func(key string, value typeCh) bool { - if initiatorID == key { - return true + <-n.executor.AsyncExec(func() { + for k, v := range n.channels { + if initiatorID == k { + continue + } + log.FromContext(ctx).WithField("upstreamrefresh", "notifier").Debugf("send notification to: %v", k) + v <- struct{}{} } - log.FromContext(ctx).WithField("upstreamrefresh", "notifier").Debug("send notification to: %v", key) - value <- struct{}{} - return true }) } + +// Notifier - interface for local notifications sending +type Notifier interface { + Notify(ctx context.Context, initiatorID string) +} + +var _ Notifier = ¬ifier{} diff --git a/pkg/networkservice/common/upstreamrefresh/notifier_map.gen.go b/pkg/networkservice/common/upstreamrefresh/notifier_map.gen.go deleted file mode 100644 index a4f0e7753..000000000 --- a/pkg/networkservice/common/upstreamrefresh/notifier_map.gen.go +++ /dev/null @@ -1,73 +0,0 @@ -// Code generated by "-output notifier_map.gen.go -type notifierMap -output notifier_map.gen.go -type notifierMap"; DO NOT EDIT. -package upstreamrefresh - -import ( - "sync" // Used by sync.Map. -) - -// Generate code that will fail if the constants change value. -func _() { - // An "cannot convert notifierMap literal (type notifierMap) to type sync.Map" compiler error signifies that the base type have changed. - // Re-run the go-syncmap command to generate them again. - _ = (sync.Map)(notifierMap{}) -} - -var _nil_notifierMap_typeCh_value = func() (val typeCh) { return }() - -// Load returns the value stored in the map for a key, or nil if no -// value is present. -// The ok result indicates whether value was found in the map. -func (m *notifierMap) Load(key string) (typeCh, bool) { - value, ok := (*sync.Map)(m).Load(key) - if value == nil { - return _nil_notifierMap_typeCh_value, ok - } - return value.(typeCh), ok -} - -// Store sets the value for a key. -func (m *notifierMap) Store(key string, value typeCh) { - (*sync.Map)(m).Store(key, value) -} - -// LoadOrStore returns the existing value for the key if present. -// Otherwise, it stores and returns the given value. -// The loaded result is true if the value was loaded, false if stored. -func (m *notifierMap) LoadOrStore(key string, value typeCh) (typeCh, bool) { - actual, loaded := (*sync.Map)(m).LoadOrStore(key, value) - if actual == nil { - return _nil_notifierMap_typeCh_value, loaded - } - return actual.(typeCh), loaded -} - -// LoadAndDelete deletes the value for a key, returning the previous value if any. -// The loaded result reports whether the key was present. -func (m *notifierMap) LoadAndDelete(key string) (value typeCh, loaded bool) { - actual, loaded := (*sync.Map)(m).LoadAndDelete(key) - if actual == nil { - return _nil_notifierMap_typeCh_value, loaded - } - return actual.(typeCh), loaded -} - -// Delete deletes the value for a key. -func (m *notifierMap) Delete(key string) { - (*sync.Map)(m).Delete(key) -} - -// Range calls f sequentially for each key and value present in the map. -// If f returns false, range stops the iteration. -// -// Range does not necessarily correspond to any consistent snapshot of the Map's -// contents: no key will be visited more than once, but if the value for any key -// is stored or deleted concurrently, Range may reflect any mapping for that key -// from any point during the Range call. -// -// Range may be O(N) with the number of elements in the map even if f returns -// false after a constant number of calls. -func (m *notifierMap) Range(f func(key string, value typeCh) bool) { - (*sync.Map)(m).Range(func(key, value interface{}) bool { - return f(key.(string), value.(typeCh)) - }) -} diff --git a/pkg/networkservice/connectioncontext/mtu/vl3mtu/client.go b/pkg/networkservice/connectioncontext/mtu/vl3mtu/client.go index 3d4b0eb25..706dcb44e 100644 --- a/pkg/networkservice/connectioncontext/mtu/vl3mtu/client.go +++ b/pkg/networkservice/connectioncontext/mtu/vl3mtu/client.go @@ -21,11 +21,14 @@ import ( "sync" "sync/atomic" - "github.com/networkservicemesh/api/pkg/api/networkservice" "google.golang.org/grpc" "google.golang.org/protobuf/types/known/emptypb" + "github.com/networkservicemesh/api/pkg/api/networkservice" + + "github.com/networkservicemesh/sdk/pkg/networkservice/common/upstreamrefresh" "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" + "github.com/networkservicemesh/sdk/pkg/networkservice/utils/metadata" ) type vl3MtuClient struct { @@ -59,7 +62,11 @@ func (v *vl3MtuClient) Request(ctx context.Context, request *networkservice.Netw } // Update MTU of the vl3 - v.updateMinMTU(conn) + if v.updateMinMTU(conn) { + if ln, ok := upstreamrefresh.LoadLocalNotifier(ctx, metadata.IsClient(v)); ok { + ln.Notify(ctx, conn.GetId()) + } + } return conn, nil } @@ -68,15 +75,17 @@ func (v *vl3MtuClient) Close(ctx context.Context, conn *networkservice.Connectio return next.Client(ctx).Close(ctx, conn, opts...) } -func (v *vl3MtuClient) updateMinMTU(conn *networkservice.Connection) { +// updateMinMTU - returns true if mtu was updated +func (v *vl3MtuClient) updateMinMTU(conn *networkservice.Connection) bool { if atomic.LoadUint32(&v.minMtu) <= conn.GetContext().GetMTU() { - return + return false } v.m.Lock() defer v.m.Unlock() if atomic.LoadUint32(&v.minMtu) <= conn.GetContext().GetMTU() { - return + return false } atomic.StoreUint32(&v.minMtu, conn.GetContext().GetMTU()) + return true } diff --git a/pkg/networkservice/connectioncontext/mtu/vl3mtu/client_test.go b/pkg/networkservice/connectioncontext/mtu/vl3mtu/client_test.go index 28e98a5b2..dba5063f9 100644 --- a/pkg/networkservice/connectioncontext/mtu/vl3mtu/client_test.go +++ b/pkg/networkservice/connectioncontext/mtu/vl3mtu/client_test.go @@ -27,6 +27,7 @@ import ( "github.com/networkservicemesh/sdk/pkg/networkservice/connectioncontext/mtu/vl3mtu" "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" + "github.com/networkservicemesh/sdk/pkg/networkservice/utils/metadata" ) func Test_vl3MtuClient(t *testing.T) { @@ -36,6 +37,7 @@ func Test_vl3MtuClient(t *testing.T) { defer cancel() client := next.NewNetworkServiceClient( + metadata.NewClient(), vl3mtu.NewClient(), )