diff --git a/pkg/networkservice/chains/nsmgr/heal_test.go b/pkg/networkservice/chains/nsmgr/heal_test.go index f88875c631..3616f880ba 100644 --- a/pkg/networkservice/chains/nsmgr/heal_test.go +++ b/pkg/networkservice/chains/nsmgr/heal_test.go @@ -504,7 +504,7 @@ func testForwarderShouldBeSelectedCorrectlyOnNSMgrRestart(t *testing.T, nodeNum, NetworkServiceNames: []string{"my-ns"}, } - nseEntry := domain.Nodes[nodeNum].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken) + domain.Nodes[nodeNum].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken) nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken) @@ -536,26 +536,5 @@ func testForwarderShouldBeSelectedCorrectlyOnNSMgrRestart(t *testing.T, nodeNum, }, }, }, sandbox.GenerateTestToken) - - _, err = domain.Nodes[nodeNum].NSMgr.NetworkServiceEndpointRegistryServer().Register(ctx, ®istry.NetworkServiceEndpoint{ - Name: expectedForwarderName, - Url: domain.Nodes[nodeNum].Forwarders[expectedForwarderName].URL.String(), - NetworkServiceNames: []string{"forwarder"}, - NetworkServiceLabels: map[string]*registry.NetworkServiceLabels{ - "forwarder": { - Labels: map[string]string{ - "p2p": "true", - }, - }, - }, - }) - require.NoError(t, err) - - _, err = domain.Nodes[nodeNum].NSMgr.NetworkServiceEndpointRegistryServer().Register(ctx, ®istry.NetworkServiceEndpoint{ - Name: nseReg.Name, - Url: nseEntry.URL.String(), - NetworkServiceNames: nseReg.NetworkServiceNames, - }) - require.NoError(t, err) } } diff --git a/pkg/networkservice/common/discoverforwarder/server.go b/pkg/networkservice/common/discoverforwarder/server.go index 68635fae5e..91d9bf093c 100644 --- a/pkg/networkservice/common/discoverforwarder/server.go +++ b/pkg/networkservice/common/discoverforwarder/server.go @@ -67,6 +67,13 @@ func (d *discoverForwarderServer) Request(ctx context.Context, request *networks var forwarderName = loadForwarderName(ctx) var logger = log.FromContext(ctx).WithField("discoverForwarderServer", "request") + if forwarderName == "" { + segments := request.Connection.GetPath().GetPathSegments() + if pathIndex := int(request.Connection.GetPath().Index); len(segments) > pathIndex+1 { + forwarderName = segments[pathIndex+1].Name + } + } + if forwarderName == "" { ns, err := d.discoverNetworkService(ctx, request.GetConnection().GetNetworkService(), request.GetConnection().GetPayload()) if err != nil { @@ -93,17 +100,6 @@ func (d *discoverForwarderServer) Request(ctx context.Context, request *networks return nil, errors.New("no candidates found") } - segments := request.Connection.GetPath().GetPathSegments() - if pathIndex := int(request.Connection.GetPath().Index); len(segments) > pathIndex+1 { - datapathForwarder := segments[pathIndex+1].Name - for i, candidate := range nses { - if candidate.Name == datapathForwarder { - nses[0], nses[i] = nses[i], nses[0] - break - } - } - } - var candidatesErr = errors.New("all forwarders have failed") // TODO: Should we consider about load balancing? @@ -128,6 +124,7 @@ func (d *discoverForwarderServer) Request(ctx context.Context, request *networks return nil, candidatesErr } + stream, err := d.nseClient.Find(ctx, ®istry.NetworkServiceEndpointQuery{ NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{ Name: forwarderName,