diff --git a/node/router/router.go b/node/router/router.go index ae833c01..ee90b069 100644 --- a/node/router/router.go +++ b/node/router/router.go @@ -123,21 +123,24 @@ func (r *Router) Broadcast(stream orderer.AtomicBroadcast_BroadcastServer) error r.init() + ctx := stream.Context() + feedbackChan := make(chan Response, 1000) + errCh := make(chan error, 2) + exit := make(chan struct{}) - defer func() { - close(exit) - }() + defer close(exit) - feedbackChan := make(chan Response, 1000) - go sendFeedbackOnBroadcastStream(stream, exit, feedbackChan) + go sendFeedbackOnBroadcastStream(stream, errCh, exit, feedbackChan) for { reqEnv, err := stream.Recv() if err == io.EOF { - return nil + errCh <- nil + break } if err != nil { - return err + errCh <- err + break } atomic.AddUint64(&r.incoming, 1) @@ -149,6 +152,22 @@ func (r *Router) Broadcast(stream orderer.AtomicBroadcast_BroadcastServer) error tr := &TrackedRequest{request: request, responses: feedbackChan, reqID: reqID} shardRouter.NewForward(tr) } + + select { + case <-ctx.Done(): + r.logger.Infof("broadcast is closing, context canceled: %v", ctx.Err()) + return ctx.Err() + case err := <-errCh: + if err != nil { + r.logger.Infof("broadcast is closing, Received error: %v", err) + return err + } else { + r.logger.Infof("Received EOF from client, broadcast closing (recv)") + <-ctx.Done() + r.logger.Infof("broadcast is closing, context canceled: %v", ctx.Err()) + return ctx.Err() + } + } } func (r *Router) init() { @@ -283,13 +302,20 @@ func sendFeedbackOnSubmitStream(stream protos.RequestTransmit_SubmitStreamServer } } -func sendFeedbackOnBroadcastStream(stream orderer.AtomicBroadcast_BroadcastServer, exit chan struct{}, feedbackChan chan Response) { +func sendFeedbackOnBroadcastStream(stream orderer.AtomicBroadcast_BroadcastServer, errCh chan error, exit chan struct{}, feedbackChan chan Response) { + ctx := stream.Context() for { select { - case <-exit: + case <-ctx.Done(): + errCh <- ctx.Err() return case response := <-feedbackChan: - stream.Send(responseToBroadcastResponse(&response)) + if err := stream.Send(responseToBroadcastResponse(&response)); err != nil { + errCh <- err + return // or just print error and continue? + } + case <-exit: + return } } } diff --git a/node/router/router_test.go b/node/router/router_test.go index 644d9fb0..bcb1831b 100644 --- a/node/router/router_test.go +++ b/node/router/router_test.go @@ -612,3 +612,72 @@ func createAndStartRouter(t *testing.T, partyID types.PartyID, ca tlsgen.CA, bat return r } + +func TestClientCloseSend(t *testing.T) { + grpclog.SetLoggerV2(&testutil.SilentLogger{}) + + testSetup := createRouterTestSetup(t, types.PartyID(1), 1, true, false) + err := createServerTLSClientConnection(testSetup, testSetup.ca) + require.NoError(t, err) + require.NotNil(t, testSetup.clientConn) + + defer testSetup.Close() + numOfRequests := 555 + res := submitBroadcastRequestsWithCloseSend(testSetup.clientConn, numOfRequests) + require.NoError(t, res.err) + + require.Eventually(t, func() bool { + return testSetup.batchers[0].ReceivedMessageCount() == uint32(numOfRequests) + }, 10*time.Second, 10*time.Millisecond) +} + +func submitBroadcastRequestsWithCloseSend(conn *grpc.ClientConn, numOfRequests int) (res testStreamResult) { + res = testStreamResult{ + failRequests: numOfRequests, + } + + cl := ab.NewAtomicBroadcastClient(conn) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + stream, err := cl.Broadcast(ctx) + if err != nil { + res.err = err + return + } + + buff := make([]byte, 300) + for j := 0; j < numOfRequests; j++ { + binary.BigEndian.PutUint32(buff, uint32(j)) + env := tx.CreateStructuredEnvelope(buff) + err := stream.Send(env) + if err != nil { + return + } + } + + stream.CloseSend() + + for j := 0; j < numOfRequests; j++ { + select { + default: + resp, err := stream.Recv() + if err != nil { + res.err = fmt.Errorf("error receiving response: %s", err) + } + if resp.Status != common.Status_SUCCESS { + requestErr := fmt.Errorf("receiving response with error: %s", resp.Info) + res.respondsErrors = append(res.respondsErrors, requestErr) + res.err = requestErr + } else { + res.successRequests++ + res.failRequests-- + } + case <-ctx.Done(): + res.err = fmt.Errorf("a time out occured during submitting request: %w", ctx.Err()) + } + } + + return +}