diff --git a/pkg/networkservice/common/discoverforwarder/server.go b/pkg/networkservice/common/discoverforwarder/server.go index a9d84ba49..3ba3ea27d 100644 --- a/pkg/networkservice/common/discoverforwarder/server.go +++ b/pkg/networkservice/common/discoverforwarder/server.go @@ -69,6 +69,10 @@ func (d *discoverForwarderServer) Request(ctx context.Context, request *networks var forwarderName = loadForwarderName(ctx) var logger = log.FromContext(ctx).WithField("discoverForwarderServer", "request") + if request.GetConnection().State == networkservice.State_RESELECT_REQUESTED { + forwarderName = "" + } + ns, err := d.discoverNetworkService(ctx, request.GetConnection().GetNetworkService(), request.GetConnection().GetPayload()) if err != nil { return nil, err @@ -96,7 +100,7 @@ func (d *discoverForwarderServer) Request(ctx context.Context, request *networks return nil, errors.New("no candidates found") } - if forwarderName == "" { + if forwarderName == "" && request.GetConnection().GetState() != networkservice.State_RESELECT_REQUESTED { segments := request.Connection.GetPath().GetPathSegments() if pathIndex := int(request.Connection.GetPath().Index); len(segments) > pathIndex+1 { for i, candidate := range nses { diff --git a/pkg/networkservice/ipam/groupipam/server_test.go b/pkg/networkservice/ipam/groupipam/server_test.go index 22bcdaec3..0888c47a6 100644 --- a/pkg/networkservice/ipam/groupipam/server_test.go +++ b/pkg/networkservice/ipam/groupipam/server_test.go @@ -21,6 +21,7 @@ import ( "net" "testing" + "github.com/google/uuid" "github.com/networkservicemesh/api/pkg/api/networkservice" "github.com/stretchr/testify/require" @@ -126,3 +127,60 @@ func Test_NewServer_GroupOfCustomIPAMServers(t *testing.T) { require.NoError(t, err) requireConns(t, conn4, []string{"172.92.0.3/16", "fd00::3/8"}) } + +func TestOutOfIPs(t *testing.T) { + _, ipNet, err := net.ParseCIDR("192.168.1.2/31") + require.NoError(t, err) + + srv1 := groupipam.NewServer([][]*net.IPNet{{ipNet}}) + srv2 := groupipam.NewServer([][]*net.IPNet{{ipNet}}) + + req1 := &networkservice.NetworkServiceRequest{ + Connection: &networkservice.Connection{ + Id: uuid.NewString(), + Context: &networkservice.ConnectionContext{ + IpContext: new(networkservice.IPContext), + }, + }, + } + + req2 := &networkservice.NetworkServiceRequest{ + Connection: &networkservice.Connection{ + Id: uuid.NewString(), + Context: &networkservice.ConnectionContext{ + IpContext: new(networkservice.IPContext), + }, + }, + } + for i := 0; i < 100; i++ { + conn1, err := srv1.Request(context.Background(), req1) + require.NoError(t, err) + requireConns(t, conn1, "192.168.1.2/32", "192.168.1.3/32") + req1.Connection = conn1 + + conn2, err := srv2.Request(context.Background(), req2) + require.NoError(t, err) + requireConns(t, conn2, "192.168.1.2/32", "192.168.1.3/32") + req2.Connection = conn2 + + _, err = srv1.Request(context.Background(), req2) + require.Error(t, err) + + _, err = srv2.Request(context.Background(), req1) + require.Error(t, err) + + _, err = srv2.Close(context.Background(), req2.GetConnection()) + require.NoError(t, err) + _, err = srv1.Close(context.Background(), req1.GetConnection()) + require.NoError(t, err) + } +} + +func requireConns(t *testing.T, conn *networkservice.Connection, dstAddr, srcAddr string) { + for i, src := range conn.GetContext().GetIpContext().GetSrcIpAddrs() { + require.Equal(t, srcAddr, src, i) + } + for i, dst := range conn.GetContext().GetIpContext().GetDstIpAddrs() { + require.Equal(t, dstAddr, dst, i) + } +} diff --git a/pkg/networkservice/ipam/point2pointipam/server.go b/pkg/networkservice/ipam/point2pointipam/server.go index a9e9874d4..6e83498b9 100644 --- a/pkg/networkservice/ipam/point2pointipam/server.go +++ b/pkg/networkservice/ipam/point2pointipam/server.go @@ -219,7 +219,7 @@ func (s *ipamServer) Close(ctx context.Context, conn *networkservice.Connection) return nil, errors.Wrap(s.initErr, "failed to init IPAM server during close") } - if connInfo, ok := s.Load(conn.GetId()); ok { + if connInfo, ok := s.LoadAndDelete(conn.GetId()); ok { s.free(connInfo) }