diff --git a/api/v3rpc/rpctypes/error.go b/api/v3rpc/rpctypes/error.go index 26e3fd378c14..24fdc85736c4 100644 --- a/api/v3rpc/rpctypes/error.go +++ b/api/v3rpc/rpctypes/error.go @@ -35,6 +35,8 @@ var ( ErrGRPCLeaseExist = status.New(codes.FailedPrecondition, "etcdserver: lease already exists").Err() ErrGRPCLeaseTTLTooLarge = status.New(codes.OutOfRange, "etcdserver: too large lease TTL").Err() + ErrGRPCWatchCanceled = status.New(codes.Canceled, "etcdserver: watch canceled").Err() + ErrGRPCMemberExist = status.New(codes.FailedPrecondition, "etcdserver: member ID already exist").Err() ErrGRPCPeerURLExist = status.New(codes.FailedPrecondition, "etcdserver: Peer URLs already exists").Err() ErrGRPCMemberNotEnoughStarted = status.New(codes.FailedPrecondition, "etcdserver: re-configuration failed due to not enough started members").Err() diff --git a/server/etcdserver/api/v3rpc/interceptor.go b/server/etcdserver/api/v3rpc/interceptor.go index 124522e6b573..e2745bccf064 100644 --- a/server/etcdserver/api/v3rpc/interceptor.go +++ b/server/etcdserver/api/v3rpc/interceptor.go @@ -231,8 +231,8 @@ func newStreamInterceptor(s *etcdserver.EtcdServer) grpc.StreamServerInterceptor return rpctypes.ErrGRPCNoLeader } - cctx, cancel := context.WithCancel(ss.Context()) - ss = serverStreamWithCtx{ctx: cctx, cancel: &cancel, ServerStream: ss} + ctx := newCancellableContext(ss.Context()) + ss = serverStreamWithCtx{ctx: newCancellableContext(ss.Context()), ServerStream: ss} smap.mu.Lock() smap.streams[ss] = struct{}{} @@ -242,7 +242,8 @@ func newStreamInterceptor(s *etcdserver.EtcdServer) grpc.StreamServerInterceptor smap.mu.Lock() delete(smap.streams, ss) smap.mu.Unlock() - cancel() + // TODO: investigate whether the reason for cancellation here is useful to know + ctx.Cancel(nil) }() } } @@ -251,10 +252,52 @@ func newStreamInterceptor(s *etcdserver.EtcdServer) grpc.StreamServerInterceptor } } +// cancellableContext wraps a context with new cancellable context that allows a +// specific cancellation error to be preserved and later retrieved using the +// Context.Err() function. This is so downstream context users can disambiguate +// the reason for the cancellation which could be from the client (for example) +// or from this interceptor code. +type cancellableContext struct { + context.Context + + lock sync.RWMutex + cancel context.CancelFunc + cancelReason error +} + +func newCancellableContext(parent context.Context) *cancellableContext { + ctx, cancel := context.WithCancel(parent) + return &cancellableContext{ + Context: ctx, + cancel: cancel, + } +} + +// Cancel stores the cancellation reason and then delegates to context.WithCancel +// against the parent context. +func (c *cancellableContext) Cancel(reason error) { + c.lock.Lock() + c.cancelReason = reason + c.lock.Unlock() + c.cancel() +} + +// Err will return the preserved cancel reason error if present, and will +// otherwise return the underlying error from the parent context. +func (c *cancellableContext) Err() error { + c.lock.RLock() + defer c.lock.RUnlock() + if c.cancelReason != nil { + return c.cancelReason + } + return c.Context.Err() +} + type serverStreamWithCtx struct { grpc.ServerStream - ctx context.Context - cancel *context.CancelFunc + + // ctx is used so that we can preserve a reason for cancellation. + ctx *cancellableContext } func (ssc serverStreamWithCtx) Context() context.Context { return ssc.ctx } @@ -286,7 +329,7 @@ func monitorLeader(s *etcdserver.EtcdServer) *streamsMap { smap.mu.Lock() for ss := range smap.streams { if ssWithCtx, ok := ss.(serverStreamWithCtx); ok { - (*ssWithCtx.cancel)() + ssWithCtx.ctx.Cancel(rpctypes.ErrGRPCNoLeader) <-ss.Context().Done() } } diff --git a/server/etcdserver/api/v3rpc/watch.go b/server/etcdserver/api/v3rpc/watch.go index df876232cf56..4531dbe60bf5 100644 --- a/server/etcdserver/api/v3rpc/watch.go +++ b/server/etcdserver/api/v3rpc/watch.go @@ -197,15 +197,25 @@ func (ws *watchServer) Watch(stream pb.Watch_WatchServer) (err error) { } }() + // TODO: There's a race here. When a stream is closed (e.g. due to a cancellation), + // the underlying error (e.g. a gRPC stream error) may be returned and handled + // through errc if the recv goroutine finishes before the send goroutine. + // When the recv goroutine wins, the stream error is retained. When recv loses + // the race, the underlying error is lost (unless the root error is propagated + // through Context.Err() which is not always the case (as callers have to decide + // to implement a custom context to do so). The stdlib context package builtins + // may be insufficient to carry semantically useful errors around and should be + // revisited. select { case err = <-errc: + if err == context.Canceled { + err = rpctypes.ErrGRPCWatchCanceled + } close(sws.ctrlStream) - case <-stream.Context().Done(): err = stream.Context().Err() - // the only server-side cancellation is noleader for now. if err == context.Canceled { - err = rpctypes.ErrGRPCNoLeader + err = rpctypes.ErrGRPCWatchCanceled } } diff --git a/tests/e2e/metrics_test.go b/tests/e2e/metrics_test.go index 6ae6d2fed91f..3ace2ea36aba 100644 --- a/tests/e2e/metrics_test.go +++ b/tests/e2e/metrics_test.go @@ -49,6 +49,7 @@ func metricsTest(cx ctlCtx) { {"/metrics", fmt.Sprintf("etcd_mvcc_delete_total 3")}, {"/metrics", fmt.Sprintf(`etcd_server_version{server_version="%s"} 1`, version.Version)}, {"/metrics", fmt.Sprintf(`etcd_cluster_version{cluster_version="%s"} 1`, version.Cluster(version.Version))}, + {"/metrics", fmt.Sprintf(`grpc_server_handled_total{grpc_code="Canceled",grpc_method="Watch",grpc_service="etcdserverpb.Watch",grpc_type="bidi_stream"} 6`)}, {"/health", `{"health":"true","reason":""}`}, } { i++ @@ -58,7 +59,9 @@ func metricsTest(cx ctlCtx) { if err := ctlV3Del(cx, []string{fmt.Sprintf("%d", i)}, 1); err != nil { cx.t.Fatal(err) } - + if err := ctlV3Watch(cx, []string{"k", "--rev", "1"}, []kvExec{{key: "k", val: "v"}}...); err != nil { + cx.t.Fatal(err) + } if err := cURLGet(cx.epc, cURLReq{endpoint: test.endpoint, expected: test.expected, metricsURLScheme: cx.cfg.metricsURLScheme}); err != nil { cx.t.Fatalf("failed get with curl (%v)", err) }