Skip to content

Commit

Permalink
server: fix connection double close (#53690) (#55186)
Browse files Browse the repository at this point in the history
close #53689
  • Loading branch information
ti-chi-bot authored Aug 5, 2024
1 parent d97c194 commit b0823fa
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 21 deletions.
1 change: 1 addition & 0 deletions server/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ go_test(
"@io_etcd_go_etcd_tests_v3//integration",
"@io_opencensus_go//stats/view",
"@org_golang_x_exp//slices",
"@org_uber_go_atomic//:atomic",
"@org_uber_go_goleak//:goleak",
"@org_uber_go_zap//:zap",
],
Expand Down
8 changes: 0 additions & 8 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,14 +336,6 @@ func closeConn(cc *clientConn, connections int) error {
// This is because closeConn() might be called after a connection read-timeout.
logutil.Logger(context.Background()).Debug("could not close connection", zap.Error(err))
}
if cc.bufReadConn != nil {
err = cc.bufReadConn.Close()
if err != nil {
// We need to expect connection might have already disconnected.
// This is because closeConn() might be called after a connection read-timeout.
logutil.Logger(context.Background()).Debug("could not close connection", zap.Error(err))
}
}
// Close statements and session
// This will release advisory locks, row locks, etc.
if ctx := cc.getCtx(); ctx != nil {
Expand Down
26 changes: 16 additions & 10 deletions server/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1692,7 +1692,7 @@ func TestMaxAllowedPacket(t *testing.T) {
bytes := append([]byte{0x00, 0x04, 0x00, 0x00}, []byte(fmt.Sprintf("SELECT length('%s') as len;", strings.Repeat("a", 999)))...)
_, err := inBuffer.Write(bytes)
require.NoError(t, err)
brc := newBufferedReadConn(&bytesConn{inBuffer})
brc := newBufferedReadConn(&bytesConn{b: inBuffer})
pkt := newPacketIO(brc)
pkt.setMaxAllowedPacket(maxAllowedPacket)
readBytes, err = pkt.readPacket()
Expand All @@ -1705,7 +1705,7 @@ func TestMaxAllowedPacket(t *testing.T) {
bytes = append([]byte{0x01, 0x04, 0x00, 0x00}, []byte(fmt.Sprintf("SELECT length('%s') as len;", strings.Repeat("a", 1000)))...)
_, err = inBuffer.Write(bytes)
require.NoError(t, err)
brc = newBufferedReadConn(&bytesConn{inBuffer})
brc = newBufferedReadConn(&bytesConn{b: inBuffer})
pkt = newPacketIO(brc)
pkt.setMaxAllowedPacket(maxAllowedPacket)
_, err = pkt.readPacket()
Expand All @@ -1717,7 +1717,7 @@ func TestMaxAllowedPacket(t *testing.T) {
bytes = append([]byte{0x01, 0x02, 0x00, 0x00}, []byte(fmt.Sprintf("SELECT length('%s') as len;", strings.Repeat("a", 488)))...)
_, err = inBuffer.Write(bytes)
require.NoError(t, err)
brc = newBufferedReadConn(&bytesConn{inBuffer})
brc = newBufferedReadConn(&bytesConn{b: inBuffer})
pkt = newPacketIO(brc)
pkt.setMaxAllowedPacket(maxAllowedPacket)
readBytes, err = pkt.readPacket()
Expand All @@ -1728,7 +1728,7 @@ func TestMaxAllowedPacket(t *testing.T) {
bytes = append([]byte{0x01, 0x02, 0x00, 0x01}, []byte(fmt.Sprintf("SELECT length('%s') as len;", strings.Repeat("b", 488)))...)
_, err = inBuffer.Write(bytes)
require.NoError(t, err)
brc = newBufferedReadConn(&bytesConn{inBuffer})
brc = newBufferedReadConn(&bytesConn{b: inBuffer})
pkt.setBufferedReadConn(brc)
readBytes, err = pkt.readPacket()
require.NoError(t, err)
Expand Down Expand Up @@ -2036,7 +2036,12 @@ func TestCloseConn(t *testing.T) {
drv := NewTiDBDriver(store)
server, err := NewServer(cfg, drv)
require.NoError(t, err)

var inBuffer bytes.Buffer
_, err = inBuffer.Write([]byte{0x01, 0x00, 0x00, 0x00, 0x01})
require.NoError(t, err)
// Test read one packet
brc := newBufferedReadConn(&bytesConn{b: inBuffer})
require.NoError(t, err)
cc := &clientConn{
connectionID: 0,
salt: []byte{
Expand All @@ -2047,11 +2052,12 @@ func TestCloseConn(t *testing.T) {
pkt: &packetIO{
bufWriter: bufio.NewWriter(&outBuffer),
},
collation: mysql.DefaultCollationID,
peerHost: "localhost",
alloc: arena.NewAllocator(512),
chunkAlloc: chunk.NewAllocator(),
capability: mysql.ClientProtocol41,
collation: mysql.DefaultCollationID,
peerHost: "localhost",
alloc: arena.NewAllocator(512),
chunkAlloc: chunk.NewAllocator(),
capability: mysql.ClientProtocol41,
bufReadConn: brc,
}

var wg sync.WaitGroup
Expand Down
13 changes: 10 additions & 3 deletions server/packetio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ import (
"testing"
"time"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/parser/mysql"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
)

func BenchmarkPacketIOWrite(b *testing.B) {
Expand Down Expand Up @@ -64,7 +66,7 @@ func TestPacketIORead(t *testing.T) {
_, err := inBuffer.Write([]byte{0x01, 0x00, 0x00, 0x00, 0x01})
require.NoError(t, err)
// Test read one packet
brc := newBufferedReadConn(&bytesConn{inBuffer})
brc := newBufferedReadConn(&bytesConn{b: inBuffer})
pkt := newPacketIO(brc)
readBytes, err := pkt.readPacket()
require.NoError(t, err)
Expand All @@ -86,7 +88,7 @@ func TestPacketIORead(t *testing.T) {
_, err = inBuffer.Write(buf)
require.NoError(t, err)
// Test read multiple packets
brc = newBufferedReadConn(&bytesConn{inBuffer})
brc = newBufferedReadConn(&bytesConn{b: inBuffer})
pkt = newPacketIO(brc)
readBytes, err = pkt.readPacket()
require.NoError(t, err)
Expand All @@ -96,7 +98,8 @@ func TestPacketIORead(t *testing.T) {
}

type bytesConn struct {
b bytes.Buffer
b bytes.Buffer
closed atomic.Bool
}

func (c *bytesConn) Read(b []byte) (n int, err error) {
Expand All @@ -108,6 +111,10 @@ func (c *bytesConn) Write(b []byte) (n int, err error) {
}

func (c *bytesConn) Close() error {
if c.closed.Load() {
return errors.New("already closed")
}
c.closed.Store(true)
return nil
}

Expand Down

0 comments on commit b0823fa

Please sign in to comment.