Skip to content

Commit

Permalink
begin: update context after a successful request (#1370)
Browse files Browse the repository at this point in the history
Signed-off-by: Artem Glazychev <artem.glazychev@xored.com>

Signed-off-by: Artem Glazychev <artem.glazychev@xored.com>
  • Loading branch information
glazychev-art authored Oct 13, 2022
1 parent d4d43ae commit 9ae27f8
Show file tree
Hide file tree
Showing 10 changed files with 168 additions and 84 deletions.
2 changes: 1 addition & 1 deletion pkg/networkservice/common/begin/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ func (b *beginClient) Request(ctx context.Context, request *networkservice.Netwo
conn, err = b.Request(ctx, request, opts...)
return
}
eventFactoryClient.updateContext(ctx)

ctx = withEventFactory(ctx, eventFactoryClient)
request.Connection = mergeConnection(eventFactoryClient.returnedConnection, request.GetConnection(), eventFactoryClient.request.GetConnection())
Expand All @@ -83,6 +82,7 @@ func (b *beginClient) Request(ctx context.Context, request *networkservice.Netwo
eventFactoryClient.state = established

eventFactoryClient.returnedConnection = conn.Clone()
eventFactoryClient.updateContext(ctx)
})
return conn, err
}
Expand Down
56 changes: 38 additions & 18 deletions pkg/networkservice/common/begin/event_factory_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"testing"
"time"

"github.com/pkg/errors"

"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
"google.golang.org/grpc"
Expand All @@ -38,13 +40,13 @@ import (
func TestRefresh_Client(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

syncChan := make(chan struct{})
checkCtxCl := &checkContextClient{t: t}
eventFactoryCl := &eventFactoryClient{ch: syncChan}
eventFactoryCl := &eventFactoryClient{}
client := chain.NewNetworkServiceClient(
begin.NewClient(),
checkCtxCl,
eventFactoryCl,
&failedNSEClient{},
)

ctx, cancel := context.WithCancel(context.Background())
Expand All @@ -62,27 +64,36 @@ func TestRefresh_Client(t *testing.T) {

// Change context value before refresh Request
ctx = context.WithValue(ctx, contextKey{}, "value_2")
checkCtxCl.setExpectedValue("value_2")

// Call refresh that will fail
request.Connection = conn.Clone()
request.Connection.NetworkServiceEndpointName = failedNSENameClient
checkCtxCl.setExpectedValue("value_2")
_, err = client.Request(ctx, request.Clone())
assert.Error(t, err)

// Call refresh
// Call refresh from eventFactory. We are expecting the previous value in the context
checkCtxCl.setExpectedValue("value_1")
eventFactoryCl.callRefresh()

// Call refresh that will successful
request.Connection.NetworkServiceEndpointName = ""
checkCtxCl.setExpectedValue("value_2")
conn, err = client.Request(ctx, request.Clone())
assert.NotNil(t, t, conn)
assert.NoError(t, err)

// Call refresh from eventFactory. We are expecting updated value in the context
eventFactoryCl.callRefresh()
<-syncChan
}

// This test reproduces the situation when Close and Request were called at the same time
// nolint:dupl
func TestRefreshDuringClose_Client(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

syncChan := make(chan struct{})
checkCtxCl := &checkContextClient{t: t}
eventFactoryCl := &eventFactoryClient{ch: syncChan}
eventFactoryCl := &eventFactoryClient{}
client := chain.NewNetworkServiceClient(
begin.NewClient(),
checkCtxCl,
Expand All @@ -109,7 +120,6 @@ func TestRefreshDuringClose_Client(t *testing.T) {

// Call Close from eventFactory
eventFactoryCl.callClose()
<-syncChan

// Call refresh (should be called at the same time as Close)
conn, err = client.Request(ctx, request.Clone())
Expand All @@ -118,12 +128,10 @@ func TestRefreshDuringClose_Client(t *testing.T) {

// Call refresh from eventFactory. We are expecting updated value in the context
eventFactoryCl.callRefresh()
<-syncChan
}

type eventFactoryClient struct {
ctx context.Context
ch chan<- struct{}
}

func (s *eventFactoryClient) Request(ctx context.Context, request *networkservice.NetworkServiceRequest, opts ...grpc.CallOption) (*networkservice.Connection, error) {
Expand All @@ -139,18 +147,12 @@ func (s *eventFactoryClient) Close(ctx context.Context, conn *networkservice.Con

func (s *eventFactoryClient) callClose() {
eventFactory := begin.FromContext(s.ctx)
go func() {
s.ch <- struct{}{}
eventFactory.Close()
}()
eventFactory.Close()
}

func (s *eventFactoryClient) callRefresh() {
eventFactory := begin.FromContext(s.ctx)
go func() {
s.ch <- struct{}{}
eventFactory.Request()
}()
<-eventFactory.Request()
}

type contextKey struct{}
Expand All @@ -172,3 +174,21 @@ func (c *checkContextClient) Close(ctx context.Context, conn *networkservice.Con
func (c *checkContextClient) setExpectedValue(value string) {
c.expectedValue = value
}

const failedNSENameClient = "failedNSE"

type failedNSEClient struct{}

func (f *failedNSEClient) Request(ctx context.Context, request *networkservice.NetworkServiceRequest, opts ...grpc.CallOption) (*networkservice.Connection, error) {
if request.GetConnection().NetworkServiceEndpointName == failedNSENameClient {
return nil, errors.New("failed")
}
return next.Client(ctx).Request(ctx, request, opts...)
}

func (f *failedNSEClient) Close(ctx context.Context, conn *networkservice.Connection, opts ...grpc.CallOption) (*emptypb.Empty, error) {
if conn.NetworkServiceEndpointName == failedNSENameClient {
return nil, errors.New("failed")
}
return next.Client(ctx).Close(ctx, conn, opts...)
}
56 changes: 38 additions & 18 deletions pkg/networkservice/common/begin/event_factory_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"testing"
"time"

"github.com/pkg/errors"

"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
"google.golang.org/protobuf/types/known/emptypb"
Expand All @@ -37,13 +39,13 @@ import (
func TestRefresh_Server(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

syncChan := make(chan struct{})
checkCtxServ := &checkContextServer{t: t}
eventFactoryServ := &eventFactoryServer{ch: syncChan}
eventFactoryServ := &eventFactoryServer{}
server := chain.NewNetworkServiceServer(
begin.NewServer(),
checkCtxServ,
eventFactoryServ,
&failedNSEServer{},
)

ctx, cancel := context.WithCancel(context.Background())
Expand All @@ -61,27 +63,36 @@ func TestRefresh_Server(t *testing.T) {

// Change context value before refresh Request
ctx = context.WithValue(ctx, contextKey{}, "value_2")
checkCtxServ.setExpectedValue("value_2")

// Call refresh that will fail
request.Connection = conn.Clone()
request.Connection.NetworkServiceEndpointName = failedNSENameServer
checkCtxServ.setExpectedValue("value_2")
_, err = server.Request(ctx, request.Clone())
assert.Error(t, err)

// Call refresh
// Call refresh from eventFactory. We are expecting the previous value in the context
checkCtxServ.setExpectedValue("value_1")
eventFactoryServ.callRefresh()

// Call refresh that will successful
request.Connection.NetworkServiceEndpointName = ""
checkCtxServ.setExpectedValue("value_2")
conn, err = server.Request(ctx, request.Clone())
assert.NotNil(t, t, conn)
assert.NoError(t, err)

// Call refresh from eventFactory. We are expecting updated value in the context
eventFactoryServ.callRefresh()
<-syncChan
}

// This test reproduces the situation when Close and Request were called at the same time
// nolint:dupl
func TestRefreshDuringClose_Server(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

syncChan := make(chan struct{})
checkCtxServ := &checkContextServer{t: t}
eventFactoryServ := &eventFactoryServer{ch: syncChan}
eventFactoryServ := &eventFactoryServer{}
server := chain.NewNetworkServiceServer(
begin.NewServer(),
checkCtxServ,
Expand All @@ -108,7 +119,6 @@ func TestRefreshDuringClose_Server(t *testing.T) {

// Call Close from eventFactory
eventFactoryServ.callClose()
<-syncChan

// Call refresh (should be called at the same time as Close)
conn, err = server.Request(ctx, request.Clone())
Expand All @@ -117,12 +127,10 @@ func TestRefreshDuringClose_Server(t *testing.T) {

// Call refresh from eventFactory. We are expecting updated value in the context
eventFactoryServ.callRefresh()
<-syncChan
}

type eventFactoryServer struct {
ctx context.Context
ch chan<- struct{}
}

func (e *eventFactoryServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) {
Expand All @@ -138,18 +146,12 @@ func (e *eventFactoryServer) Close(ctx context.Context, conn *networkservice.Con

func (e *eventFactoryServer) callClose() {
eventFactory := begin.FromContext(e.ctx)
go func() {
e.ch <- struct{}{}
eventFactory.Close()
}()
eventFactory.Close()
}

func (e *eventFactoryServer) callRefresh() {
eventFactory := begin.FromContext(e.ctx)
go func() {
e.ch <- struct{}{}
eventFactory.Request()
}()
<-eventFactory.Request()
}

type checkContextServer struct {
Expand All @@ -169,3 +171,21 @@ func (c *checkContextServer) Close(ctx context.Context, conn *networkservice.Con
func (c *checkContextServer) setExpectedValue(value string) {
c.expectedValue = value
}

const failedNSENameServer = "failedNSE"

type failedNSEServer struct{}

func (f *failedNSEServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) {
if request.GetConnection().NetworkServiceEndpointName == failedNSENameServer {
return nil, errors.New("failed")
}
return next.Server(ctx).Request(ctx, request)
}

func (f *failedNSEServer) Close(ctx context.Context, conn *networkservice.Connection) (*emptypb.Empty, error) {
if conn.NetworkServiceEndpointName == failedNSENameServer {
return nil, errors.New("failed")
}
return next.Server(ctx).Close(ctx, conn)
}
2 changes: 1 addition & 1 deletion pkg/networkservice/common/begin/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ func (b *beginServer) Request(ctx context.Context, request *networkservice.Netwo
conn, err = b.Request(ctx, request)
return
}
eventFactoryServer.updateContext(ctx)

ctx = withEventFactory(ctx, eventFactoryServer)
conn, err = next.Server(ctx).Request(ctx, request)
Expand All @@ -77,6 +76,7 @@ func (b *beginServer) Request(ctx context.Context, request *networkservice.Netwo
eventFactoryServer.state = established

eventFactoryServer.returnedConnection = conn.Clone()
eventFactoryServer.updateContext(ctx)
})
return conn, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/registry/common/begin/ns_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ func (b *beginNSClient) Register(ctx context.Context, in *registry.NetworkServic
resp, err = b.Register(ctx, in, opts...)
return
}
eventFactoryClient.updateContext(ctx)

ctx = withEventFactory(ctx, eventFactoryClient)
resp, err = next.NetworkServiceRegistryClient(ctx).Register(ctx, in, opts...)
Expand All @@ -76,6 +75,7 @@ func (b *beginNSClient) Register(ctx context.Context, in *registry.NetworkServic
eventFactoryClient.state = established
eventFactoryClient.registration = mergeNS(in, resp.Clone())
eventFactoryClient.response = resp.Clone()
eventFactoryClient.updateContext(ctx)
})
return resp, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/registry/common/begin/ns_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ func (b *beginNSServer) Register(ctx context.Context, in *registry.NetworkServic
resp, err = b.Register(ctx, in)
return
}
eventFactoryServer.updateContext(ctx)

ctx = withEventFactory(ctx, eventFactoryServer)
resp, err = next.NetworkServiceRegistryServer(ctx).Register(ctx, in)
Expand All @@ -74,6 +73,7 @@ func (b *beginNSServer) Register(ctx context.Context, in *registry.NetworkServic
eventFactoryServer.registration = mergeNS(in, resp)
eventFactoryServer.state = established
eventFactoryServer.response = resp
eventFactoryServer.updateContext(ctx)
})
return resp, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/registry/common/begin/nse_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ func (b *beginNSEClient) Register(ctx context.Context, in *registry.NetworkServi
resp, err = b.Register(ctx, in, opts...)
return
}
eventFactoryClient.updateContext(ctx)

ctx = withEventFactory(ctx, eventFactoryClient)
resp, err = next.NetworkServiceEndpointRegistryClient(ctx).Register(ctx, in, opts...)
Expand All @@ -76,6 +75,7 @@ func (b *beginNSEClient) Register(ctx context.Context, in *registry.NetworkServi
eventFactoryClient.state = established
eventFactoryClient.registration = mergeNSE(in, resp.Clone())
eventFactoryClient.response = resp.Clone()
eventFactoryClient.updateContext(ctx)
})
return resp, err
}
Expand Down
Loading

0 comments on commit 9ae27f8

Please sign in to comment.