From e874e60a10dc72c340c8591d9f839dd35cdb3bab Mon Sep 17 00:00:00 2001 From: Bogdan Kanivets Date: Wed, 29 Jun 2022 01:02:46 -0700 Subject: [PATCH] server: don't panic in readonly serializable txn Problem: We pass grpc context down to applier in readonly serializable txn. This context can be cancelled for example due to timeout. This will trigger panic inside applyTxn Solution: Only panic for transactions with write operations fixes https://github.com/etcd-io/etcd/issues/14110 backported from main Signed-off-by: Bogdan Kanivets --- server/etcdserver/api/v3rpc/grpc.go | 6 +-- server/etcdserver/apply.go | 27 +++++++---- server/mvcc/kvstore_txn.go | 3 +- tests/integration/cluster.go | 11 ++++- tests/integration/v3_grpc_test.go | 75 +++++++++++++++++++++++++++++ 5 files changed, 107 insertions(+), 15 deletions(-) diff --git a/server/etcdserver/api/v3rpc/grpc.go b/server/etcdserver/api/v3rpc/grpc.go index ea3dd75705fd..409a1c39a988 100644 --- a/server/etcdserver/api/v3rpc/grpc.go +++ b/server/etcdserver/api/v3rpc/grpc.go @@ -36,7 +36,7 @@ const ( maxSendBytes = math.MaxInt32 ) -func Server(s *etcdserver.EtcdServer, tls *tls.Config, interceptor grpc.UnaryServerInterceptor, gopts ...grpc.ServerOption) *grpc.Server { +func Server(s *etcdserver.EtcdServer, tls *tls.Config, interceptors []grpc.UnaryServerInterceptor, gopts ...grpc.ServerOption) *grpc.Server { var opts []grpc.ServerOption opts = append(opts, grpc.CustomCodec(&codec{})) if tls != nil { @@ -48,8 +48,8 @@ func Server(s *etcdserver.EtcdServer, tls *tls.Config, interceptor grpc.UnarySer newUnaryInterceptor(s), grpc_prometheus.UnaryServerInterceptor, } - if interceptor != nil { - chainUnaryInterceptors = append(chainUnaryInterceptors, interceptor) + if interceptors != nil { + chainUnaryInterceptors = append(chainUnaryInterceptors, interceptors...) } chainStreamInterceptors := []grpc.StreamServerInterceptor{ diff --git a/server/etcdserver/apply.go b/server/etcdserver/apply.go index 5a77ef377342..f81cb0063772 100644 --- a/server/etcdserver/apply.go +++ b/server/etcdserver/apply.go @@ -428,6 +428,7 @@ func (a *applierV3backend) Range(ctx context.Context, txn mvcc.TxnRead, r *pb.Ra } func (a *applierV3backend) Txn(ctx context.Context, rt *pb.TxnRequest) (*pb.TxnResponse, *traceutil.Trace, error) { + lg := a.s.Logger() trace := traceutil.Get(ctx) if trace.IsEmpty() { trace = traceutil.New("transaction", a.s.Logger()) @@ -474,7 +475,13 @@ func (a *applierV3backend) Txn(ctx context.Context, rt *pb.TxnRequest) (*pb.TxnR txn.End() txn = a.s.KV().Write(trace) } - a.applyTxn(ctx, txn, rt, txnPath, txnResp) + _, err := a.applyTxn(ctx, txn, rt, txnPath, txnResp) + if err != nil && isWrite { + // When txn with write operations starts it has to be successful + // We don't have a way to recover state in case of write failure + lg.Panic("unexpected error during txnWrite", zap.Error(err)) + } + rev := txn.Rev() if len(txn.Changes()) != 0 { rev++ @@ -486,7 +493,7 @@ func (a *applierV3backend) Txn(ctx context.Context, rt *pb.TxnRequest) (*pb.TxnR traceutil.Field{Key: "number_of_response", Value: len(txnResp.Responses)}, traceutil.Field{Key: "response_revision", Value: txnResp.Header.Revision}, ) - return txnResp, trace, nil + return txnResp, trace, err } // newTxnResp allocates a txn response for a txn request given a path. @@ -617,14 +624,13 @@ func compareKV(c *pb.Compare, ckv mvccpb.KeyValue) bool { return true } -func (a *applierV3backend) applyTxn(ctx context.Context, txn mvcc.TxnWrite, rt *pb.TxnRequest, txnPath []bool, tresp *pb.TxnResponse) (txns int) { +func (a *applierV3backend) applyTxn(ctx context.Context, txn mvcc.TxnWrite, rt *pb.TxnRequest, txnPath []bool, tresp *pb.TxnResponse) (txns int, err error) { trace := traceutil.Get(ctx) reqs := rt.Success if !txnPath[0] { reqs = rt.Failure } - lg := a.s.Logger() for i, req := range reqs { respi := tresp.Responses[i].Response switch tv := req.Request.(type) { @@ -635,7 +641,7 @@ func (a *applierV3backend) applyTxn(ctx context.Context, txn mvcc.TxnWrite, rt * traceutil.Field{Key: "range_end", Value: string(tv.RequestRange.RangeEnd)}) resp, err := a.Range(ctx, txn, tv.RequestRange) if err != nil { - lg.Panic("unexpected error during txn", zap.Error(err)) + return 0, err } respi.(*pb.ResponseOp_ResponseRange).ResponseRange = resp trace.StopSubTrace() @@ -646,26 +652,29 @@ func (a *applierV3backend) applyTxn(ctx context.Context, txn mvcc.TxnWrite, rt * traceutil.Field{Key: "req_size", Value: tv.RequestPut.Size()}) resp, _, err := a.Put(ctx, txn, tv.RequestPut) if err != nil { - lg.Panic("unexpected error during txn", zap.Error(err)) + return 0, err } respi.(*pb.ResponseOp_ResponsePut).ResponsePut = resp trace.StopSubTrace() case *pb.RequestOp_RequestDeleteRange: resp, err := a.DeleteRange(txn, tv.RequestDeleteRange) if err != nil { - lg.Panic("unexpected error during txn", zap.Error(err)) + return 0, err } respi.(*pb.ResponseOp_ResponseDeleteRange).ResponseDeleteRange = resp case *pb.RequestOp_RequestTxn: resp := respi.(*pb.ResponseOp_ResponseTxn).ResponseTxn - applyTxns := a.applyTxn(ctx, txn, tv.RequestTxn, txnPath[1:], resp) + applyTxns, err := a.applyTxn(ctx, txn, tv.RequestTxn, txnPath[1:], resp) + if err != nil { + return 0, err + } txns += applyTxns + 1 txnPath = txnPath[applyTxns+1:] default: // empty union } } - return txns + return txns, nil } func (a *applierV3backend) Compaction(compaction *pb.CompactionRequest) (*pb.CompactionResponse, <-chan struct{}, *traceutil.Trace, error) { diff --git a/server/mvcc/kvstore_txn.go b/server/mvcc/kvstore_txn.go index 9df7b79410f5..6dfa0353d5ac 100644 --- a/server/mvcc/kvstore_txn.go +++ b/server/mvcc/kvstore_txn.go @@ -16,6 +16,7 @@ package mvcc import ( "context" + "fmt" "go.etcd.io/etcd/api/v3/mvccpb" "go.etcd.io/etcd/pkg/v3/traceutil" @@ -156,7 +157,7 @@ func (tr *storeTxnRead) rangeKeys(ctx context.Context, key, end []byte, curRev i for i, revpair := range revpairs[:len(kvs)] { select { case <-ctx.Done(): - return nil, ctx.Err() + return nil, fmt.Errorf("range context cancelled: %w", ctx.Err()) default: } revToBytes(revpair, revBytes) diff --git a/tests/integration/cluster.go b/tests/integration/cluster.go index fedad797ae9a..cccef1845456 100644 --- a/tests/integration/cluster.go +++ b/tests/integration/cluster.go @@ -171,6 +171,9 @@ type ClusterConfig struct { WatchProgressNotifyInterval time.Duration CorruptCheckTime time.Duration + // GrpcInterceptors allows to add additional interceptors to GrpcServer for testing + // For example can be used to cancel context on demand + GrpcInterceptors []grpc.UnaryServerInterceptor } type cluster struct { @@ -334,6 +337,7 @@ func (c *cluster) mustNewMember(t testutil.TB, memberNumber int64) *member { leaseCheckpointInterval: c.cfg.LeaseCheckpointInterval, WatchProgressNotifyInterval: c.cfg.WatchProgressNotifyInterval, CorruptCheckTime: c.cfg.CorruptCheckTime, + GrpcInterceptors: c.cfg.GrpcInterceptors, }) m.DiscoveryURL = c.cfg.DiscoveryURL if c.cfg.UseGRPC { @@ -609,6 +613,7 @@ type member struct { closed bool grpcServerRecorder *grpc_testing.GrpcRecorder + GrpcInterceptors []grpc.UnaryServerInterceptor } func (m *member) GRPCURL() string { return m.grpcURL } @@ -638,6 +643,7 @@ type memberConfig struct { leaseCheckpointPersist bool WatchProgressNotifyInterval time.Duration CorruptCheckTime time.Duration + GrpcInterceptors []grpc.UnaryServerInterceptor } // mustNewMember return an inited member with the given name. If peerTLS is @@ -747,6 +753,7 @@ func mustNewMember(t testutil.TB, mcfg memberConfig) *member { m.V2Deprecation = config.V2_DEPR_DEFAULT m.grpcServerRecorder = &grpc_testing.GrpcRecorder{} + m.GrpcInterceptors = append(mcfg.GrpcInterceptors, m.grpcServerRecorder.UnaryInterceptor()) m.Logger = memberLogger(t, mcfg.name) t.Cleanup(func() { // if we didn't cleanup the logger, the consecutive test @@ -958,8 +965,8 @@ func (m *member) Launch() error { return err } } - m.grpcServer = v3rpc.Server(m.s, tlscfg, m.grpcServerRecorder.UnaryInterceptor(), m.grpcServerOpts...) - m.grpcServerPeer = v3rpc.Server(m.s, peerTLScfg, m.grpcServerRecorder.UnaryInterceptor()) + m.grpcServer = v3rpc.Server(m.s, tlscfg, m.GrpcInterceptors, m.grpcServerOpts...) + m.grpcServerPeer = v3rpc.Server(m.s, peerTLScfg, m.GrpcInterceptors) m.serverClient = v3client.New(m.s) lockpb.RegisterLockServer(m.grpcServer, v3lock.NewLockServer(m.serverClient)) epb.RegisterElectionServer(m.grpcServer, v3election.NewElectionServer(m.serverClient)) diff --git a/tests/integration/v3_grpc_test.go b/tests/integration/v3_grpc_test.go index 1cbee679a657..25e6afc3d48e 100644 --- a/tests/integration/v3_grpc_test.go +++ b/tests/integration/v3_grpc_test.go @@ -17,6 +17,7 @@ package integration import ( "bytes" "context" + "errors" "fmt" "io/ioutil" "math/rand" @@ -1975,3 +1976,77 @@ func waitForRestart(t *testing.T, kvc pb.KVClient) { t.Fatalf("timed out waiting for restart: %v", err) } } + +func TestV3ReadonlyTxnCancelledContext(t *testing.T) { + BeforeTest(t) + clus := NewClusterV3(t, &ClusterConfig{ + Size: 1, + // Context should be cancelled on the second check that happens inside rangeKeys + GrpcInterceptors: []grpc.UnaryServerInterceptor{injectMockContextForTxn(newMockContext(2))}, + }) + defer clus.Terminate(t) + + kvc := toGRPC(clus.RandClient()).KV + pr := &pb.PutRequest{Key: []byte("abc"), Value: []byte("def")} + _, err := kvc.Put(context.TODO(), pr) + if err != nil { + t.Fatal(err) + } + + txnget := &pb.RequestOp{Request: &pb.RequestOp_RequestRange{RequestRange: &pb.RangeRequest{Key: []byte("abc")}}} + txn := &pb.TxnRequest{Success: []*pb.RequestOp{txnget}} + _, err = kvc.Txn(context.TODO(), txn) + if err == nil || !strings.Contains(err.Error(), "range context cancelled: mock context error") { + t.Fatal(err) + } +} + +type mockCtx struct { + calledDone int + doneAfter int + + donec chan struct{} +} + +func newMockContext(doneAfter int) context.Context { + return &mockCtx{ + calledDone: 0, + doneAfter: doneAfter, + donec: make(chan struct{}), + } +} + +func (*mockCtx) Deadline() (deadline time.Time, ok bool) { + return +} + +func (ctx *mockCtx) Done() <-chan struct{} { + ctx.calledDone++ + if ctx.calledDone == ctx.doneAfter { + close(ctx.donec) + } + return ctx.donec +} + +func (*mockCtx) Err() error { + return errors.New("mock context error") +} + +func (*mockCtx) Value(interface{}) interface{} { + return nil +} + +func (*mockCtx) String() string { + return "mock Context" +} + +func injectMockContextForTxn(mctx context.Context) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + switch req.(type) { + case *pb.TxnRequest: + return handler(mctx, req) + default: + return handler(ctx, req) + } + } +}