Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stream: fix calloption.After() race in finish #3672

Merged
merged 7 commits into from
Jun 11, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 21 additions & 22 deletions rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ func (d *gzipDecompressor) Type() string {
type callInfo struct {
compressorType string
failFast bool
stream ClientStream
maxReceiveMessageSize *int
maxSendMessageSize *int
creds credentials.PerRPCCredentials
Expand All @@ -180,16 +179,16 @@ type CallOption interface {

// after is called after the call has completed. after cannot return an
// error, so any failures should be reported via output parameters.
after(*callInfo)
after(*callInfo, *csAttempt)
}

// EmptyCallOption does not alter the Call configuration.
// It can be embedded in another structure to carry satellite data for use
// by interceptors.
type EmptyCallOption struct{}

func (EmptyCallOption) before(*callInfo) error { return nil }
func (EmptyCallOption) after(*callInfo) {}
func (EmptyCallOption) before(*callInfo) error { return nil }
func (EmptyCallOption) after(*callInfo, *csAttempt) {}

// Header returns a CallOptions that retrieves the header metadata
// for a unary RPC.
Expand All @@ -205,9 +204,9 @@ type HeaderCallOption struct {
}

func (o HeaderCallOption) before(c *callInfo) error { return nil }
func (o HeaderCallOption) after(c *callInfo) {
if c.stream != nil {
*o.HeaderAddr, _ = c.stream.Header()
func (o HeaderCallOption) after(c *callInfo, attempt *csAttempt) {
if attempt != nil {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we now assume this is never called with a nil attempt or attempt.s?

Copy link
Contributor Author

@menghanl menghanl Jun 10, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If rpc is canceled before we make the attempt (e.g. blocking on pick)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But there's only one call site (right?) and it's guarded by if attempt != nil && attempt.s != nil.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, right. Removed

*o.HeaderAddr, _ = attempt.s.Header()
}
}

Expand All @@ -225,9 +224,9 @@ type TrailerCallOption struct {
}

func (o TrailerCallOption) before(c *callInfo) error { return nil }
func (o TrailerCallOption) after(c *callInfo) {
if c.stream != nil {
*o.TrailerAddr = c.stream.Trailer()
func (o TrailerCallOption) after(c *callInfo, attempt *csAttempt) {
if attempt != nil {
*o.TrailerAddr = attempt.s.Trailer()
}
}

Expand All @@ -245,9 +244,9 @@ type PeerCallOption struct {
}

func (o PeerCallOption) before(c *callInfo) error { return nil }
func (o PeerCallOption) after(c *callInfo) {
if c.stream != nil {
if x, ok := peer.FromContext(c.stream.Context()); ok {
func (o PeerCallOption) after(c *callInfo, attempt *csAttempt) {
if attempt != nil {
if x, ok := peer.FromContext(attempt.s.Context()); ok {
*o.PeerAddr = *x
}
}
Expand Down Expand Up @@ -285,7 +284,7 @@ func (o FailFastCallOption) before(c *callInfo) error {
c.failFast = o.FailFast
return nil
}
func (o FailFastCallOption) after(c *callInfo) {}
func (o FailFastCallOption) after(c *callInfo, attempt *csAttempt) {}

// MaxCallRecvMsgSize returns a CallOption which sets the maximum message size
// in bytes the client can receive.
Expand All @@ -304,7 +303,7 @@ func (o MaxRecvMsgSizeCallOption) before(c *callInfo) error {
c.maxReceiveMessageSize = &o.MaxRecvMsgSize
return nil
}
func (o MaxRecvMsgSizeCallOption) after(c *callInfo) {}
func (o MaxRecvMsgSizeCallOption) after(c *callInfo, attempt *csAttempt) {}

// MaxCallSendMsgSize returns a CallOption which sets the maximum message size
// in bytes the client can send.
Expand All @@ -323,7 +322,7 @@ func (o MaxSendMsgSizeCallOption) before(c *callInfo) error {
c.maxSendMessageSize = &o.MaxSendMsgSize
return nil
}
func (o MaxSendMsgSizeCallOption) after(c *callInfo) {}
func (o MaxSendMsgSizeCallOption) after(c *callInfo, attempt *csAttempt) {}

// PerRPCCredentials returns a CallOption that sets credentials.PerRPCCredentials
// for a call.
Expand All @@ -342,7 +341,7 @@ func (o PerRPCCredsCallOption) before(c *callInfo) error {
c.creds = o.Creds
return nil
}
func (o PerRPCCredsCallOption) after(c *callInfo) {}
func (o PerRPCCredsCallOption) after(c *callInfo, attempt *csAttempt) {}

// UseCompressor returns a CallOption which sets the compressor used when
// sending the request. If WithCompressor is also set, UseCompressor has
Expand All @@ -363,7 +362,7 @@ func (o CompressorCallOption) before(c *callInfo) error {
c.compressorType = o.CompressorType
return nil
}
func (o CompressorCallOption) after(c *callInfo) {}
func (o CompressorCallOption) after(c *callInfo, attempt *csAttempt) {}

// CallContentSubtype returns a CallOption that will set the content-subtype
// for a call. For example, if content-subtype is "json", the Content-Type over
Expand Down Expand Up @@ -396,7 +395,7 @@ func (o ContentSubtypeCallOption) before(c *callInfo) error {
c.contentSubtype = o.ContentSubtype
return nil
}
func (o ContentSubtypeCallOption) after(c *callInfo) {}
func (o ContentSubtypeCallOption) after(c *callInfo, attempt *csAttempt) {}

// ForceCodec returns a CallOption that will set the given Codec to be
// used for all request and response messages for a call. The result of calling
Expand Down Expand Up @@ -428,7 +427,7 @@ func (o ForceCodecCallOption) before(c *callInfo) error {
c.codec = o.Codec
return nil
}
func (o ForceCodecCallOption) after(c *callInfo) {}
func (o ForceCodecCallOption) after(c *callInfo, attempt *csAttempt) {}

// CallCustomCodec behaves like ForceCodec, but accepts a grpc.Codec instead of
// an encoding.Codec.
Expand All @@ -450,7 +449,7 @@ func (o CustomCodecCallOption) before(c *callInfo) error {
c.codec = o.Codec
return nil
}
func (o CustomCodecCallOption) after(c *callInfo) {}
func (o CustomCodecCallOption) after(c *callInfo, attempt *csAttempt) {}

// MaxRetryRPCBufferSize returns a CallOption that limits the amount of memory
// used for buffering this RPC's requests for retry purposes.
Expand All @@ -471,7 +470,7 @@ func (o MaxRetryRPCBufferSizeCallOption) before(c *callInfo) error {
c.maxRetryRPCBufferSize = o.MaxRetryRPCBufferSize
return nil
}
func (o MaxRetryRPCBufferSizeCallOption) after(c *callInfo) {}
func (o MaxRetryRPCBufferSizeCallOption) after(c *callInfo, attempt *csAttempt) {}

// The format of the payload: compressed or not?
type payloadFormat uint8
Expand Down
20 changes: 9 additions & 11 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,6 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
}
cs.binlog = binarylog.GetMethodLogger(method)

cs.callInfo.stream = cs
// Only this initial attempt has stats/tracing.
// TODO(dfawley): move to newAttempt when per-attempt stats are implemented.
if err := cs.newAttemptLocked(sh, trInfo); err != nil {
Expand Down Expand Up @@ -799,6 +798,15 @@ func (cs *clientStream) finish(err error) {
}
cs.finished = true
cs.commitAttemptLocked()
if cs.attempt != nil {
cs.attempt.finish(err)
// after functions all rely upon having a stream.
if cs.attempt.s != nil {
for _, o := range cs.opts {
o.after(cs.callInfo, cs.attempt)
}
}
}
cs.mu.Unlock()
// For binary logging. only log cancel in finish (could be caused by RPC ctx
// canceled or ClientConn closed). Trailer will be logged in RecvMsg.
Expand All @@ -820,15 +828,6 @@ func (cs *clientStream) finish(err error) {
cs.cc.incrCallsSucceeded()
}
}
if cs.attempt != nil {
cs.attempt.finish(err)
// after functions all rely upon having a stream.
if cs.attempt.s != nil {
for _, o := range cs.opts {
o.after(cs.callInfo)
}
}
}
cs.cancel()
}

Expand Down Expand Up @@ -1066,7 +1065,6 @@ func newNonRetryClientStream(ctx context.Context, desc *StreamDesc, method strin
t: t,
}

as.callInfo.stream = as
s, err := as.t.NewStream(as.ctx, as.callHdr)
if err != nil {
err = toRPCErr(err)
Expand Down
59 changes: 59 additions & 0 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7128,3 +7128,62 @@ func (s) TestGzipBadChecksum(t *testing.T) {
t.Errorf("ss.client.UnaryCall(_) = _, %v\n\twant: _, status(codes.Internal, contains %q)", err, gzip.ErrChecksum)
}
}

// When an RPC is canceled, it's possible that the last Recv() returns before
// all call options' after are executed.
func (s) TestCanceledRPCCallOptionRace(t *testing.T) {
ss := &stubServer{
fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error {
err := stream.Send(&testpb.StreamingOutputCallResponse{})
if err != nil {
return err
}
<-stream.Context().Done()
return nil
},
}
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

const count = 1000
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

var wg sync.WaitGroup
wg.Add(count)
for i := 0; i < count; i++ {
go func() {
defer wg.Done()
var p peer.Peer
ctx, cancel := context.WithCancel(ctx)
defer cancel()
stream, err := ss.client.FullDuplexCall(ctx, grpc.Peer(&p))
if err != nil {
t.Errorf("_.FullDuplexCall(_) = _, %v", err)
return
}
if err := stream.Send(&testpb.StreamingOutputCallRequest{}); err != nil {
t.Errorf("_ has error %v while sending", err)
return
}
if _, err := stream.Recv(); err != nil {
t.Errorf("%v.Recv() = %v", stream, err)
return
}
cancel()
if _, err := stream.Recv(); status.Code(err) != codes.Canceled {
t.Errorf("%v compleled with error %v, want %s", stream, err, codes.Canceled)
return
}
// If recv returns before call options are executed, peer.Addr is not set,
// fail the test.
if p.Addr == nil {
t.Errorf("peer.Addr is nil, want non-nil")
return
}
}()
}
wg.Wait()
}