Skip to content

Commit

Permalink
Use a context with exteded timeout on Requests in begin (networkservi…
Browse files Browse the repository at this point in the history
…cemesh#1656)

* Use an extended timeout in case of reselect requests

Signed-off-by: NikitaSkrynnik <nikita.skrynnik@xored.com>

* add unit test

Signed-off-by: NikitaSkrynnik <nikita.skrynnik@xored.com>

* fix race condition

Signed-off-by: NikitaSkrynnik <nikita.skrynnik@xored.com>

* use a context with exteded timeout on Requests in begin

Signed-off-by: NikitaSkrynnik <nikita.skrynnik@xored.com>

* fix unit tests

Signed-off-by: NikitaSkrynnik <nikita.skrynnik@xored.com>

* fix race conditions in dial

Signed-off-by: NikitaSkrynnik <nikita.skrynnik@xored.com>

* fix variable name

Signed-off-by: NikitaSkrynnik <nikita.skrynnik@xored.com>

* fix go linter issues

Signed-off-by: NikitaSkrynnik <nikita.skrynnik@xored.com>

---------

Signed-off-by: NikitaSkrynnik <nikita.skrynnik@xored.com>
  • Loading branch information
NikitaSkrynnik authored Aug 20, 2024
1 parent ae25bb4 commit 6fad31a
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 52 deletions.
21 changes: 12 additions & 9 deletions pkg/networkservice/common/begin/event_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/peer"

"github.com/networkservicemesh/sdk/pkg/tools/clock"
"github.com/networkservicemesh/sdk/pkg/tools/extend"
"github.com/networkservicemesh/sdk/pkg/tools/postpone"

Expand Down Expand Up @@ -160,16 +159,16 @@ type eventFactoryServer struct {
ctxFunc func() (context.Context, context.CancelFunc)
request *networkservice.NetworkServiceRequest
returnedConnection *networkservice.Connection
closeTimeout time.Duration
contextTimeout time.Duration
afterCloseFunc func()
server networkservice.NetworkServiceServer
}

func newEventFactoryServer(ctx context.Context, closeTimeout time.Duration, afterClose func()) *eventFactoryServer {
func newEventFactoryServer(ctx context.Context, contextTimeout time.Duration, afterClose func()) *eventFactoryServer {
f := &eventFactoryServer{
server: next.Server(ctx),
initialCtxFunc: postpone.Context(ctx),
closeTimeout: closeTimeout,
contextTimeout: contextTimeout,
}
f.updateContext(ctx)

Expand Down Expand Up @@ -207,7 +206,12 @@ func (f *eventFactoryServer) Request(opts ...Option) <-chan error {
default:
ctx, cancel := f.ctxFunc()
defer cancel()
conn, err := f.server.Request(ctx, f.request)

extendedCtx, cancel := context.WithTimeout(context.Background(), f.contextTimeout)
defer cancel()

extendedCtx = extend.WithValuesFromContext(extendedCtx, ctx)
conn, err := f.server.Request(extendedCtx, f.request)
if err == nil && f.request != nil {
f.request.Connection = conn
}
Expand Down Expand Up @@ -236,12 +240,11 @@ func (f *eventFactoryServer) Close(opts ...Option) <-chan error {
ctx, cancel := f.ctxFunc()
defer cancel()

c := clock.FromContext(ctx)
closeCtx, cancel := c.WithTimeout(context.Background(), f.closeTimeout)
extendedCtx, cancel := context.WithTimeout(context.Background(), f.contextTimeout)
defer cancel()

closeCtx = extend.WithValuesFromContext(closeCtx, ctx)
_, err := f.server.Close(closeCtx, f.request.GetConnection())
extendedCtx = extend.WithValuesFromContext(extendedCtx, ctx)
_, err := f.server.Close(extendedCtx, f.request.GetConnection())
f.afterCloseFunc()
ch <- err
}
Expand Down
25 changes: 8 additions & 17 deletions pkg/networkservice/common/begin/event_factory_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import (
"github.com/networkservicemesh/sdk/pkg/networkservice/core/chain"
"github.com/networkservicemesh/sdk/pkg/networkservice/core/next"
"github.com/networkservicemesh/sdk/pkg/tools/clock"
"github.com/networkservicemesh/sdk/pkg/tools/clockmock"
)

// This test reproduces the situation when refresh changes the eventFactory context
Expand Down Expand Up @@ -138,19 +137,12 @@ func TestContextTimeout_Server(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Add clockMock to the context
clockMock := clockmock.New(ctx)
ctx = clock.WithClock(ctx, clockMock)

ctx, cancel = clockMock.WithDeadline(ctx, clockMock.Now().Add(time.Second*3))
defer cancel()

closeTimeout := time.Minute
contextTimeout := time.Second * 2
eventFactoryServ := &eventFactoryServer{}
server := chain.NewNetworkServiceServer(
begin.NewServer(begin.WithCloseTimeout(closeTimeout)),
begin.NewServer(begin.WithContextTimeout(contextTimeout)),
eventFactoryServ,
&delayedNSEServer{t: t, closeTimeout: closeTimeout, clock: clockMock},
&delayedNSEServer{t: t, contextTimeout: contextTimeout},
)

// Do Request
Expand Down Expand Up @@ -229,9 +221,8 @@ func (f *failedNSEServer) Close(ctx context.Context, conn *networkservice.Connec

type delayedNSEServer struct {
t *testing.T
clock *clockmock.Mock
initialTimeout time.Duration
closeTimeout time.Duration
contextTimeout time.Duration
}

func (d *delayedNSEServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) {
Expand All @@ -247,20 +238,20 @@ func (d *delayedNSEServer) Request(ctx context.Context, request *networkservice.
d.initialTimeout = timeout
}
// All requests timeout must be equal the first
require.Equal(d.t, d.initialTimeout, timeout)
require.Less(d.t, (d.initialTimeout - timeout).Abs(), time.Second)

// Add delay
d.clock.Add(timeout / 2)
time.Sleep(timeout / 2)
return next.Server(ctx).Request(ctx, request)
}

func (d *delayedNSEServer) Close(ctx context.Context, conn *networkservice.Connection) (*emptypb.Empty, error) {
require.Greater(d.t, d.initialTimeout, time.Duration(0))

deadline, _ := ctx.Deadline()
clockTime := clock.FromContext(ctx)
timeout := time.Until(deadline)

require.Equal(d.t, d.closeTimeout, clockTime.Until(deadline))
require.Less(d.t, (d.contextTimeout - timeout).Abs(), time.Second)

return next.Server(ctx).Close(ctx, conn)
}
12 changes: 6 additions & 6 deletions pkg/networkservice/common/begin/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ import (
)

type option struct {
cancelCtx context.Context
reselect bool
closeTimeout time.Duration
cancelCtx context.Context
reselect bool
contextTimeout time.Duration
}

// Option - event option
Expand All @@ -44,9 +44,9 @@ func WithReselect() Option {
}
}

// WithCloseTimeout - set a custom timeout for a context in begin.Close
func WithCloseTimeout(timeout time.Duration) Option {
// WithContextTimeout - set a custom timeout for a context in begin.Close
func WithContextTimeout(timeout time.Duration) Option {
return func(o *option) {
o.closeTimeout = timeout
o.contextTimeout = timeout
}
}
32 changes: 19 additions & 13 deletions pkg/networkservice/common/begin/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,23 @@ import (

type beginServer struct {
genericsync.Map[string, *eventFactoryServer]
closeTimeout time.Duration
contextTimeout time.Duration
}

// NewServer - creates a new begin chain element
func NewServer(opts ...Option) networkservice.NetworkServiceServer {
o := &option{
cancelCtx: context.Background(),
reselect: false,
closeTimeout: time.Minute,
cancelCtx: context.Background(),
reselect: false,
contextTimeout: time.Minute,
}

for _, opt := range opts {
opt(o)
}

return &beginServer{
closeTimeout: o.closeTimeout,
contextTimeout: o.contextTimeout,
}
}

Expand All @@ -68,7 +68,7 @@ func (b *beginServer) Request(ctx context.Context, request *networkservice.Netwo
eventFactoryServer, _ := b.LoadOrStore(request.GetConnection().GetId(),
newEventFactoryServer(
ctx,
b.closeTimeout,
b.contextTimeout,
func() {
b.Delete(request.GetRequestConnection().GetId())
},
Expand All @@ -88,17 +88,24 @@ func (b *beginServer) Request(ctx context.Context, request *networkservice.Netwo
eventFactoryServer.request != nil && eventFactoryServer.request.Connection != nil {
log.FromContext(ctx).Info("Closing connection due to RESELECT_REQUESTED state")

closeCtx, cancel := context.WithTimeout(context.Background(), b.contextTimeout)
defer cancel()

eventFactoryCtx, eventFactoryCtxCancel := eventFactoryServer.ctxFunc()
_, closeErr := next.Server(eventFactoryCtx).Close(eventFactoryCtx, eventFactoryServer.request.Connection)
closeCtx = extend.WithValuesFromContext(closeCtx, eventFactoryCtx)
_, closeErr := next.Server(closeCtx).Close(closeCtx, eventFactoryServer.request.Connection)
if closeErr != nil {
log.FromContext(ctx).Errorf("Can't close old connection: %v", closeErr)
}
eventFactoryServer.state = closed
eventFactoryCtxCancel()
}

withEventFactoryCtx := withEventFactory(ctx, eventFactoryServer)
conn, err = next.Server(withEventFactoryCtx).Request(withEventFactoryCtx, request)
extendedCtx, cancel := context.WithTimeout(context.Background(), b.contextTimeout)
extendedCtx = extend.WithValuesFromContext(extendedCtx, withEventFactory(ctx, eventFactoryServer))
defer cancel()

conn, err = next.Server(extendedCtx).Request(extendedCtx, request)
if err != nil {
if eventFactoryServer.state != established {
eventFactoryServer.state = closed
Expand Down Expand Up @@ -143,14 +150,13 @@ func (b *beginServer) Close(ctx context.Context, conn *networkservice.Connection
if currentServerClient != eventFactoryServer {
return
}
closeCtx, cancel := context.WithTimeout(context.Background(), b.closeTimeout)
extendedCtx, cancel := context.WithTimeout(context.Background(), b.contextTimeout)
extendedCtx = extend.WithValuesFromContext(extendedCtx, withEventFactory(ctx, eventFactoryServer))
defer cancel()

// Always close with the last valid EventFactory we got
conn = eventFactoryServer.request.Connection
withEventFactoryCtx := withEventFactory(ctx, eventFactoryServer)
closeCtx = extend.WithValuesFromContext(closeCtx, withEventFactoryCtx)
_, err = next.Server(closeCtx).Close(closeCtx, conn)
_, err = next.Server(extendedCtx).Close(extendedCtx, conn)
eventFactoryServer.afterCloseFunc()
}):
return &emptypb.Empty{}, err
Expand Down
55 changes: 51 additions & 4 deletions pkg/networkservice/common/begin/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/networkservicemesh/api/pkg/api/networkservice"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"google.golang.org/protobuf/types/known/emptypb"

"github.com/networkservicemesh/sdk/pkg/networkservice/common/begin"
"github.com/networkservicemesh/sdk/pkg/networkservice/core/next"
Expand All @@ -41,14 +42,24 @@ type waitServer struct {
}

func (s *waitServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) {
time.Sleep(waitTime)
s.requestDone.Store(1)
afterCh := time.After(time.Second)
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-afterCh:
s.requestDone.Add(1)
}
return next.Server(ctx).Request(ctx, request)
}

func (s *waitServer) Close(ctx context.Context, connection *networkservice.Connection) (*empty.Empty, error) {
time.Sleep(waitTime)
s.closeDone.Store(1)
afterCh := time.After(time.Second)
select {
case <-ctx.Done():
return &emptypb.Empty{}, nil
case <-afterCh:
s.closeDone.Add(1)
}
return next.Server(ctx).Close(ctx, connection)
}

Expand Down Expand Up @@ -82,3 +93,39 @@ func TestBeginWorksWithSmallTimeout(t *testing.T) {
return waitSrv.closeDone.Load() == 1
}, waitTime*2, time.Millisecond*500)
}

func TestBeginHasExtendedTimeoutOnReselect(t *testing.T) {
t.Cleanup(func() {
goleak.VerifyNone(t)
})
requestCtx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200)
defer cancel()

waitSrv := &waitServer{}
server := next.NewNetworkServiceServer(
begin.NewServer(),
waitSrv,
)

// Make a first request to create an event factory. Begin should make Request only
request := testRequest("id")
_, err := server.Request(requestCtx, request)
require.EqualError(t, err, context.DeadlineExceeded.Error())
require.Equal(t, int32(0), waitSrv.requestDone.Load())
require.Eventually(t, func() bool {
return waitSrv.requestDone.Load() == 1
}, waitTime*2, time.Millisecond*500)

// Make a second request with RESELECT_REQUESTED. Begin should make Close with extended context first and then Request
requestCtx, cancel = context.WithTimeout(context.Background(), time.Millisecond*200)
defer cancel()
newRequest := request.Clone()
newRequest.Connection.State = networkservice.State_RESELECT_REQUESTED

_, err = server.Request(requestCtx, newRequest)
require.EqualError(t, err, context.DeadlineExceeded.Error())
require.Equal(t, int32(0), waitSrv.closeDone.Load())
require.Eventually(t, func() bool {
return waitSrv.closeDone.Load() == 1 && waitSrv.requestDone.Load() == 2
}, waitTime*4, time.Millisecond*500)
}
8 changes: 6 additions & 2 deletions pkg/networkservice/common/dial/client.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021-2023 Cisco and/or its affiliates.
// Copyright (c) 2021-2024 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand Down Expand Up @@ -72,8 +72,12 @@ func (d *dialClient) Request(ctx context.Context, request *networkservice.Networ
return next.Client(ctx).Request(ctx, request, opts...)
}

di.mu.Lock()
dialClientURL := di.clientURL
di.mu.Unlock()

// If our existing dialer has a different URL close down the chain
if di.clientURL != nil && di.clientURL.String() != clientURL.String() {
if dialClientURL != nil && dialClientURL.String() != clientURL.String() {
closeCtx, closeCancel := closeContextFunc()
defer closeCancel()
err := di.Dial(closeCtx, di.clientURL)
Expand Down
6 changes: 5 additions & 1 deletion pkg/networkservice/common/dial/dialer.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021 Cisco and/or its affiliates.
// Copyright (c) 2021-2024 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand All @@ -20,6 +20,7 @@ import (
"context"
"net/url"
"runtime"
"sync"
"time"

"github.com/pkg/errors"
Expand All @@ -37,6 +38,7 @@ type dialer struct {
*grpc.ClientConn
dialOptions []grpc.DialOption
dialTimeout time.Duration
mu sync.Mutex
}

func newDialer(ctx context.Context, dialTimeout time.Duration, dialOptions ...grpc.DialOption) *dialer {
Expand All @@ -56,8 +58,10 @@ func (di *dialer) Dial(ctx context.Context, clientURL *url.URL) error {
di.cleanupCancel()
}

di.mu.Lock()
// Set the clientURL
di.clientURL = clientURL
di.mu.Unlock()

// Setup dialTimeout if needed
dialCtx := ctx
Expand Down

0 comments on commit 6fad31a

Please sign in to comment.