diff --git a/binarylog/binarylog_end2end_test.go b/binarylog/binarylog_end2end_test.go index 66bb7bda3af4..277c17a10726 100644 --- a/binarylog/binarylog_end2end_test.go +++ b/binarylog/binarylog_end2end_test.go @@ -31,10 +31,12 @@ import ( "github.com/golang/protobuf/proto" "google.golang.org/grpc" "google.golang.org/grpc/binarylog" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/grpclog" iblog "google.golang.org/grpc/internal/binarylog" "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -1059,3 +1061,39 @@ func (s) TestServerBinaryLogFullDuplexError(t *testing.T) { t.Fatal(err) } } + +// TestCanceledStatus ensures a server that responds with a Canceled status has +// its trailers logged appropriately and is not treated as a canceled RPC. +func (s) TestCanceledStatus(t *testing.T) { + defer testSink.clear() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + const statusMsgWant = "server returned Canceled" + ss := &stubserver.StubServer{ + UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + grpc.SetTrailer(ctx, metadata.Pairs("key", "value")) + return nil, status.Error(codes.Canceled, statusMsgWant) + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{}); status.Code(err) != codes.Canceled { + t.Fatalf("Received unexpected error from UnaryCall: %v; want Canceled", err) + } + + got := testSink.logEntries(true) + last := got[len(got)-1] + if last.Type != binlogpb.GrpcLogEntry_EVENT_TYPE_SERVER_TRAILER || + last.GetTrailer().GetStatusCode() != uint32(codes.Canceled) || + last.GetTrailer().GetStatusMessage() != statusMsgWant || + len(last.GetTrailer().GetMetadata().GetEntry()) != 1 || + last.GetTrailer().GetMetadata().GetEntry()[0].GetKey() != "key" || + string(last.GetTrailer().GetMetadata().GetEntry()[0].GetValue()) != "value" { + t.Fatalf("Got binary log: %+v; want last entry is server trailing with status Canceled", got) + } +} diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index afcda3602f0c..badab8acf3b1 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -1505,14 +1505,15 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { return } - isHeader := false - - // If headerChan hasn't been closed yet - if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) { - s.headerValid = true - if !endStream { - // HEADERS frame block carries a Response-Headers. - isHeader = true + // For headers, set them in s.header and close headerChan. For trailers or + // trailers-only, closeStream will set the trailers and close headerChan as + // needed. + if !endStream { + // If headerChan hasn't been closed yet (expected, given we checked it + // above, but something else could have potentially closed the whole + // stream). + if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) { + s.headerValid = true // These values can be set without any synchronization because // stream goroutine will read it only after seeing a closed // headerChan which we'll close after setting this. @@ -1520,15 +1521,12 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { if len(mdata) > 0 { s.header = mdata } - } else { - // HEADERS frame block carries a Trailers-Only. - s.noHeaders = true + close(s.headerChan) } - close(s.headerChan) } for _, sh := range t.statsHandlers { - if isHeader { + if !endStream { inHeader := &stats.InHeader{ Client: true, WireLength: int(frame.Header().Length), @@ -1554,9 +1552,10 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { statusGen = status.New(rawStatusCode, grpcMessage) } - // if client received END_STREAM from server while stream was still active, send RST_STREAM - rst := s.getState() == streamActive - t.closeStream(s, io.EOF, rst, http2.ErrCodeNo, statusGen, mdata, true) + // If client received END_STREAM from server while stream was still active, + // send RST_STREAM. + rstStream := s.getState() == streamActive + t.closeStream(s, io.EOF, rstStream, http2.ErrCodeNo, statusGen, mdata, true) } // readServerPreface reads and handles the initial settings frame from the diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 99e184a13978..74a811fc0590 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -43,10 +43,6 @@ import ( "google.golang.org/grpc/tap" ) -// ErrNoHeaders is used as a signal that a trailers only response was received, -// and is not a real error. -var ErrNoHeaders = errors.New("stream has no headers") - const logLevel = 2 type bufferPool struct { @@ -390,14 +386,10 @@ func (s *Stream) Header() (metadata.MD, error) { } s.waitOnHeader() - if !s.headerValid { + if !s.headerValid || s.noHeaders { return nil, s.status.Err() } - if s.noHeaders { - return nil, ErrNoHeaders - } - return s.header.Copy(), nil } diff --git a/rpc_util.go b/rpc_util.go index 56451d07758b..b7723aa09cbb 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -867,15 +867,18 @@ func Errorf(c codes.Code, format string, a ...any) error { return status.Errorf(c, format, a...) } +var errContextCanceled = status.Error(codes.Canceled, context.Canceled.Error()) +var errContextDeadline = status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error()) + // toRPCErr converts an error into an error from the status package. func toRPCErr(err error) error { switch err { case nil, io.EOF: return err case context.DeadlineExceeded: - return status.Error(codes.DeadlineExceeded, err.Error()) + return errContextDeadline case context.Canceled: - return status.Error(codes.Canceled, err.Error()) + return errContextCanceled case io.ErrUnexpectedEOF: return status.Error(codes.Internal, err.Error()) } diff --git a/stream.go b/stream.go index d7fb37c986b8..cf73147bdd22 100644 --- a/stream.go +++ b/stream.go @@ -789,23 +789,23 @@ func (cs *clientStream) withRetry(op func(a *csAttempt) error, onSuccess func()) func (cs *clientStream) Header() (metadata.MD, error) { var m metadata.MD - noHeader := false err := cs.withRetry(func(a *csAttempt) error { var err error m, err = a.s.Header() - if err == transport.ErrNoHeaders { - noHeader = true - return nil - } return toRPCErr(err) }, cs.commitAttemptLocked) + if m == nil && err == nil { + // The stream ended with success. Finish the clientStream. + err = io.EOF + } + if err != nil { cs.finish(err) return nil, err } - if len(cs.binlogs) != 0 && !cs.serverHeaderBinlogged && !noHeader { + if len(cs.binlogs) != 0 && !cs.serverHeaderBinlogged && m != nil { // Only log if binary log is on and header has not been logged, and // there is actually headers to log. logEntry := &binarylog.ServerHeader{ @@ -821,6 +821,7 @@ func (cs *clientStream) Header() (metadata.MD, error) { binlog.Log(cs.ctx, logEntry) } } + return m, nil } @@ -929,24 +930,6 @@ func (cs *clientStream) RecvMsg(m any) error { if err != nil || !cs.desc.ServerStreams { // err != nil or non-server-streaming indicates end of stream. cs.finish(err) - - if len(cs.binlogs) != 0 { - // finish will not log Trailer. Log Trailer here. - logEntry := &binarylog.ServerTrailer{ - OnClientSide: true, - Trailer: cs.Trailer(), - Err: err, - } - if logEntry.Err == io.EOF { - logEntry.Err = nil - } - if peer, ok := peer.FromContext(cs.Context()); ok { - logEntry.PeerAddr = peer.Addr - } - for _, binlog := range cs.binlogs { - binlog.Log(cs.ctx, logEntry) - } - } } return err } @@ -1002,18 +985,30 @@ func (cs *clientStream) finish(err error) { } } } + cs.mu.Unlock() - // For binary logging. only log cancel in finish (could be caused by RPC ctx - // canceled or ClientConn closed). Trailer will be logged in RecvMsg. - // - // Only one of cancel or trailer needs to be logged. In the cases where - // users don't call RecvMsg, users must have already canceled the RPC. - if len(cs.binlogs) != 0 && status.Code(err) == codes.Canceled { - c := &binarylog.Cancel{ - OnClientSide: true, - } - for _, binlog := range cs.binlogs { - binlog.Log(cs.ctx, c) + // Only one of cancel or trailer needs to be logged. + if len(cs.binlogs) != 0 { + switch err { + case errContextCanceled, errContextDeadline, ErrClientConnClosing: + c := &binarylog.Cancel{ + OnClientSide: true, + } + for _, binlog := range cs.binlogs { + binlog.Log(cs.ctx, c) + } + default: + logEntry := &binarylog.ServerTrailer{ + OnClientSide: true, + Trailer: cs.Trailer(), + Err: err, + } + if peer, ok := peer.FromContext(cs.Context()); ok { + logEntry.PeerAddr = peer.Addr + } + for _, binlog := range cs.binlogs { + binlog.Log(cs.ctx, logEntry) + } } } if err == nil { diff --git a/test/end2end_test.go b/test/end2end_test.go index cde1edef5ee3..948382d04f99 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -6328,12 +6328,11 @@ func (s) TestGlobalBinaryLoggingOptions(t *testing.T) { return &testpb.SimpleResponse{}, nil }, FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error { - for { - _, err := stream.Recv() - if err == io.EOF { - return nil - } + _, err := stream.Recv() + if err == io.EOF { + return nil } + return status.Errorf(codes.Unknown, "expected client to call CloseSend") }, } diff --git a/test/retry_test.go b/test/retry_test.go index bfbc051dcdee..49becb359097 100644 --- a/test/retry_test.go +++ b/test/retry_test.go @@ -211,6 +211,11 @@ func (s) TestRetryStreaming(t *testing.T) { return nil } } + sHdr := func() serverOp { + return func(stream testgrpc.TestService_FullDuplexCallServer) error { + return stream.SendHeader(metadata.Pairs("test_header", "test_value")) + } + } sRes := func(b byte) serverOp { return func(stream testgrpc.TestService_FullDuplexCallServer) error { msg := res(b) @@ -222,7 +227,7 @@ func (s) TestRetryStreaming(t *testing.T) { } sErr := func(c codes.Code) serverOp { return func(stream testgrpc.TestService_FullDuplexCallServer) error { - return status.New(c, "").Err() + return status.New(c, "this is a test error").Err() } } sCloseSend := func() serverOp { @@ -270,7 +275,7 @@ func (s) TestRetryStreaming(t *testing.T) { } cErr := func(c codes.Code) clientOp { return func(stream testgrpc.TestService_FullDuplexCallClient) error { - want := status.New(c, "").Err() + want := status.New(c, "this is a test error").Err() if c == codes.OK { want = io.EOF } @@ -309,6 +314,11 @@ func (s) TestRetryStreaming(t *testing.T) { cHdr := func() clientOp { return func(stream testgrpc.TestService_FullDuplexCallClient) error { _, err := stream.Header() + if err == io.EOF { + // The stream ended successfully; convert to nil to avoid + // erroring the test case. + err = nil + } return err } } @@ -362,9 +372,13 @@ func (s) TestRetryStreaming(t *testing.T) { sReq(1), sRes(3), sErr(codes.Unavailable), }, clientOps: []clientOp{cReq(1), cRes(3), cErr(codes.Unavailable)}, + }, { + desc: "Retry via ClientStream.Header()", + serverOps: []serverOp{sReq(1), sErr(codes.Unavailable), sReq(1), sAttempts(1)}, + clientOps: []clientOp{cReq(1), cHdr() /* this should cause a retry */, cErr(codes.OK)}, }, { desc: "No retry after header", - serverOps: []serverOp{sReq(1), sErr(codes.Unavailable)}, + serverOps: []serverOp{sReq(1), sHdr(), sErr(codes.Unavailable)}, clientOps: []clientOp{cReq(1), cHdr(), cErr(codes.Unavailable)}, }, { desc: "No retry after context",