Skip to content

Commit

Permalink
proxy: listen on additional addrs (#393)
Browse files Browse the repository at this point in the history
Signed-off-by: xhe <xw897002528@gmail.com>
  • Loading branch information
xhebox authored Nov 3, 2023
1 parent 8d0bf38 commit 9dac468
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 42 deletions.
1 change: 0 additions & 1 deletion lib/config/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ func (cfg *Config) Clone() *Config {
}

func (cfg *Config) Check() error {

if cfg.Workdir == "" {
d, err := os.Getwd()
if err != nil {
Expand Down
9 changes: 5 additions & 4 deletions pkg/manager/config/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,16 +136,17 @@ func (e *ConfigManager) Close() error {
e.cancel()
e.cancel = nil
}
if e.wch != nil {
wcherr = e.wch.Close()
e.wch = nil
}
e.sts.Lock()
for _, ch := range e.sts.listeners {
close(ch)
}
e.sts.listeners = nil
e.sts.Unlock()
e.wg.Wait()
// close after all goroutines are done
if e.wch != nil {
wcherr = e.wch.Close()
e.wch = nil
}
return wcherr
}
3 changes: 2 additions & 1 deletion pkg/manager/infosync/info.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ func (is *InfoSyncer) getTopologyInfo(cfg *config.Config) (*TopologyInfo, error)
s = ""
}
dir := path.Dir(s)
ip, port, err := net.SplitHostPort(cfg.Proxy.Addr)
addrs := strings.Split(cfg.Proxy.Addr, ",")
ip, port, err := net.SplitHostPort(addrs[0])
if err != nil {
return nil, errors.WithStack(err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/metrics/metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (
// Test that the metrics are pushed or not pushed with different configurations.
func TestPushMetrics(t *testing.T) {
proxyAddr := "0.0.0.0:6000"
labelName := fmt.Sprintf("%s_%s_connections", ModuleProxy, LabelServer)
labelName := fmt.Sprintf("%s_%s_maxprocs", ModuleProxy, LabelServer)
hostname, err := os.Hostname()
require.NoError(t, err)
expectedPath := fmt.Sprintf("/metrics/job/tiproxy/instance/%s_6000", hostname)
Expand Down
1 change: 1 addition & 0 deletions pkg/proxy/backend/handshake_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type ConnContextKey string
const (
ConnContextKeyTLSState ConnContextKey = "tls-state"
ConnContextKeyConnID ConnContextKey = "conn-id"
ConnContextKeyConnAddr ConnContextKey = "conn-addr"
)

type ErrorSource int
Expand Down
3 changes: 2 additions & 1 deletion pkg/proxy/client/client_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ type ClientConnection struct {
}

func NewClientConnection(logger *zap.Logger, conn net.Conn, frontendTLSConfig *tls.Config, backendTLSConfig *tls.Config,
hsHandler backend.HandshakeHandler, connID uint64, bcConfig *backend.BCConfig) *ClientConnection {
hsHandler backend.HandshakeHandler, connID uint64, addr string, bcConfig *backend.BCConfig) *ClientConnection {
bemgr := backend.NewBackendConnManager(logger.Named("be"), hsHandler, connID, bcConfig)
bemgr.SetValue(backend.ConnContextKeyConnAddr, addr)
opts := make([]pnet.PacketIOption, 0, 2)
opts = append(opts, pnet.WithWrapError(backend.ErrClientConn))
if bcConfig.ProxyProtocol {
Expand Down
63 changes: 36 additions & 27 deletions pkg/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package proxy
import (
"context"
"net"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -37,7 +38,8 @@ type serverState struct {
}

type SQLServer struct {
listener net.Listener
listeners []net.Listener
addrs []string
logger *zap.Logger
certMgr *cert.CertManager
hsHandler backend.HandshakeHandler
Expand Down Expand Up @@ -65,9 +67,13 @@ func NewSQLServer(logger *zap.Logger, cfg config.ProxyServer, certMgr *cert.Cert

s.reset(&cfg.ProxyServerOnline)

s.listener, err = net.Listen("tcp", cfg.Addr)
if err != nil {
return nil, err
s.addrs = strings.Split(cfg.Addr, ",")
s.listeners = make([]net.Listener, len(s.addrs))
for i, addr := range s.addrs {
s.listeners[i], err = net.Listen("tcp", addr)
if err != nil {
return nil, err
}
}

return s, nil
Expand Down Expand Up @@ -104,31 +110,34 @@ func (s *SQLServer) Run(ctx context.Context, cfgch <-chan *config.Config) {
}
})

s.wg.Run(func() {
for {
select {
case <-ctx.Done():
return
default:
conn, err := s.listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
for i := range s.listeners {
j := i
s.wg.Run(func() {
for {
select {
case <-ctx.Done():
return
default:
conn, err := s.listeners[j].Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}

s.logger.Error("accept failed", zap.Error(err))
continue
}

s.logger.Error("accept failed", zap.Error(err))
continue
s.wg.Run(func() {
util.WithRecovery(func() { s.onConn(ctx, conn, s.addrs[j]) }, nil, s.logger)
})
}

s.wg.Run(func() {
util.WithRecovery(func() { s.onConn(ctx, conn) }, nil, s.logger)
})
}
}
})
})
}
}

func (s *SQLServer) onConn(ctx context.Context, conn net.Conn) {
func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) {
s.mu.Lock()
conns := uint64(len(s.mu.clients))
maxConns := s.mu.maxConnections
Expand All @@ -149,9 +158,9 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn) {
connID := s.mu.connID
s.mu.connID++
logger := s.logger.With(zap.Uint64("connID", connID), zap.String("client_addr", conn.RemoteAddr().String()),
zap.Bool("proxy-protocol", s.mu.proxyProtocol))
zap.Bool("proxy-protocol", s.mu.proxyProtocol), zap.String("addr", addr))
clientConn := client.NewClientConnection(logger.Named("conn"), conn, s.certMgr.ServerTLS(), s.certMgr.SQLTLS(),
s.hsHandler, connID, &backend.BCConfig{
s.hsHandler, connID, addr, &backend.BCConfig{
ProxyProtocol: s.mu.proxyProtocol,
RequireBackendTLS: s.requireBackendTLS,
HealthyKeepAlive: s.mu.healthyKeepAlive,
Expand Down Expand Up @@ -232,8 +241,8 @@ func (s *SQLServer) Close() error {
s.cancelFunc = nil
}
errs := make([]error, 0, 4)
if s.listener != nil {
errs = append(errs, s.listener.Close())
for i := range s.listeners {
errs = append(errs, s.listeners[i].Close())
}

s.mu.RLock()
Expand Down
37 changes: 30 additions & 7 deletions pkg/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package proxy
import (
"context"
"database/sql"
"fmt"
"net"
"strings"
"testing"
Expand Down Expand Up @@ -48,13 +49,13 @@ func TestGracefulShutdown(t *testing.T) {
createClientConn := func() *client.ClientConnection {
server.mu.Lock()
go func() {
conn, err := net.Dial("tcp", server.listener.Addr().String())
conn, err := net.Dial("tcp", server.listeners[0].Addr().String())
require.NoError(t, err)
require.NoError(t, conn.Close())
}()
conn, err := server.listener.Accept()
conn, err := server.listeners[0].Accept()
require.NoError(t, err)
clientConn := client.NewClientConnection(lg, conn, nil, nil, hsHandler, 0, &backend.BCConfig{})
clientConn := client.NewClientConnection(lg, conn, nil, nil, hsHandler, 0, "", &backend.BCConfig{})
server.mu.clients[1] = clientConn
server.mu.Unlock()
return clientConn
Expand Down Expand Up @@ -107,18 +108,40 @@ func TestGracefulShutdown(t *testing.T) {
}
}

func TestRecoverPanic(t *testing.T) {
lg, text := logger.CreateLoggerForTest(t)
func TestMultiAddr(t *testing.T) {
lg, _ := logger.CreateLoggerForTest(t)
certManager := cert.NewCertManager()
err := certManager.Init(&config.Config{}, lg, nil)
require.NoError(t, err)
server, err := NewSQLServer(lg, config.ProxyServer{
Addr: "0.0.0.0:6000",
Addr: "0.0.0.0:0,0.0.0.0:0",
}, certManager, &panicHsHandler{})
require.NoError(t, err)
server.Run(context.Background(), nil)

mdb, err := sql.Open("mysql", "root@tcp(localhost:6000)/test")
require.Len(t, server.listeners, 2)
for _, listener := range server.listeners {
conn, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
require.NoError(t, conn.Close())
}

require.NoError(t, server.Close())
certManager.Close()
}

func TestRecoverPanic(t *testing.T) {
lg, text := logger.CreateLoggerForTest(t)
certManager := cert.NewCertManager()
err := certManager.Init(&config.Config{}, lg, nil)
require.NoError(t, err)
server, err := NewSQLServer(lg, config.ProxyServer{}, certManager, &panicHsHandler{})
require.NoError(t, err)
server.Run(context.Background(), nil)

_, port, err := net.SplitHostPort(server.listeners[0].Addr().String())
require.NoError(t, err)
mdb, err := sql.Open("mysql", fmt.Sprintf("root@tcp(localhost:%s)/test", port))
require.NoError(t, err)
// The first connection encounters panic.
require.ErrorContains(t, mdb.Ping(), "invalid connection")
Expand Down

0 comments on commit 9dac468

Please sign in to comment.