diff --git a/server/BUILD.bazel b/server/BUILD.bazel index 5857a2801db70..9548fa1998a10 100644 --- a/server/BUILD.bazel +++ b/server/BUILD.bazel @@ -217,6 +217,7 @@ go_test( "@com_github_tikv_client_go_v2//tikvrpc", "@io_etcd_go_etcd_tests_v3//integration", "@io_opencensus_go//stats/view", + "@org_uber_go_atomic//:atomic", "@org_uber_go_goleak//:goleak", "@org_uber_go_zap//:zap", ], diff --git a/server/conn.go b/server/conn.go index b65538e999f88..227870f5e16ac 100644 --- a/server/conn.go +++ b/server/conn.go @@ -361,14 +361,6 @@ func closeConn(cc *clientConn, connections int) error { // This is because closeConn() might be called after a connection read-timeout. logutil.Logger(context.Background()).Debug("could not close connection", zap.Error(err)) } - if cc.bufReadConn != nil { - err = cc.bufReadConn.Close() - if err != nil { - // We need to expect connection might have already disconnected. - // This is because closeConn() might be called after a connection read-timeout. - logutil.Logger(context.Background()).Debug("could not close connection", zap.Error(err)) - } - } // Close statements and session // This will release advisory locks, row locks, etc. if ctx := cc.getCtx(); ctx != nil { diff --git a/server/conn_test.go b/server/conn_test.go index 55e9d97b5c4aa..a4aca4600c24a 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -1581,7 +1581,7 @@ func TestMaxAllowedPacket(t *testing.T) { bytes := append([]byte{0x00, 0x04, 0x00, 0x00}, []byte(fmt.Sprintf("SELECT length('%s') as len;", strings.Repeat("a", 999)))...) _, err := inBuffer.Write(bytes) require.NoError(t, err) - brc := newBufferedReadConn(&bytesConn{inBuffer}) + brc := newBufferedReadConn(&bytesConn{b: inBuffer}) pkt := newPacketIO(brc) pkt.setMaxAllowedPacket(maxAllowedPacket) readBytes, err = pkt.readPacket() @@ -1594,7 +1594,7 @@ func TestMaxAllowedPacket(t *testing.T) { bytes = append([]byte{0x01, 0x04, 0x00, 0x00}, []byte(fmt.Sprintf("SELECT length('%s') as len;", strings.Repeat("a", 1000)))...) _, err = inBuffer.Write(bytes) require.NoError(t, err) - brc = newBufferedReadConn(&bytesConn{inBuffer}) + brc = newBufferedReadConn(&bytesConn{b: inBuffer}) pkt = newPacketIO(brc) pkt.setMaxAllowedPacket(maxAllowedPacket) _, err = pkt.readPacket() @@ -1606,7 +1606,7 @@ func TestMaxAllowedPacket(t *testing.T) { bytes = append([]byte{0x01, 0x02, 0x00, 0x00}, []byte(fmt.Sprintf("SELECT length('%s') as len;", strings.Repeat("a", 488)))...) _, err = inBuffer.Write(bytes) require.NoError(t, err) - brc = newBufferedReadConn(&bytesConn{inBuffer}) + brc = newBufferedReadConn(&bytesConn{b: inBuffer}) pkt = newPacketIO(brc) pkt.setMaxAllowedPacket(maxAllowedPacket) readBytes, err = pkt.readPacket() @@ -1617,7 +1617,7 @@ func TestMaxAllowedPacket(t *testing.T) { bytes = append([]byte{0x01, 0x02, 0x00, 0x01}, []byte(fmt.Sprintf("SELECT length('%s') as len;", strings.Repeat("b", 488)))...) _, err = inBuffer.Write(bytes) require.NoError(t, err) - brc = newBufferedReadConn(&bytesConn{inBuffer}) + brc = newBufferedReadConn(&bytesConn{b: inBuffer}) pkt.setBufferedReadConn(brc) readBytes, err = pkt.readPacket() require.NoError(t, err) @@ -1830,6 +1830,7 @@ func TestProcessInfoForExecuteCommand(t *testing.T) { 0x0A, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0})) require.Equal(t, cc.ctx.Session.ShowProcess().Info, "select sum(col1) from t where col1 < ? and col1 > 100") } + func TestCloseConn(t *testing.T) { var outBuffer bytes.Buffer @@ -1840,7 +1841,12 @@ func TestCloseConn(t *testing.T) { drv := NewTiDBDriver(store) server, err := NewServer(cfg, drv) require.NoError(t, err) - + var inBuffer bytes.Buffer + _, err = inBuffer.Write([]byte{0x01, 0x00, 0x00, 0x00, 0x01}) + require.NoError(t, err) + // Test read one packet + brc := newBufferedReadConn(&bytesConn{b: inBuffer}) + require.NoError(t, err) cc := &clientConn{ connectionID: 0, salt: []byte{ @@ -1851,11 +1857,12 @@ func TestCloseConn(t *testing.T) { pkt: &packetIO{ bufWriter: bufio.NewWriter(&outBuffer), }, - collation: mysql.DefaultCollationID, - peerHost: "localhost", - alloc: arena.NewAllocator(512), - chunkAlloc: chunk.NewAllocator(), - capability: mysql.ClientProtocol41, + collation: mysql.DefaultCollationID, + peerHost: "localhost", + alloc: arena.NewAllocator(512), + chunkAlloc: chunk.NewAllocator(), + capability: mysql.ClientProtocol41, + bufReadConn: brc, } var wg sync.WaitGroup diff --git a/server/packetio_test.go b/server/packetio_test.go index fc0b38a23169b..0e9cde8bb26d2 100644 --- a/server/packetio_test.go +++ b/server/packetio_test.go @@ -21,8 +21,10 @@ import ( "testing" "time" + "github.com/pingcap/errors" "github.com/pingcap/tidb/parser/mysql" "github.com/stretchr/testify/require" + "go.uber.org/atomic" ) func BenchmarkPacketIOWrite(b *testing.B) { @@ -64,7 +66,7 @@ func TestPacketIORead(t *testing.T) { _, err := inBuffer.Write([]byte{0x01, 0x00, 0x00, 0x00, 0x01}) require.NoError(t, err) // Test read one packet - brc := newBufferedReadConn(&bytesConn{inBuffer}) + brc := newBufferedReadConn(&bytesConn{b: inBuffer}) pkt := newPacketIO(brc) readBytes, err := pkt.readPacket() require.NoError(t, err) @@ -86,7 +88,7 @@ func TestPacketIORead(t *testing.T) { _, err = inBuffer.Write(buf) require.NoError(t, err) // Test read multiple packets - brc = newBufferedReadConn(&bytesConn{inBuffer}) + brc = newBufferedReadConn(&bytesConn{b: inBuffer}) pkt = newPacketIO(brc) readBytes, err = pkt.readPacket() require.NoError(t, err) @@ -96,7 +98,8 @@ func TestPacketIORead(t *testing.T) { } type bytesConn struct { - b bytes.Buffer + b bytes.Buffer + closed atomic.Bool } func (c *bytesConn) Read(b []byte) (n int, err error) { @@ -108,6 +111,10 @@ func (c *bytesConn) Write(b []byte) (n int, err error) { } func (c *bytesConn) Close() error { + if c.closed.Load() { + return errors.New("already closed") + } + c.closed.Store(true) return nil }