Skip to content

Commit

Permalink
Use correct address for CloseAddr (#1140)
Browse files Browse the repository at this point in the history
* wip

Signed-off-by: Ping Yu <yuping@pingcap.com>

* CloseAddr with ver

Signed-off-by: Ping Yu <yuping@pingcap.com>

* fix ErrConn

Signed-off-by: Ping Yu <yuping@pingcap.com>

* fix ut

Signed-off-by: Ping Yu <yuping@pingcap.com>

* polish

Signed-off-by: Ping Yu <yuping@pingcap.com>

* polish

Signed-off-by: Ping Yu <yuping@pingcap.com>

* polish

Signed-off-by: Ping Yu <yuping@pingcap.com>

---------

Signed-off-by: Ping Yu <yuping@pingcap.com>
Signed-off-by: zzm <zhouzemin@pingcap.com>
  • Loading branch information
pingyu authored and zeminzhou committed Feb 28, 2024
1 parent f732af4 commit 191e357
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 14 deletions.
73 changes: 64 additions & 9 deletions internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,45 @@ type Client interface {
SendRequest(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (*tikvrpc.Response, error)
}

// ClientExt is a client has extended interfaces.
type ClientExt interface {
// CloseAddrVer closes gRPC connections to the address with additional `ver` parameter.
// Each new connection will have an incremented `ver` value, and attempts to close a previous `ver` will be ignored.
// Passing `math.MaxUint64` as the `ver` parameter will forcefully close all connections to the address.
CloseAddrVer(addr string, ver uint64) error
}

// ErrConn wraps error with target address and version of the connection.
type ErrConn struct {
Err error
Addr string
Ver uint64
}

func (e *ErrConn) Error() string {
return fmt.Sprintf("[%s](%d) %s", e.Addr, e.Ver, e.Err.Error())
}

func (e *ErrConn) Unwrap() error {
return e.Err
}

func WrapErrConn(err error, conn *connArray) error {
if err == nil {
return nil
}
return &ErrConn{
Err: err,
Addr: conn.target,
Ver: conn.ver,
}
}

type connArray struct {
// The target host.
target string
// version of the connection array, increase by 1 when reconnect.
ver uint64

index uint32
v []*monitoredConn
Expand All @@ -125,9 +161,10 @@ type connArray struct {
monitor *connMonitor
}

func newConnArray(maxSize uint, addr string, security config.Security,
func newConnArray(maxSize uint, addr string, ver uint64, security config.Security,
idleNotify *uint32, enableBatch bool, dialTimeout time.Duration, m *connMonitor, opts []grpc.DialOption) (*connArray, error) {
a := &connArray{
ver: ver,
index: 0,
v: make([]*monitoredConn, maxSize),
streamTimeout: make(chan *tikvrpc.Lease, 1024),
Expand Down Expand Up @@ -390,6 +427,7 @@ type RPCClient struct {
sync.RWMutex

conns map[string]*connArray
vers map[string]uint64
option *option

idleNotify uint32
Expand All @@ -405,6 +443,7 @@ type RPCClient struct {
func NewRPCClient(opts ...Opt) *RPCClient {
cli := &RPCClient{
conns: make(map[string]*connArray),
vers: make(map[string]uint64),
option: &option{
dialTimeout: dialTimeout,
},
Expand Down Expand Up @@ -452,9 +491,11 @@ func (c *RPCClient) createConnArray(addr string, enableBatch bool, opts ...func(
for _, opt := range opts {
opt(&client)
}
ver := c.vers[addr] + 1
array, err = newConnArray(
client.GrpcConnectionCount,
addr,
ver,
c.option.security,
&c.idleNotify,
enableBatch,
Expand All @@ -466,6 +507,7 @@ func (c *RPCClient) createConnArray(addr string, enableBatch bool, opts ...func(
return nil, err
}
c.conns[addr] = array
c.vers[addr] = ver
}
return array, nil
}
Expand Down Expand Up @@ -603,6 +645,10 @@ func (c *RPCClient) sendRequest(ctx context.Context, addr string, req *tikvrpc.R
return nil, err
}

wrapErrConn := func(resp *tikvrpc.Response, err error) (*tikvrpc.Response, error) {
return resp, WrapErrConn(err, connArray)
}

start := time.Now()
staleRead := req.GetStaleRead()
defer func() {
Expand All @@ -625,7 +671,7 @@ func (c *RPCClient) sendRequest(ctx context.Context, addr string, req *tikvrpc.R
if config.GetGlobalConfig().TiKVClient.MaxBatchSize > 0 && enableBatch {
if batchReq := req.ToBatchCommandsRequest(); batchReq != nil {
defer trace.StartRegion(ctx, req.Type.String()).End()
return sendBatchRequest(ctx, addr, req.ForwardedHost, connArray.batchConn, batchReq, timeout)
return wrapErrConn(sendBatchRequest(ctx, addr, req.ForwardedHost, connArray.batchConn, batchReq, timeout))
}
}

Expand All @@ -639,7 +685,7 @@ func (c *RPCClient) sendRequest(ctx context.Context, addr string, req *tikvrpc.R
client := debugpb.NewDebugClient(clientConn)
ctx1, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
return tikvrpc.CallDebugRPC(ctx1, client, req)
return wrapErrConn(tikvrpc.CallDebugRPC(ctx1, client, req))
}

client := tikvpb.NewTikvClient(clientConn)
Expand All @@ -650,16 +696,16 @@ func (c *RPCClient) sendRequest(ctx context.Context, addr string, req *tikvrpc.R
}
switch req.Type {
case tikvrpc.CmdBatchCop:
return c.getBatchCopStreamResponse(ctx, client, req, timeout, connArray)
return wrapErrConn(c.getBatchCopStreamResponse(ctx, client, req, timeout, connArray))
case tikvrpc.CmdCopStream:
return c.getCopStreamResponse(ctx, client, req, timeout, connArray)
return wrapErrConn(c.getCopStreamResponse(ctx, client, req, timeout, connArray))
case tikvrpc.CmdMPPConn:
return c.getMPPStreamResponse(ctx, client, req, timeout, connArray)
return wrapErrConn(c.getMPPStreamResponse(ctx, client, req, timeout, connArray))
}
// Or else it's a unary call.
ctx1, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
return tikvrpc.CallRPC(ctx1, client, req)
return wrapErrConn(tikvrpc.CallRPC(ctx1, client, req))
}

// SendRequest sends a Request to server and receives Response.
Expand Down Expand Up @@ -793,11 +839,20 @@ func (c *RPCClient) Close() error {

// CloseAddr closes gRPC connections to the address.
func (c *RPCClient) CloseAddr(addr string) error {
return c.CloseAddrVer(addr, math.MaxUint64)
}

func (c *RPCClient) CloseAddrVer(addr string, ver uint64) error {
c.Lock()
conn, ok := c.conns[addr]
if ok {
delete(c.conns, addr)
logutil.BgLogger().Debug("close connection", zap.String("target", addr))
if conn.ver <= ver {
delete(c.conns, addr)
logutil.BgLogger().Debug("close connection", zap.String("target", addr), zap.Uint64("ver", ver), zap.Uint64("conn.ver", conn.ver))
} else {
logutil.BgLogger().Debug("ignore close connection", zap.String("target", addr), zap.Uint64("ver", ver), zap.Uint64("conn.ver", conn.ver))
conn = nil
}
}
c.Unlock()

Expand Down
6 changes: 4 additions & 2 deletions internal/client/client_batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -830,16 +830,18 @@ func (c *RPCClient) recycleIdleConnArray() {
start := time.Now()

var addrs []string
var vers []uint64
c.RLock()
for _, conn := range c.conns {
if conn.batchConn != nil && conn.isIdle() {
addrs = append(addrs, conn.target)
vers = append(vers, conn.ver)
}
}
c.RUnlock()

for _, addr := range addrs {
c.CloseAddr(addr)
for i, addr := range addrs {
c.CloseAddrVer(addr, vers[i])
}

metrics.TiKVBatchClientRecycle.Observe(time.Since(start).Seconds())
Expand Down
41 changes: 39 additions & 2 deletions internal/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,18 @@ func TestConn(t *testing.T) {
assert.Nil(t, err)
assert.False(t, conn2.Get() == conn1.Get())

assert.Nil(t, client.CloseAddr(addr))
ver := conn2.ver
assert.Nil(t, client.CloseAddrVer(addr, ver-1))
_, ok := client.conns[addr]
assert.True(t, ok)
assert.Nil(t, client.CloseAddrVer(addr, ver))
_, ok = client.conns[addr]
assert.False(t, ok)

conn3, err := client.getConnArray(addr, true)
assert.Nil(t, err)
assert.NotNil(t, conn3)
assert.Equal(t, ver+1, conn3.ver)

client.Close()
conn4, err := client.getConnArray(addr, true)
Expand Down Expand Up @@ -135,7 +141,7 @@ func TestSendWhenReconnect(t *testing.T) {

req := tikvrpc.NewRequest(tikvrpc.CmdEmpty, &tikvpb.BatchCommandsEmptyRequest{})
_, err = rpcClient.SendRequest(context.Background(), addr, req, 100*time.Second)
assert.True(t, err.Error() == "no available connections")
assert.EqualError(t, err, fmt.Sprintf("[%s](%d) no available connections", addr, 1))
server.Stop()
}

Expand Down Expand Up @@ -723,3 +729,34 @@ func TestBatchClientRecoverAfterServerRestart(t *testing.T) {
require.NoError(t, err)
}
}

func TestErrConn(t *testing.T) {
e := errors.New("conn error")
err1 := &ErrConn{Err: e, Addr: "127.0.0.1", Ver: 10}
err2 := &ErrConn{Err: e, Addr: "127.0.0.1", Ver: 10}

e3 := errors.New("conn error 3")
err3 := &ErrConn{Err: e3}

err4 := errors.New("not ErrConn")

assert.True(t, errors.Is(err1, err1))
assert.True(t, errors.Is(fmt.Errorf("%w", err1), err1))
assert.False(t, errors.Is(fmt.Errorf("%w", err2), err1)) // err2 != err1
assert.False(t, errors.Is(fmt.Errorf("%w", err4), err1))

var errConn *ErrConn
assert.True(t, errors.As(err1, &errConn))
assert.Equal(t, "127.0.0.1", errConn.Addr)
assert.EqualValues(t, 10, errConn.Ver)
assert.EqualError(t, errConn.Err, "conn error")

assert.True(t, errors.As(err3, &errConn))
assert.EqualError(t, e3, "conn error 3")

assert.False(t, errors.As(err4, &errConn))

errMsg := errors.New("unknown")
assert.True(t, errors.As(err1, &errMsg))
assert.EqualError(t, err1, errMsg.Error())
}
17 changes: 16 additions & 1 deletion internal/locate/region_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,14 @@ func (s *RegionRequestSender) GetClient() client.Client {
return s.client
}

// getClientExt returns the client with ClientExt interface.
// Return nil if the client does not implement ClientExt.
// Don't use in critical path.
func (s *RegionRequestSender) getClientExt() client.ClientExt {
ext, _ := s.client.(client.ClientExt)
return ext
}

// SetStoreAddr specifies the dest store address.
func (s *RegionRequestSender) SetStoreAddr(addr string) {
s.storeAddr = addr
Expand Down Expand Up @@ -1836,7 +1844,14 @@ func (s *RegionRequestSender) onSendFail(bo *retry.Backoffer, ctx *RPCContext, r
// Canceled by gRPC remote may happen when tikv is killed and exiting.
// Close the connection, backoff, and retry.
logutil.Logger(bo.GetCtx()).Warn("receive a grpc cancel signal", zap.Error(err))
s.client.CloseAddr(ctx.Addr)
var errConn *client.ErrConn
if errors.As(err, &errConn) {
if ext := s.getClientExt(); ext != nil {
ext.CloseAddrVer(errConn.Addr, errConn.Ver)
} else {
s.client.CloseAddr(errConn.Addr)
}
}
}
}

Expand Down
26 changes: 26 additions & 0 deletions internal/locate/region_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ package locate
import (
"context"
"fmt"
"math"
"math/rand"
"net"
"sync"
Expand Down Expand Up @@ -99,14 +100,20 @@ func (s *testRegionRequestToSingleStoreSuite) TearDownTest() {
type fnClient struct {
fn func(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (*tikvrpc.Response, error)
closedAddr string
closedVer uint64
}

func (f *fnClient) Close() error {
return nil
}

func (f *fnClient) CloseAddr(addr string) error {
return f.CloseAddrVer(addr, math.MaxUint64)
}

func (f *fnClient) CloseAddrVer(addr string, ver uint64) error {
f.closedAddr = addr
f.closedVer = ver
return nil
}

Expand Down Expand Up @@ -664,6 +671,8 @@ func (s *testRegionRequestToSingleStoreSuite) TestCloseConnectionOnStoreNotMatch
regionErr, _ := resp.GetRegionError()
s.NotNil(regionErr)
s.Equal(target, client.closedAddr)
var expected uint64 = math.MaxUint64
s.Equal(expected, client.closedVer)
}

func (s *testRegionRequestToSingleStoreSuite) TestStaleReadRetry() {
Expand Down Expand Up @@ -824,3 +833,20 @@ func (s *testRegionRequestToSingleStoreSuite) TestCountReplicaNumber() {
s.Equal(4, s.regionRequestSender.countReplicaNumber(peers)) // Only count 1 tiflash replica for tiflash write-nodes.
}
}

type emptyClient struct {
client.Client
}

func (s *testRegionRequestToSingleStoreSuite) TestClientExt() {
var cli client.Client = client.NewRPCClient()
sender := NewRegionRequestSender(s.cache, cli)
s.NotNil(sender.client)
s.NotNil(sender.getClientExt())
cli.Close()

cli = &emptyClient{}
sender = NewRegionRequestSender(s.cache, cli)
s.NotNil(sender.client)
s.Nil(sender.getClientExt())
}

0 comments on commit 191e357

Please sign in to comment.