Skip to content

Commit

Permalink
This is an automated cherry-pick of pingcap#53690
Browse files Browse the repository at this point in the history
Signed-off-by: ti-chi-bot <ti-community-prow-bot@tidb.io>
  • Loading branch information
jackysp authored and ti-chi-bot committed Aug 5, 2024
1 parent d97c194 commit a8023bf
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 21 deletions.
4 changes: 4 additions & 0 deletions server/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,11 @@ go_test(
"@com_github_tikv_client_go_v2//tikvrpc",
"@io_etcd_go_etcd_tests_v3//integration",
"@io_opencensus_go//stats/view",
<<<<<<< HEAD
"@org_golang_x_exp//slices",
=======
"@org_uber_go_atomic//:atomic",
>>>>>>> 05a1ad36ce8 (server: fix connection double close (#53690))
"@org_uber_go_goleak//:goleak",
"@org_uber_go_zap//:zap",
],
Expand Down
8 changes: 0 additions & 8 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,14 +336,6 @@ func closeConn(cc *clientConn, connections int) error {
// This is because closeConn() might be called after a connection read-timeout.
logutil.Logger(context.Background()).Debug("could not close connection", zap.Error(err))
}
if cc.bufReadConn != nil {
err = cc.bufReadConn.Close()
if err != nil {
// We need to expect connection might have already disconnected.
// This is because closeConn() might be called after a connection read-timeout.
logutil.Logger(context.Background()).Debug("could not close connection", zap.Error(err))
}
}
// Close statements and session
// This will release advisory locks, row locks, etc.
if ctx := cc.getCtx(); ctx != nil {
Expand Down
29 changes: 19 additions & 10 deletions server/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1692,7 +1692,7 @@ func TestMaxAllowedPacket(t *testing.T) {
bytes := append([]byte{0x00, 0x04, 0x00, 0x00}, []byte(fmt.Sprintf("SELECT length('%s') as len;", strings.Repeat("a", 999)))...)
_, err := inBuffer.Write(bytes)
require.NoError(t, err)
brc := newBufferedReadConn(&bytesConn{inBuffer})
brc := newBufferedReadConn(&bytesConn{b: inBuffer})
pkt := newPacketIO(brc)
pkt.setMaxAllowedPacket(maxAllowedPacket)
readBytes, err = pkt.readPacket()
Expand All @@ -1705,7 +1705,7 @@ func TestMaxAllowedPacket(t *testing.T) {
bytes = append([]byte{0x01, 0x04, 0x00, 0x00}, []byte(fmt.Sprintf("SELECT length('%s') as len;", strings.Repeat("a", 1000)))...)
_, err = inBuffer.Write(bytes)
require.NoError(t, err)
brc = newBufferedReadConn(&bytesConn{inBuffer})
brc = newBufferedReadConn(&bytesConn{b: inBuffer})
pkt = newPacketIO(brc)
pkt.setMaxAllowedPacket(maxAllowedPacket)
_, err = pkt.readPacket()
Expand All @@ -1717,7 +1717,7 @@ func TestMaxAllowedPacket(t *testing.T) {
bytes = append([]byte{0x01, 0x02, 0x00, 0x00}, []byte(fmt.Sprintf("SELECT length('%s') as len;", strings.Repeat("a", 488)))...)
_, err = inBuffer.Write(bytes)
require.NoError(t, err)
brc = newBufferedReadConn(&bytesConn{inBuffer})
brc = newBufferedReadConn(&bytesConn{b: inBuffer})
pkt = newPacketIO(brc)
pkt.setMaxAllowedPacket(maxAllowedPacket)
readBytes, err = pkt.readPacket()
Expand All @@ -1728,7 +1728,7 @@ func TestMaxAllowedPacket(t *testing.T) {
bytes = append([]byte{0x01, 0x02, 0x00, 0x01}, []byte(fmt.Sprintf("SELECT length('%s') as len;", strings.Repeat("b", 488)))...)
_, err = inBuffer.Write(bytes)
require.NoError(t, err)
brc = newBufferedReadConn(&bytesConn{inBuffer})
brc = newBufferedReadConn(&bytesConn{b: inBuffer})
pkt.setBufferedReadConn(brc)
readBytes, err = pkt.readPacket()
require.NoError(t, err)
Expand Down Expand Up @@ -1987,6 +1987,7 @@ func TestProcessInfoForExecuteCommand(t *testing.T) {
require.Equal(t, cc.ctx.Session.ShowProcess().Info, "select sum(col1) from t where col1 < ? and col1 > 100")
}

<<<<<<< HEAD
func TestLDAPAuthSwitch(t *testing.T) {
store := testkit.CreateMockStore(t)
cfg := newTestConfig()
Expand Down Expand Up @@ -2026,6 +2027,8 @@ func TestLDAPAuthSwitch(t *testing.T) {
require.Equal(t, []byte(mysql.AuthMySQLClearPassword), respAuthSwitch)
}

=======
>>>>>>> 05a1ad36ce8 (server: fix connection double close (#53690))
func TestCloseConn(t *testing.T) {
var outBuffer bytes.Buffer

Expand All @@ -2036,7 +2039,12 @@ func TestCloseConn(t *testing.T) {
drv := NewTiDBDriver(store)
server, err := NewServer(cfg, drv)
require.NoError(t, err)

var inBuffer bytes.Buffer
_, err = inBuffer.Write([]byte{0x01, 0x00, 0x00, 0x00, 0x01})
require.NoError(t, err)
// Test read one packet
brc := newBufferedReadConn(&bytesConn{b: inBuffer})
require.NoError(t, err)
cc := &clientConn{
connectionID: 0,
salt: []byte{
Expand All @@ -2047,11 +2055,12 @@ func TestCloseConn(t *testing.T) {
pkt: &packetIO{
bufWriter: bufio.NewWriter(&outBuffer),
},
collation: mysql.DefaultCollationID,
peerHost: "localhost",
alloc: arena.NewAllocator(512),
chunkAlloc: chunk.NewAllocator(),
capability: mysql.ClientProtocol41,
collation: mysql.DefaultCollationID,
peerHost: "localhost",
alloc: arena.NewAllocator(512),
chunkAlloc: chunk.NewAllocator(),
capability: mysql.ClientProtocol41,
bufReadConn: brc,
}

var wg sync.WaitGroup
Expand Down
13 changes: 10 additions & 3 deletions server/packetio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ import (
"testing"
"time"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/parser/mysql"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
)

func BenchmarkPacketIOWrite(b *testing.B) {
Expand Down Expand Up @@ -64,7 +66,7 @@ func TestPacketIORead(t *testing.T) {
_, err := inBuffer.Write([]byte{0x01, 0x00, 0x00, 0x00, 0x01})
require.NoError(t, err)
// Test read one packet
brc := newBufferedReadConn(&bytesConn{inBuffer})
brc := newBufferedReadConn(&bytesConn{b: inBuffer})
pkt := newPacketIO(brc)
readBytes, err := pkt.readPacket()
require.NoError(t, err)
Expand All @@ -86,7 +88,7 @@ func TestPacketIORead(t *testing.T) {
_, err = inBuffer.Write(buf)
require.NoError(t, err)
// Test read multiple packets
brc = newBufferedReadConn(&bytesConn{inBuffer})
brc = newBufferedReadConn(&bytesConn{b: inBuffer})
pkt = newPacketIO(brc)
readBytes, err = pkt.readPacket()
require.NoError(t, err)
Expand All @@ -96,7 +98,8 @@ func TestPacketIORead(t *testing.T) {
}

type bytesConn struct {
b bytes.Buffer
b bytes.Buffer
closed atomic.Bool
}

func (c *bytesConn) Read(b []byte) (n int, err error) {
Expand All @@ -108,6 +111,10 @@ func (c *bytesConn) Write(b []byte) (n int, err error) {
}

func (c *bytesConn) Close() error {
if c.closed.Load() {
return errors.New("already closed")
}
c.closed.Store(true)
return nil
}

Expand Down

0 comments on commit a8023bf

Please sign in to comment.