Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 36 additions & 10 deletions node/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,21 +123,24 @@ func (r *Router) Broadcast(stream orderer.AtomicBroadcast_BroadcastServer) error

r.init()

ctx := stream.Context()
feedbackChan := make(chan Response, 1000)
errCh := make(chan error, 2)

exit := make(chan struct{})
defer func() {
close(exit)
}()
defer close(exit)

feedbackChan := make(chan Response, 1000)
go sendFeedbackOnBroadcastStream(stream, exit, feedbackChan)
go sendFeedbackOnBroadcastStream(stream, errCh, exit, feedbackChan)

for {
reqEnv, err := stream.Recv()
if err == io.EOF {
return nil
errCh <- nil
break
}
if err != nil {
return err
errCh <- err
break
}

atomic.AddUint64(&r.incoming, 1)
Expand All @@ -149,6 +152,22 @@ func (r *Router) Broadcast(stream orderer.AtomicBroadcast_BroadcastServer) error
tr := &TrackedRequest{request: request, responses: feedbackChan, reqID: reqID}
shardRouter.NewForward(tr)
}

select {
case <-ctx.Done():
r.logger.Infof("broadcast is closing, context canceled: %v", ctx.Err())
return ctx.Err()
case err := <-errCh:
if err != nil {
r.logger.Infof("broadcast is closing, Received error: %v", err)
return err
} else {
r.logger.Infof("Received EOF from client, broadcast closing (recv)")
<-ctx.Done()
r.logger.Infof("broadcast is closing, context canceled: %v", ctx.Err())
return ctx.Err()
}
}
}

func (r *Router) init() {
Expand Down Expand Up @@ -283,13 +302,20 @@ func sendFeedbackOnSubmitStream(stream protos.RequestTransmit_SubmitStreamServer
}
}

func sendFeedbackOnBroadcastStream(stream orderer.AtomicBroadcast_BroadcastServer, exit chan struct{}, feedbackChan chan Response) {
func sendFeedbackOnBroadcastStream(stream orderer.AtomicBroadcast_BroadcastServer, errCh chan error, exit chan struct{}, feedbackChan chan Response) {
ctx := stream.Context()
for {
select {
case <-exit:
case <-ctx.Done():
errCh <- ctx.Err()
return
case response := <-feedbackChan:
stream.Send(responseToBroadcastResponse(&response))
if err := stream.Send(responseToBroadcastResponse(&response)); err != nil {
errCh <- err
return // or just print error and continue?
}
case <-exit:
return
}
}
}
Expand Down
69 changes: 69 additions & 0 deletions node/router/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -612,3 +612,72 @@ func createAndStartRouter(t *testing.T, partyID types.PartyID, ca tlsgen.CA, bat

return r
}

func TestClientCloseSend(t *testing.T) {
grpclog.SetLoggerV2(&testutil.SilentLogger{})

testSetup := createRouterTestSetup(t, types.PartyID(1), 1, true, false)
err := createServerTLSClientConnection(testSetup, testSetup.ca)
require.NoError(t, err)
require.NotNil(t, testSetup.clientConn)

defer testSetup.Close()
numOfRequests := 555
res := submitBroadcastRequestsWithCloseSend(testSetup.clientConn, numOfRequests)
require.NoError(t, res.err)

require.Eventually(t, func() bool {
return testSetup.batchers[0].ReceivedMessageCount() == uint32(numOfRequests)
}, 10*time.Second, 10*time.Millisecond)
}

func submitBroadcastRequestsWithCloseSend(conn *grpc.ClientConn, numOfRequests int) (res testStreamResult) {
res = testStreamResult{
failRequests: numOfRequests,
}

cl := ab.NewAtomicBroadcastClient(conn)

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

stream, err := cl.Broadcast(ctx)
if err != nil {
res.err = err
return
}

buff := make([]byte, 300)
for j := 0; j < numOfRequests; j++ {
binary.BigEndian.PutUint32(buff, uint32(j))
env := tx.CreateStructuredEnvelope(buff)
err := stream.Send(env)
if err != nil {
return
}
}

stream.CloseSend()

for j := 0; j < numOfRequests; j++ {
select {
default:
resp, err := stream.Recv()
if err != nil {
res.err = fmt.Errorf("error receiving response: %s", err)
}
if resp.Status != common.Status_SUCCESS {
requestErr := fmt.Errorf("receiving response with error: %s", resp.Info)
res.respondsErrors = append(res.respondsErrors, requestErr)
res.err = requestErr
} else {
res.successRequests++
res.failRequests--
}
case <-ctx.Done():
res.err = fmt.Errorf("a time out occured during submitting request: %w", ctx.Err())
}
}

return
}