From f98513f5f9acbff8ef09397a38fc17bc2444ea69 Mon Sep 17 00:00:00 2001 From: Artem Glazychev Date: Fri, 27 Nov 2020 13:47:28 +0700 Subject: [PATCH] Replacing comparison of proto messages (#601) * Replacing comparison of proto messages Signed-off-by: Artem Glazychev * Add tests for network service endpoints matching by label Signed-off-by: Artem Glazychev Signed-off-by: Sergey Ershov --- go.mod | 1 + pkg/networkservice/common/heal/client_test.go | 7 +- .../roundrobin/round_robin_selector_test.go | 5 +- pkg/registry/memory/nse_server_test.go | 131 ++++++++++++++++++ pkg/tools/matchutils/utils.go | 26 +++- 5 files changed, 161 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index a68274b886..6d9c754b62 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/fsnotify/fsnotify v1.4.9 github.com/ghodss/yaml v1.0.0 github.com/golang/protobuf v1.4.3 + github.com/google/go-cmp v0.5.2 github.com/google/uuid v1.1.2 github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 github.com/hashicorp/go-multierror v1.0.0 diff --git a/pkg/networkservice/common/heal/client_test.go b/pkg/networkservice/common/heal/client_test.go index 1ac96e94c4..c3e42e567b 100644 --- a/pkg/networkservice/common/heal/client_test.go +++ b/pkg/networkservice/common/heal/client_test.go @@ -19,10 +19,11 @@ package heal_test import ( "context" "io/ioutil" - "reflect" "testing" "time" + "google.golang.org/protobuf/proto" + "github.com/networkservicemesh/sdk/pkg/tools/sandbox" "github.com/golang/protobuf/ptypes/empty" @@ -194,11 +195,11 @@ func TestNewClient_MissingConnectionsInInit(t *testing.T) { defer cancel() conn, err := client.Request(ctx, &networkservice.NetworkServiceRequest{Connection: conns[0]}) require.Nil(t, err) - require.True(t, reflect.DeepEqual(conn, conns[0])) + require.True(t, proto.Equal(conn, conns[0])) conn, err = client.Request(ctx, &networkservice.NetworkServiceRequest{Connection: conns[1]}) require.Nil(t, err) - require.True(t, reflect.DeepEqual(conn, conns[1])) + require.True(t, proto.Equal(conn, conns[1])) eventCh <- &networkservice.ConnectionEvent{ Type: networkservice.ConnectionEventType_INITIAL_STATE_TRANSFER, diff --git a/pkg/networkservice/common/roundrobin/round_robin_selector_test.go b/pkg/networkservice/common/roundrobin/round_robin_selector_test.go index cc5a4751ba..b1564cc7ed 100644 --- a/pkg/networkservice/common/roundrobin/round_robin_selector_test.go +++ b/pkg/networkservice/common/roundrobin/round_robin_selector_test.go @@ -17,9 +17,10 @@ package roundrobin import ( - "reflect" "testing" + "google.golang.org/protobuf/proto" + "github.com/networkservicemesh/api/pkg/api/registry" "go.uber.org/goleak" ) @@ -228,7 +229,7 @@ func Test_roundRobinSelector_SelectEndpoint(t *testing.T) { defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) rr := newRoundRobinSelector() for _, tt := range tests { - if got := rr.selectEndpoint(tt.args.ns, tt.args.networkServiceEndpoints); !reflect.DeepEqual(got, tt.want) { + if got := rr.selectEndpoint(tt.args.ns, tt.args.networkServiceEndpoints); !proto.Equal(got, tt.want) { t.Errorf("%s: roundRobinSelector.selectEndpoint() = %v, want %v", tt.name, got, tt.want) } } diff --git a/pkg/registry/memory/nse_server_test.go b/pkg/registry/memory/nse_server_test.go index f164a6bbcb..0d3d8ab315 100644 --- a/pkg/registry/memory/nse_server_test.go +++ b/pkg/registry/memory/nse_server_test.go @@ -105,3 +105,134 @@ func TestNetworkServiceEndpointRegistryServer_RegisterAndFindWatch(t *testing.T) close(ch) } + +func TestNetworkServiceEndpointRegistryServer_RegisterAndFindByLabel(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + s := next.NewNetworkServiceEndpointRegistryServer(memory.NewNetworkServiceEndpointRegistryServer()) + + _, err := s.Register(context.Background(), createLabeledNSE1()) + require.NoError(t, err) + + _, err = s.Register(context.Background(), createLabeledNSE2()) + require.NoError(t, err) + + _, err = s.Register(context.Background(), createLabeledNSE3()) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + ch := make(chan *registry.NetworkServiceEndpoint, 1) + _ = s.Find(®istry.NetworkServiceEndpointQuery{ + NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{ + NetworkServiceLabels: map[string]*registry.NetworkServiceLabels{ + "Service1": { + Labels: map[string]string{ + "c": "d", + }, + }, + }, + }, + }, streamchannel.NewNetworkServiceEndpointFindServer(ctx, ch)) + + require.Equal(t, createLabeledNSE2(), <-ch) + cancel() + close(ch) +} + +func TestNetworkServiceEndpointRegistryServer_RegisterAndFindByLabelWatch(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + s := next.NewNetworkServiceEndpointRegistryServer(memory.NewNetworkServiceEndpointRegistryServer()) + + _, err := s.Register(context.Background(), createLabeledNSE1()) + require.NoError(t, err) + + _, err = s.Register(context.Background(), createLabeledNSE2()) + require.NoError(t, err) + + _, err = s.Register(context.Background(), createLabeledNSE3()) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ch := make(chan *registry.NetworkServiceEndpoint, 1) + go func() { + _ = s.Find(®istry.NetworkServiceEndpointQuery{ + Watch: true, + NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{ + NetworkServiceLabels: map[string]*registry.NetworkServiceLabels{ + "Service1": { + Labels: map[string]string{ + "c": "d", + }, + }, + }, + }, + }, streamchannel.NewNetworkServiceEndpointFindServer(ctx, ch)) + }() + + require.Equal(t, createLabeledNSE2(), <-ch) + + expected, err := s.Register(context.Background(), createLabeledNSE2()) + require.NoError(t, err) + require.Equal(t, expected, <-ch) + + close(ch) +} + +func createLabeledNSE1() *registry.NetworkServiceEndpoint { + labels := map[string]*registry.NetworkServiceLabels{ + "Service1": { + Labels: map[string]string{ + "foo": "bar", + }, + }, + } + return ®istry.NetworkServiceEndpoint{ + Name: "nse1", + NetworkServiceNames: []string{ + "Service1", + }, + NetworkServiceLabels: labels, + } +} + +func createLabeledNSE2() *registry.NetworkServiceEndpoint { + labels := map[string]*registry.NetworkServiceLabels{ + "Service1": { + Labels: map[string]string{ + "a": "b", + "c": "d", + }, + }, + "Service2": { + Labels: map[string]string{ + "1": "2", + "3": "4", + }, + }, + } + return ®istry.NetworkServiceEndpoint{ + Name: "nse2", + NetworkServiceNames: []string{ + "Service1", "Service2", + }, + NetworkServiceLabels: labels, + } +} + +func createLabeledNSE3() *registry.NetworkServiceEndpoint { + labels := map[string]*registry.NetworkServiceLabels{ + "Service555": { + Labels: map[string]string{ + "a": "b", + "c": "d", + }, + }, + } + return ®istry.NetworkServiceEndpoint{ + Name: "nse3", + NetworkServiceNames: []string{ + "Service1", + }, + NetworkServiceLabels: labels, + } +} diff --git a/pkg/tools/matchutils/utils.go b/pkg/tools/matchutils/utils.go index ef408bd91b..ba4bc767ea 100644 --- a/pkg/tools/matchutils/utils.go +++ b/pkg/tools/matchutils/utils.go @@ -18,9 +18,11 @@ package matchutils import ( - "reflect" "strings" + "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/proto" + "github.com/networkservicemesh/api/pkg/api/registry" ) @@ -28,19 +30,35 @@ import ( func MatchNetworkServices(left, right *registry.NetworkService) bool { return (left.Name == "" || strings.Contains(right.Name, left.Name)) && (left.Payload == "" || left.Payload == right.Payload) && - (left.Matches == nil || reflect.DeepEqual(left.Matches, right.Matches)) + (left.Matches == nil || cmp.Equal(left.Matches, right.Matches, cmp.Comparer(proto.Equal))) } // MatchNetworkServiceEndpoints returns true if two network service endpoints are matched func MatchNetworkServiceEndpoints(left, right *registry.NetworkServiceEndpoint) bool { return (left.Name == "" || strings.Contains(right.Name, left.Name)) && - (left.NetworkServiceLabels == nil || reflect.DeepEqual(left.NetworkServiceLabels, right.NetworkServiceLabels)) && + (left.NetworkServiceLabels == nil || labelsContains(right.NetworkServiceLabels, left.NetworkServiceLabels)) && (left.ExpirationTime == nil || left.ExpirationTime.Seconds == right.ExpirationTime.Seconds) && (left.NetworkServiceNames == nil || contains(right.NetworkServiceNames, left.NetworkServiceNames)) && (left.Url == "" || strings.Contains(right.Url, left.Url)) } -func contains(what, where []string) bool { +func labelsContains(where, what map[string]*registry.NetworkServiceLabels) bool { + for lService, lLabels := range what { + rService, ok := where[lService] + if !ok { + return false + } + for lKey, lVal := range lLabels.Labels { + rVal, ok := rService.Labels[lKey] + if !ok || lVal != rVal { + return false + } + } + } + return true +} + +func contains(where, what []string) bool { set := make(map[string]struct{}) for _, s := range what { set[s] = struct{}{}