From 3cc6bb959266af01e57687829c9dedebc2d37f9b Mon Sep 17 00:00:00 2001 From: xhe Date: Tue, 25 Apr 2023 12:20:14 +0800 Subject: [PATCH 1/2] net: fix proxy protocol Signed-off-by: xhe --- pkg/proxy/net/packetio.go | 12 ++--- pkg/proxy/net/packetio_options.go | 2 +- pkg/proxy/net/packetio_test.go | 63 ++++++----------------- pkg/proxy/net/proxy_test.go | 65 ++++++++++++++++++++++++ pkg/proxy/proxyprotocol/definition.go | 4 +- pkg/proxy/proxyprotocol/listener_test.go | 8 +-- pkg/proxy/proxyprotocol/proxy.go | 10 ++-- pkg/proxy/proxyprotocol/proxy_test.go | 14 ++--- 8 files changed, 105 insertions(+), 73 deletions(-) create mode 100644 pkg/proxy/net/proxy_test.go diff --git a/pkg/proxy/net/packetio.go b/pkg/proxy/net/packetio.go index 2033ebe6..4b6457a0 100644 --- a/pkg/proxy/net/packetio.go +++ b/pkg/proxy/net/packetio.go @@ -103,12 +103,15 @@ func NewPacketIO(conn net.Conn, opts ...PacketIOption) *PacketIO { sequence: 0, buf: buf, } - // TODO: disable it by default now p.proxyInited.Store(true) + p.ApplyOpts(opts...) + return p +} + +func (p *PacketIO) ApplyOpts(opts ...PacketIOption) { for _, opt := range opts { opt(p) } - return p } func (p *PacketIO) wrapErr(err error) error { @@ -117,10 +120,7 @@ func (p *PacketIO) wrapErr(err error) error { // Proxy returned parsed proxy header from clients if any. func (p *PacketIO) Proxy() *proxyprotocol.Proxy { - if p.proxyInited.Load() { - return p.proxy - } - return nil + return p.proxy } func (p *PacketIO) LocalAddr() net.Addr { diff --git a/pkg/proxy/net/packetio_options.go b/pkg/proxy/net/packetio_options.go index fe1fb914..2805d080 100644 --- a/pkg/proxy/net/packetio_options.go +++ b/pkg/proxy/net/packetio_options.go @@ -23,7 +23,7 @@ import ( type PacketIOption = func(*PacketIO) func WithProxy(pi *PacketIO) { - pi.proxyInited.Store(true) + pi.proxyInited.Store(false) } func WithWrapError(err error) func(pi *PacketIO) { diff --git a/pkg/proxy/net/packetio_test.go b/pkg/proxy/net/packetio_test.go index 0340cdba..c0f057f7 100644 --- a/pkg/proxy/net/packetio_test.go +++ b/pkg/proxy/net/packetio_test.go @@ -22,62 +22,29 @@ import ( "github.com/pingcap/TiProxy/lib/config" "github.com/pingcap/TiProxy/lib/util/security" - "github.com/pingcap/TiProxy/lib/util/waitgroup" + "github.com/pingcap/TiProxy/pkg/testkit" "github.com/pingcap/tidb/parser/mysql" "github.com/stretchr/testify/require" ) func testPipeConn(t *testing.T, a func(*testing.T, *PacketIO), b func(*testing.T, *PacketIO), loop int) { - var wg waitgroup.WaitGroup - client, server := net.Pipe() - cli, srv := NewPacketIO(client), NewPacketIO(server) - if ddl, ok := t.Deadline(); ok { - require.NoError(t, client.SetDeadline(ddl)) - require.NoError(t, server.SetDeadline(ddl)) - } - for i := 0; i < loop; i++ { - wg.Run(func() { - a(t, cli) - require.NoError(t, cli.Close()) - }) - wg.Run(func() { - b(t, srv) - require.NoError(t, srv.Close()) - }) - wg.Wait() - } + testkit.TestPipeConn(t, + func(t *testing.T, c net.Conn) { + a(t, NewPacketIO(c)) + }, + func(t *testing.T, c net.Conn) { + b(t, NewPacketIO(c)) + }, loop) } func testTCPConn(t *testing.T, a func(*testing.T, *PacketIO), b func(*testing.T, *PacketIO), loop int) { - listener, err := net.Listen("tcp", "0.0.0.0:0") - require.NoError(t, err) - defer func() { - require.NoError(t, listener.Close()) - }() - var wg waitgroup.WaitGroup - for i := 0; i < loop; i++ { - wg.Run(func() { - cli, err := net.Dial("tcp", listener.Addr().String()) - require.NoError(t, err) - if ddl, ok := t.Deadline(); ok { - require.NoError(t, cli.SetDeadline(ddl)) - } - cliIO := NewPacketIO(cli) - a(t, cliIO) - require.NoError(t, cliIO.Close()) - }) - wg.Run(func() { - srv, err := listener.Accept() - require.NoError(t, err) - if ddl, ok := t.Deadline(); ok { - require.NoError(t, srv.SetDeadline(ddl)) - } - srvIO := NewPacketIO(srv) - b(t, srvIO) - require.NoError(t, srvIO.Close()) - }) - wg.Wait() - } + testkit.TestTCPConn(t, + func(t *testing.T, c net.Conn) { + a(t, NewPacketIO(c)) + }, + func(t *testing.T, c net.Conn) { + b(t, NewPacketIO(c)) + }, loop) } func TestPacketIO(t *testing.T) { diff --git a/pkg/proxy/net/proxy_test.go b/pkg/proxy/net/proxy_test.go new file mode 100644 index 00000000..8d26b00f --- /dev/null +++ b/pkg/proxy/net/proxy_test.go @@ -0,0 +1,65 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package net + +import ( + "bytes" + "io" + "net" + "testing" + + "github.com/pingcap/TiProxy/pkg/proxy/proxyprotocol" + "github.com/stretchr/testify/require" +) + +func TestProxyParse(t *testing.T) { + tcpaddr, err := net.ResolveTCPAddr("tcp", "192.168.1.1:34") + require.NoError(t, err) + + testPipeConn(t, + func(t *testing.T, cli *PacketIO) { + p := &proxyprotocol.Proxy{ + Version: proxyprotocol.ProxyVersion2, + Command: proxyprotocol.ProxyCommandLocal, + SrcAddress: tcpaddr, + DstAddress: tcpaddr, + TLV: []proxyprotocol.ProxyTlv{ + { + Typ: proxyprotocol.ProxyTlvALPN, + Content: nil, + }, + { + Typ: proxyprotocol.ProxyTlvUniqueID, + Content: []byte("test"), + }, + }, + } + b, err := p.ToBytes() + require.NoError(t, err) + _, err = io.Copy(cli.conn, bytes.NewReader(b)) + require.NoError(t, err) + err = cli.WritePacket([]byte("hello"), true) + require.NoError(t, err) + }, + func(t *testing.T, srv *PacketIO) { + srv.ApplyOpts(WithProxy) + b, err := srv.ReadPacket() + require.NoError(t, err) + require.Equal(t, "hello", string(b)) + require.Equal(t, tcpaddr.String(), srv.RemoteAddr().String()) + }, + 1, + ) +} diff --git a/pkg/proxy/proxyprotocol/definition.go b/pkg/proxy/proxyprotocol/definition.go index 3c50a830..f68e1cd7 100644 --- a/pkg/proxy/proxyprotocol/definition.go +++ b/pkg/proxy/proxyprotocol/definition.go @@ -63,8 +63,8 @@ const ( ) type ProxyTlv struct { - content []byte - typ ProxyTlvType + Content []byte + Typ ProxyTlvType } type Proxy struct { diff --git a/pkg/proxy/proxyprotocol/listener_test.go b/pkg/proxy/proxyprotocol/listener_test.go index f424c2ae..f3699cbe 100644 --- a/pkg/proxy/proxyprotocol/listener_test.go +++ b/pkg/proxy/proxyprotocol/listener_test.go @@ -43,12 +43,12 @@ func TestProxyListener(t *testing.T) { DstAddress: tcpaddr, TLV: []ProxyTlv{ { - typ: ProxyTlvALPN, - content: nil, + Typ: ProxyTlvALPN, + Content: nil, }, { - typ: ProxyTlvUniqueID, - content: []byte("test"), + Typ: ProxyTlvUniqueID, + Content: []byte("test"), }, }, } diff --git a/pkg/proxy/proxyprotocol/proxy.go b/pkg/proxy/proxyprotocol/proxy.go index e0ce875d..ad2c8331 100644 --- a/pkg/proxy/proxyprotocol/proxy.go +++ b/pkg/proxy/proxyprotocol/proxy.go @@ -92,10 +92,10 @@ func (p *Proxy) ToBytes() ([]byte, error) { buf[magicLen+1] = byte(addressFamily<<4) | byte(network&0xF) for _, tlv := range p.TLV { - buf = append(buf, byte(tlv.typ)) - tlen := len(tlv.content) + buf = append(buf, byte(tlv.Typ)) + tlen := len(tlv.Content) buf = append(buf, byte(tlen>>8), byte(tlen)) - buf = append(buf, tlv.content...) + buf = append(buf, tlv.Content...) } length := len(buf) - 4 - magicLen @@ -205,8 +205,8 @@ func ParseProxyV2(rd io.Reader) (m *Proxy, n int, err error) { length = len(buf) - 3 } m.TLV = append(m.TLV, ProxyTlv{ - typ: typ, - content: buf[3 : 3+length], + Typ: typ, + Content: buf[3 : 3+length], }) buf = buf[3+length:] } diff --git a/pkg/proxy/proxyprotocol/proxy_test.go b/pkg/proxy/proxyprotocol/proxy_test.go index 333c701f..735ef54f 100644 --- a/pkg/proxy/proxyprotocol/proxy_test.go +++ b/pkg/proxy/proxyprotocol/proxy_test.go @@ -37,12 +37,12 @@ func TestProxyParse(t *testing.T) { DstAddress: tcpaddr, TLV: []ProxyTlv{ { - typ: ProxyTlvALPN, - content: nil, + Typ: ProxyTlvALPN, + Content: nil, }, { - typ: ProxyTlvUniqueID, - content: []byte("test"), + Typ: ProxyTlvUniqueID, + Content: []byte("test"), }, }, } @@ -66,9 +66,9 @@ func TestProxyParse(t *testing.T) { require.Equal(t, ProxyVersion2, p.Version) require.Equal(t, ProxyCommandLocal, p.Command) require.Len(t, p.TLV, 2) - require.Equal(t, ProxyTlvALPN, p.TLV[0].typ) - require.Equal(t, ProxyTlvUniqueID, p.TLV[1].typ) - require.Equal(t, []byte("test"), p.TLV[1].content) + require.Equal(t, ProxyTlvALPN, p.TLV[0].Typ) + require.Equal(t, ProxyTlvUniqueID, p.TLV[1].Typ) + require.Equal(t, []byte("test"), p.TLV[1].Content) }, 1, ) From 6e11e746f8b59b98432e85ebb77d8f91f4a42b07 Mon Sep 17 00:00:00 2001 From: xhe Date: Tue, 25 Apr 2023 15:17:39 +0800 Subject: [PATCH 2/2] fix test Signed-off-by: xhe --- pkg/proxy/net/packetio_test.go | 8 ++++++-- pkg/testkit/main.go | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/pkg/proxy/net/packetio_test.go b/pkg/proxy/net/packetio_test.go index c0f057f7..f9206e2f 100644 --- a/pkg/proxy/net/packetio_test.go +++ b/pkg/proxy/net/packetio_test.go @@ -40,10 +40,14 @@ func testPipeConn(t *testing.T, a func(*testing.T, *PacketIO), b func(*testing.T func testTCPConn(t *testing.T, a func(*testing.T, *PacketIO), b func(*testing.T, *PacketIO), loop int) { testkit.TestTCPConn(t, func(t *testing.T, c net.Conn) { - a(t, NewPacketIO(c)) + cli := NewPacketIO(c) + a(t, cli) + require.NoError(t, cli.Close()) }, func(t *testing.T, c net.Conn) { - b(t, NewPacketIO(c)) + srv := NewPacketIO(c) + b(t, srv) + require.NoError(t, srv.Close()) }, loop) } diff --git a/pkg/testkit/main.go b/pkg/testkit/main.go index 221fff8b..5128cd59 100644 --- a/pkg/testkit/main.go +++ b/pkg/testkit/main.go @@ -64,7 +64,9 @@ func TestTCPConnWithListener(t *testing.T, listen func(*testing.T, string, strin require.NoError(t, cli.SetDeadline(ddl)) } a(t, cli) - require.NoError(t, cli.Close()) + if err := cli.Close(); err != nil { + require.ErrorIs(t, err, net.ErrClosed) + } }) wg.Run(func() { srv, err := listener.Accept() @@ -73,7 +75,9 @@ func TestTCPConnWithListener(t *testing.T, listen func(*testing.T, string, strin require.NoError(t, srv.SetDeadline(ddl)) } b(t, srv) - require.NoError(t, srv.Close()) + if err := srv.Close(); err != nil { + require.ErrorIs(t, err, net.ErrClosed) + } }) wg.Wait() }