diff --git a/pkg/networkservice/common/discover/server.go b/pkg/networkservice/common/discover/server.go index 1c2fa42fd..33b5e2081 100644 --- a/pkg/networkservice/common/discover/server.go +++ b/pkg/networkservice/common/discover/server.go @@ -109,8 +109,10 @@ func (d *discoverCandidatesServer) discoverNetworkServiceEndpoint(ctx context.Co } nseList := registry.ReadNetworkServiceEndpointList(nseStream) - if len(nseList) != 0 { - return nseList[0], nil + for _, nse := range nseList { + if nse.Name == nseName { + return nse, nil + } } query.Watch = true @@ -119,7 +121,16 @@ func (d *discoverCandidatesServer) discoverNetworkServiceEndpoint(ctx context.Co if err != nil { return nil, errors.WithStack(err) } - return nseStream.Recv() + for { + var nse *registry.NetworkServiceEndpoint + if nse, err = nseStream.Recv(); err != nil { + return nil, errors.WithStack(err) + } + + if nse.Name == nseName { + return nse, nil + } + } } func (d *discoverCandidatesServer) discoverNetworkServiceEndpoints(ctx context.Context, ns *registry.NetworkService, labels map[string]string) ([]*registry.NetworkServiceEndpoint, error) { @@ -152,7 +163,7 @@ func (d *discoverCandidatesServer) discoverNetworkServiceEndpoints(ctx context.C for { var nse *registry.NetworkServiceEndpoint if nse, err = nseStream.Recv(); err != nil { - return nil, err + return nil, errors.WithStack(err) } result = matchEndpoint(labels, ns, nse) @@ -176,8 +187,10 @@ func (d *discoverCandidatesServer) discoverNetworkService(ctx context.Context, n } nsList := registry.ReadNetworkServiceList(nsStream) - if len(nsList) != 0 { - return nsList[0], nil + for _, ns := range nsList { + if ns.Name == name { + return ns, nil + } } ctx, cancelFind := context.WithCancel(ctx) @@ -189,5 +202,14 @@ func (d *discoverCandidatesServer) discoverNetworkService(ctx context.Context, n if err != nil { return nil, errors.WithStack(err) } - return nsStream.Recv() + for { + var ns *registry.NetworkService + if ns, err = nsStream.Recv(); err != nil { + return nil, errors.WithStack(err) + } + + if ns.Name == name { + return ns, nil + } + } } diff --git a/pkg/networkservice/common/discover/server_test.go b/pkg/networkservice/common/discover/server_test.go index 415bf441a..0f615ccc4 100644 --- a/pkg/networkservice/common/discover/server_test.go +++ b/pkg/networkservice/common/discover/server_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -377,3 +377,122 @@ func TestNoMatchServiceEndpointFound(t *testing.T) { _, err = server.Request(ctx, request) require.Error(t, err) } + +func TestMatchExactService(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + + nsServer := memory.NewNetworkServiceRegistryServer() + nseServer := registrynext.NewNetworkServiceEndpointRegistryServer( + setid.NewNetworkServiceEndpointRegistryServer(), + memory.NewNetworkServiceEndpointRegistryServer(), + ) + + nsName := networkServiceName() + server := next.NewNetworkServiceServer( + discover.NewServer( + adapters.NetworkServiceServerToClient(nsServer), + adapters.NetworkServiceEndpointServerToClient(nseServer)), + checkcontext.NewServer(t, func(t *testing.T, ctx context.Context) { + nses := discover.Candidates(ctx).Endpoints + require.Len(t, nses, 1) + require.Equal(t, nsName, nses[0].NetworkServiceNames[0]) + }), + ) + + // 1. Register NS, NSE with wrong name + wrongNSName := nsName + "-wrong" + _, err := nsServer.Register(context.Background(), ®istry.NetworkService{ + Name: wrongNSName, + }) + require.NoError(t, err) + _, err = nseServer.Register(context.Background(), ®istry.NetworkServiceEndpoint{ + NetworkServiceNames: []string{wrongNSName}, + }) + require.NoError(t, err) + + // 2. Try to discover NSE by the right NS name + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + request := &networkservice.NetworkServiceRequest{ + Connection: &networkservice.Connection{ + NetworkService: nsName, + }, + } + + _, err = server.Request(ctx, request.Clone()) + require.Error(t, err) + + // 3. Register NS, NSE with the right name + _, err = nsServer.Register(context.Background(), ®istry.NetworkService{ + Name: nsName, + }) + require.NoError(t, err) + _, err = nseServer.Register(context.Background(), ®istry.NetworkServiceEndpoint{ + NetworkServiceNames: []string{nsName}, + }) + require.NoError(t, err) + + // 4. Try to discover NSE by the right NS name + ctx, cancel = context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + _, err = server.Request(ctx, request.Clone()) + require.NoError(t, err) +} + +func TestMatchExactEndpoint(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + + nseServer := registrynext.NewNetworkServiceEndpointRegistryServer( + setid.NewNetworkServiceEndpointRegistryServer(), + memory.NewNetworkServiceEndpointRegistryServer(), + ) + + nseName := "final-endpoint" + u := "tcp://" + nseName + server := next.NewNetworkServiceServer( + discover.NewServer( + adapters.NetworkServiceServerToClient(memory.NewNetworkServiceRegistryServer()), + adapters.NetworkServiceEndpointServerToClient(nseServer)), + checkcontext.NewServer(t, func(t *testing.T, ctx context.Context) { + require.Equal(t, u, clienturlctx.ClientURL(ctx).String()) + }), + ) + + // 1. Register NSE with wrong name + wrongNSEName := nseName + "-wrong" + wrongURL := u + "-wrong" + _, err := nseServer.Register(context.Background(), ®istry.NetworkServiceEndpoint{ + Name: wrongNSEName, + Url: wrongURL, + }) + require.NoError(t, err) + + // 2. Try to discover NSE by the right name + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + request := &networkservice.NetworkServiceRequest{ + Connection: &networkservice.Connection{ + NetworkServiceEndpointName: nseName, + }, + } + + _, err = server.Request(ctx, request.Clone()) + require.Error(t, err) + + // 3. Register NSE with the right name + _, err = nseServer.Register(context.Background(), ®istry.NetworkServiceEndpoint{ + Name: nseName, + Url: u, + }) + require.NoError(t, err) + + // 4. Try to discover NSE by the right name + ctx, cancel = context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + _, err = server.Request(ctx, request.Clone()) + require.NoError(t, err) +}