Skip to content

Commit

Permalink
Fix batch client batchSendLoop panic (#1021) (#1022)
Browse files Browse the repository at this point in the history
Signed-off-by: crazycs520 <crazycs520@gmail.com>
  • Loading branch information
crazycs520 authored Oct 18, 2023
1 parent 14934ce commit 916bb20
Show file tree
Hide file tree
Showing 9 changed files with 141 additions and 51 deletions.
15 changes: 7 additions & 8 deletions internal/apicodec/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,28 +92,27 @@ func attachAPICtx(c Codec, req *tikvrpc.Request) *tikvrpc.Request {
// Shallow copy the request to avoid concurrent modification.
r := *req

ctx := &r.Context
ctx.ApiVersion = c.GetAPIVersion()
ctx.KeyspaceId = uint32(c.GetKeyspaceID())
r.Context.ApiVersion = c.GetAPIVersion()
r.Context.KeyspaceId = uint32(c.GetKeyspaceID())

switch r.Type {
case tikvrpc.CmdMPPTask:
mpp := *r.DispatchMPPTask()
// Shallow copy the meta to avoid concurrent modification.
meta := *mpp.Meta
meta.KeyspaceId = ctx.KeyspaceId
meta.ApiVersion = ctx.ApiVersion
meta.KeyspaceId = r.Context.KeyspaceId
meta.ApiVersion = r.Context.ApiVersion
mpp.Meta = &meta
r.Req = &mpp

case tikvrpc.CmdCompact:
compact := *r.Compact()
compact.KeyspaceId = ctx.KeyspaceId
compact.ApiVersion = ctx.ApiVersion
compact.KeyspaceId = r.Context.KeyspaceId
compact.ApiVersion = r.Context.ApiVersion
r.Req = &compact
}

tikvrpc.AttachContext(&r, ctx)
tikvrpc.AttachContext(&r, r.Context)

return &r
}
12 changes: 8 additions & 4 deletions internal/client/client_batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,14 +296,18 @@ func (a *batchConn) fetchMorePendingRequests(

const idleTimeout = 3 * time.Minute

// BatchSendLoopPanicCounter is only used for testing.
var BatchSendLoopPanicCounter int64 = 0

func (a *batchConn) batchSendLoop(cfg config.TiKVClient) {
defer func() {
if r := recover(); r != nil {
metrics.TiKVPanicCounter.WithLabelValues(metrics.LabelBatchSendLoop).Inc()
logutil.BgLogger().Error("batchSendLoop",
zap.Reflect("r", r),
zap.Any("r", r),
zap.Stack("stack"))
logutil.BgLogger().Info("restart batchSendLoop")
atomic.AddInt64(&BatchSendLoopPanicCounter, 1)
logutil.BgLogger().Info("restart batchSendLoop", zap.Int64("count", atomic.LoadInt64(&BatchSendLoopPanicCounter)))
go a.batchSendLoop(cfg)
}
}()
Expand Down Expand Up @@ -430,7 +434,7 @@ func (s *batchCommandsStream) recv() (resp *tikvpb.BatchCommandsResponse, err er
if r := recover(); r != nil {
metrics.TiKVPanicCounter.WithLabelValues(metrics.LabelBatchRecvLoop).Inc()
logutil.BgLogger().Error("batchCommandsClient.recv panic",
zap.Reflect("r", r),
zap.Any("r", r),
zap.Stack("stack"))
err = errors.New("batch conn recv paniced")
}
Expand Down Expand Up @@ -595,7 +599,7 @@ func (c *batchCommandsClient) batchRecvLoop(cfg config.TiKVClient, tikvTransport
if r := recover(); r != nil {
metrics.TiKVPanicCounter.WithLabelValues(metrics.LabelBatchRecvLoop).Inc()
logutil.BgLogger().Error("batchRecvLoop",
zap.Reflect("r", r),
zap.Any("r", r),
zap.Stack("stack"))
logutil.BgLogger().Info("restart batchRecvLoop")
go c.batchRecvLoop(cfg, tikvTransportLayerLoad, streamClient)
Expand Down
10 changes: 5 additions & 5 deletions internal/client/client_fail_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ package client

import (
"context"
"fmt"
"sync/atomic"
"testing"
"time"
Expand All @@ -47,18 +46,19 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tikv/client-go/v2/config"
"github.com/tikv/client-go/v2/internal/client/mock_server"
"github.com/tikv/client-go/v2/tikvrpc"
)

func TestPanicInRecvLoop(t *testing.T) {
require.Nil(t, failpoint.Enable("tikvclient/panicInFailPendingRequests", `panic`))
require.Nil(t, failpoint.Enable("tikvclient/gotErrorInRecvLoop", `return("0")`))

server, port := startMockTikvService()
server, port := mock_server.StartMockTikvService()
require.True(t, port > 0)
defer server.Stop()

addr := fmt.Sprintf("%s:%d", "127.0.0.1", port)
addr := server.Addr()
rpcClient := NewRPCClient()
defer rpcClient.Close()
rpcClient.option.dialTimeout = time.Second / 3
Expand All @@ -81,10 +81,10 @@ func TestPanicInRecvLoop(t *testing.T) {
}

func TestRecvErrorInMultipleRecvLoops(t *testing.T) {
server, port := startMockTikvService()
server, port := mock_server.StartMockTikvService()
require.True(t, port > 0)
defer server.Stop()
addr := fmt.Sprintf("%s:%d", "127.0.0.1", port)
addr := server.Addr()

// Enable batch and limit the connection count to 1 so that
// there is only one BatchCommands stream for each host or forwarded host.
Expand Down
17 changes: 9 additions & 8 deletions internal/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tikv/client-go/v2/config"
"github.com/tikv/client-go/v2/internal/client/mock_server"
"github.com/tikv/client-go/v2/tikvrpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/metadata"
Expand Down Expand Up @@ -114,12 +115,12 @@ func TestCancelTimeoutRetErr(t *testing.T) {
}

func TestSendWhenReconnect(t *testing.T) {
server, port := startMockTikvService()
server, port := mock_server.StartMockTikvService()
require.True(t, port > 0)

rpcClient := NewRPCClient()
defer rpcClient.Close()
addr := fmt.Sprintf("%s:%d", "127.0.0.1", port)
addr := server.Addr()
conn, err := rpcClient.getConnArray(addr, true)
assert.Nil(t, err)

Expand Down Expand Up @@ -238,7 +239,7 @@ func TestCollapseResolveLock(t *testing.T) {
}

func TestForwardMetadataByUnaryCall(t *testing.T) {
server, port := startMockTikvService()
server, port := mock_server.StartMockTikvService()
require.True(t, port > 0)
defer server.Stop()
addr := fmt.Sprintf("%s:%d", "127.0.0.1", port)
Expand All @@ -253,7 +254,7 @@ func TestForwardMetadataByUnaryCall(t *testing.T) {

var checkCnt uint64
// Check no corresponding metadata if ForwardedHost is empty.
server.setMetaChecker(func(ctx context.Context) error {
server.SetMetaChecker(func(ctx context.Context) error {
atomic.AddUint64(&checkCnt, 1)
// gRPC may set some metadata by default, e.g. "context-type".
md, ok := metadata.FromIncomingContext(ctx)
Expand Down Expand Up @@ -281,7 +282,7 @@ func TestForwardMetadataByUnaryCall(t *testing.T) {
checkCnt = 0
forwardedHost := "127.0.0.1:6666"
// Check the metadata exists.
server.setMetaChecker(func(ctx context.Context) error {
server.SetMetaChecker(func(ctx context.Context) error {
atomic.AddUint64(&checkCnt, 1)
// gRPC may set some metadata by default, e.g. "context-type".
md, ok := metadata.FromIncomingContext(ctx)
Expand All @@ -306,10 +307,10 @@ func TestForwardMetadataByUnaryCall(t *testing.T) {
}

func TestForwardMetadataByBatchCommands(t *testing.T) {
server, port := startMockTikvService()
server, port := mock_server.StartMockTikvService()
require.True(t, port > 0)
defer server.Stop()
addr := fmt.Sprintf("%s:%d", "127.0.0.1", port)
addr := server.Addr()

// Enable batch and limit the connection count to 1 so that
// there is only one BatchCommands stream for each host or forwarded host.
Expand All @@ -322,7 +323,7 @@ func TestForwardMetadataByBatchCommands(t *testing.T) {

var checkCnt uint64
setCheckHandler := func(forwardedHost string) {
server.setMetaChecker(func(ctx context.Context) error {
server.SetMetaChecker(func(ctx context.Context) error {
atomic.AddUint64(&checkCnt, 1)
md, ok := metadata.FromIncomingContext(ctx)
if forwardedHost == "" {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
// https://github.com/pingcap/tidb/tree/cc5e161ac06827589c4966674597c137cc9e809c/store/tikv/client/mock_tikv_service_test.go
//

package client
package mock_server

import (
"context"
"fmt"
"net"
"sync"
"sync/atomic"
"time"

"github.com/pingcap/kvproto/pkg/coprocessor"
Expand All @@ -35,9 +36,11 @@ import (
"google.golang.org/grpc"
)

type server struct {
type MockServer struct {
tikvpb.TikvServer
grpcServer *grpc.Server
addr string
running int64 // 0: not running, 1: running
// metaChecker check the metadata of each request. Now only requests
// which need redirection set it.
metaChecker struct {
Expand All @@ -46,21 +49,28 @@ type server struct {
}
}

func (s *server) KvPrewrite(ctx context.Context, req *kvrpcpb.PrewriteRequest) (*kvrpcpb.PrewriteResponse, error) {
func (s *MockServer) KvGet(ctx context.Context, req *kvrpcpb.GetRequest) (*kvrpcpb.GetResponse, error) {
if err := s.checkMetadata(ctx); err != nil {
return nil, err
}
return &kvrpcpb.GetResponse{}, nil
}

func (s *MockServer) KvPrewrite(ctx context.Context, req *kvrpcpb.PrewriteRequest) (*kvrpcpb.PrewriteResponse, error) {
if err := s.checkMetadata(ctx); err != nil {
return nil, err
}
return &kvrpcpb.PrewriteResponse{}, nil
}

func (s *server) CoprocessorStream(req *coprocessor.Request, ss tikvpb.Tikv_CoprocessorStreamServer) error {
func (s *MockServer) CoprocessorStream(req *coprocessor.Request, ss tikvpb.Tikv_CoprocessorStreamServer) error {
if err := s.checkMetadata(ss.Context()); err != nil {
return err
}
return ss.Send(&coprocessor.Response{})
}

func (s *server) BatchCommands(ss tikvpb.Tikv_BatchCommandsServer) error {
func (s *MockServer) BatchCommands(ss tikvpb.Tikv_BatchCommandsServer) error {
if err := s.checkMetadata(ss.Context()); err != nil {
return err
}
Expand Down Expand Up @@ -91,13 +101,13 @@ func (s *server) BatchCommands(ss tikvpb.Tikv_BatchCommandsServer) error {
}
}

func (s *server) setMetaChecker(check func(context.Context) error) {
func (s *MockServer) SetMetaChecker(check func(context.Context) error) {
s.metaChecker.Lock()
s.metaChecker.check = check
s.metaChecker.Unlock()
}

func (s *server) checkMetadata(ctx context.Context) error {
func (s *MockServer) checkMetadata(ctx context.Context) error {
s.metaChecker.Lock()
defer s.metaChecker.Unlock()
if s.metaChecker.check != nil {
Expand All @@ -106,32 +116,52 @@ func (s *server) checkMetadata(ctx context.Context) error {
return nil
}

func (s *server) Stop() {
func (s *MockServer) IsRunning() bool {
return atomic.LoadInt64(&s.running) == 1
}

func (s *MockServer) Addr() string {
return s.addr
}

func (s *MockServer) Stop() {
s.grpcServer.Stop()
atomic.StoreInt64(&s.running, 0)
}

// Try to start a gRPC server and retrun the server instance and binded port.
func startMockTikvService() (*server, int) {
func (s *MockServer) Start(addr string) int {
if addr == "" {
addr = fmt.Sprintf("%s:%d", "127.0.0.1", 0)
}
port := -1
lis, err := net.Listen("tcp", fmt.Sprintf("%s:%d", "127.0.0.1", 0))
lis, err := net.Listen("tcp", addr)
if err != nil {
logutil.BgLogger().Error("can't listen", zap.Error(err))
logutil.BgLogger().Error("can't start mock tikv service because no available ports")
return nil, port
return port
}
port = lis.Addr().(*net.TCPAddr).Port

server := &server{}
s := grpc.NewServer(grpc.ConnectionTimeout(time.Minute))
tikvpb.RegisterTikvServer(s, server)
server.grpcServer = s
grpcServer := grpc.NewServer(grpc.ConnectionTimeout(time.Minute))
tikvpb.RegisterTikvServer(grpcServer, s)
s.grpcServer = grpcServer
go func() {
if err = s.Serve(lis); err != nil {
if err = grpcServer.Serve(lis); err != nil {
logutil.BgLogger().Error(
"can't serve gRPC requests",
zap.Error(err),
)
}
}()
atomic.StoreInt64(&s.running, 1)
s.addr = fmt.Sprintf("%s:%d", "127.0.0.1", port)
logutil.BgLogger().Info("mock server started", zap.String("addr", s.addr))
return port
}

// StartMockTikvService try to start a gRPC server and retrun the server instance and binded port.
func StartMockTikvService() (*MockServer, int) {
server := &MockServer{}
port := server.Start("")
return server, port
}
4 changes: 2 additions & 2 deletions internal/locate/region_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ func (c *RegionCache) checkAndResolve(needCheckStores []*Store, needCheck func(*
r := recover()
if r != nil {
logutil.BgLogger().Error("panic in the checkAndResolve goroutine",
zap.Reflect("r", r),
zap.Any("r", r),
zap.Stack("stack trace"))
}
}()
Expand Down Expand Up @@ -2975,7 +2975,7 @@ func (c *RegionCache) checkAndUpdateStoreSlowScores() {
r := recover()
if r != nil {
logutil.BgLogger().Error("panic in the checkAndUpdateStoreSlowScores goroutine",
zap.Reflect("r", r),
zap.Any("r", r),
zap.Stack("stack trace"))
}
}()
Expand Down
Loading

0 comments on commit 916bb20

Please sign in to comment.