diff --git a/pkg/networkservice/common/discoverforwarder/metadata.go b/pkg/networkservice/common/discoverforwarder/metadata.go index b675fd369..22451b10a 100644 --- a/pkg/networkservice/common/discoverforwarder/metadata.go +++ b/pkg/networkservice/common/discoverforwarder/metadata.go @@ -1,5 +1,7 @@ // Copyright (c) 2021 Doc.ai and/or its affiliates. // +// Copyright (c) 2023 Cisco and/or its affiliates. +// // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -22,16 +24,21 @@ import ( "github.com/networkservicemesh/sdk/pkg/networkservice/utils/metadata" ) -type selectedForworderKey struct{} +type selectedForwarderKey struct{} + +type selectedForwarderVal struct { + name string + active bool +} -func loadForwarderName(ctx context.Context) string { - v, ok := metadata.Map(ctx, false).Load(selectedForworderKey{}) +func loadForwarder(ctx context.Context) *selectedForwarderVal { + v, ok := metadata.Map(ctx, false).Load(selectedForwarderKey{}) if !ok { - return "" + return nil } - return v.(string) + return v.(*selectedForwarderVal) } -func storeForwarderName(ctx context.Context, v string) { - metadata.Map(ctx, false).Store(selectedForworderKey{}, v) +func storeForwarder(ctx context.Context, v *selectedForwarderVal) { + metadata.Map(ctx, false).Store(selectedForwarderKey{}, v) } diff --git a/pkg/networkservice/common/discoverforwarder/server.go b/pkg/networkservice/common/discoverforwarder/server.go index e306ed01c..9c0f5a540 100644 --- a/pkg/networkservice/common/discoverforwarder/server.go +++ b/pkg/networkservice/common/discoverforwarder/server.go @@ -65,11 +65,12 @@ func NewServer(nsClient registry.NetworkServiceRegistryClient, nseClient registr return result } +// nolint:gocyclo func (d *discoverForwarderServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) { - var forwarderName = loadForwarderName(ctx) + var forwarder = loadForwarder(ctx) var logger = log.FromContext(ctx).WithField("discoverForwarderServer", "request") - if forwarderName == "" { + if forwarder == nil || !forwarder.active { ns, err := d.discoverNetworkService(ctx, request.GetConnection().GetNetworkService(), request.GetConnection().GetPayload()) if err != nil { return nil, err @@ -117,7 +118,10 @@ func (d *discoverForwarderServer) Request(ctx context.Context, request *networks resp, err := next.Server(ctx).Request(clienturlctx.WithClientURL(ctx, u), request.Clone()) if err == nil { - storeForwarderName(ctx, candidate.Name) + storeForwarder(ctx, &selectedForwarderVal{ + name: candidate.Name, + active: true, + }) return resp, nil } logger.Errorf("forwarder=%v url=%v returned error=%v", candidate.Name, candidate.Url, err.Error()) @@ -129,18 +133,18 @@ func (d *discoverForwarderServer) Request(ctx context.Context, request *networks stream, err := d.nseClient.Find(ctx, ®istry.NetworkServiceEndpointQuery{ NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{ - Name: forwarderName, + Name: forwarder.name, Url: d.nsmgrURL, }, }) if err != nil { logger.Errorf("can not open registry nse stream by forwarder name. Error: %v", err.Error()) - return nil, errors.Wrapf(err, "failed to find %s on %s", forwarderName, d.nsmgrURL) + return nil, errors.Wrapf(err, "failed to find %s on %s", forwarder.name, d.nsmgrURL) } nses := registry.ReadNetworkServiceEndpointList(stream) if len(nses) == 0 { - storeForwarderName(ctx, "") + storeForwarder(ctx, nil) return nil, errors.New("forwarder not found") } @@ -152,27 +156,27 @@ func (d *discoverForwarderServer) Request(ctx context.Context, request *networks conn, err := next.Server(ctx).Request(clienturlctx.WithClientURL(ctx, u), request) if err != nil { - storeForwarderName(ctx, "") + forwarder.active = false } return conn, err } func (d *discoverForwarderServer) Close(ctx context.Context, conn *networkservice.Connection) (*empty.Empty, error) { - var forwarderName = loadForwarderName(ctx) + var forwarder = loadForwarder(ctx) var logger = log.FromContext(ctx).WithField("discoverForwarderServer", "request") - if forwarderName == "" { + if forwarder == nil { return nil, errors.New("forwarder is not selected") } stream, err := d.nseClient.Find(ctx, ®istry.NetworkServiceEndpointQuery{ NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{ - Name: forwarderName, + Name: forwarder.name, Url: d.nsmgrURL, }, }) if err != nil { logger.Errorf("can not open registry nse stream by forwarder name. Error: %v", err.Error()) - return nil, errors.Wrapf(err, "failed to find %s on %s", forwarderName, d.nsmgrURL) + return nil, errors.Wrapf(err, "failed to find %s on %s", forwarder.name, d.nsmgrURL) } nses := registry.ReadNetworkServiceEndpointList(stream)