Skip to content

Commit

Permalink
Merge pull request #5563 from planetscale/jacques_proxy_protocol
Browse files Browse the repository at this point in the history
Add proxy protocol support for vtgate.
  • Loading branch information
deepthi authored Dec 19, 2019
2 parents 2ba9c70 + ca800b3 commit d774709
Show file tree
Hide file tree
Showing 9 changed files with 37 additions and 25 deletions.
4 changes: 4 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ require (
github.com/golang/mock v1.3.1
github.com/golang/protobuf v1.3.2
github.com/golang/snappy v0.0.0-20170215233205-553a64147049
github.com/google/btree v1.0.0 // indirect
github.com/google/shlex v0.0.0-20181106134648-c34317bd91bf // indirect
github.com/gorilla/websocket v0.0.0-20160912153041-2d1e4548da23
github.com/grpc-ecosystem/go-grpc-middleware v1.1.0
Expand All @@ -49,10 +50,13 @@ require (
github.com/minio/minio-go v0.0.0-20190131015406-c8a261de75c1
github.com/mitchellh/go-testing-interface v1.0.0 // indirect
github.com/mitchellh/mapstructure v1.1.2 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.1 // indirect
github.com/olekukonko/tablewriter v0.0.0-20160115111002-cca8bbc07984
github.com/opentracing-contrib/go-grpc v0.0.0-20180928155321-4b5a12d3ff02
github.com/opentracing/opentracing-go v1.1.0
github.com/pborman/uuid v0.0.0-20160824210600-b984ec7fa9ff
github.com/pires/go-proxyproto v0.0.0-20191211124218-517ecdf5bb2b
github.com/pkg/errors v0.8.1
github.com/prometheus/client_golang v1.1.0
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ github.com/pascaldekloe/goe v0.1.0 h1:cBOtyMzM9HTpWjXfbbunk26uA6nG3a8n06Wieeh0Mw
github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
github.com/pborman/uuid v0.0.0-20160824210600-b984ec7fa9ff h1:pTiDfW+iOjIxjZeCm88gKn/AmR09UGZYZdqif2yPRrM=
github.com/pborman/uuid v0.0.0-20160824210600-b984ec7fa9ff/go.mod h1:VyrYX9gd7irzKovcSS6BIIEwPRkP2Wm2m9ufcdFSJ34=
github.com/pires/go-proxyproto v0.0.0-20191211124218-517ecdf5bb2b h1:JPLdtNmpXbWytipbGwYz7zXZzlQNASEiFw5aGAM75us=
github.com/pires/go-proxyproto v0.0.0-20191211124218-517ecdf5bb2b/go.mod h1:Odh9VFOZJCf9G8cLW5o435Xf1J95Jw9Gw5rnCjcwzAY=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
Expand Down
4 changes: 2 additions & 2 deletions go/mysql/auth_server_clientcert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestValidCert(t *testing.T) {
}

// Create the listener, so we can get its host.
l, err := NewListener("tcp", ":0", authServer, th, 0, 0)
l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false)
if err != nil {
t.Fatalf("NewListener failed: %v", err)
}
Expand Down Expand Up @@ -122,7 +122,7 @@ func TestNoCert(t *testing.T) {
}

// Create the listener, so we can get its host.
l, err := NewListener("tcp", ":0", authServer, th, 0, 0)
l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false)
if err != nil {
t.Fatalf("NewListener failed: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion go/mysql/fakesqldb/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func New(t *testing.T) *DB {
authServer := &mysql.AuthServerNone{}

// Start listening.
db.listener, err = mysql.NewListener("unix", socketFile, authServer, db, 0, 0)
db.listener, err = mysql.NewListener("unix", socketFile, authServer, db, 0, 0, false)
if err != nil {
t.Fatalf("NewListener failed: %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions go/mysql/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func TestClearTextClientAuth(t *testing.T) {
}

// Create the listener.
l, err := NewListener("tcp", ":0", authServer, th, 0, 0)
l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false)
if err != nil {
t.Fatalf("NewListener failed: %v", err)
}
Expand Down Expand Up @@ -101,7 +101,7 @@ func TestSSLConnection(t *testing.T) {
}

// Create the listener, so we can get its host.
l, err := NewListener("tcp", ":0", authServer, th, 0, 0)
l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false)
if err != nil {
t.Fatalf("NewListener failed: %v", err)
}
Expand Down
7 changes: 6 additions & 1 deletion go/mysql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"strings"
"time"

proxyproto "github.com/pires/go-proxyproto"
"vitess.io/vitess/go/netutil"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/stats"
Expand Down Expand Up @@ -184,11 +185,15 @@ func NewFromListener(l net.Listener, authServer AuthServer, handler Handler, con
}

// NewListener creates a new Listener.
func NewListener(protocol, address string, authServer AuthServer, handler Handler, connReadTimeout time.Duration, connWriteTimeout time.Duration) (*Listener, error) {
func NewListener(protocol, address string, authServer AuthServer, handler Handler, connReadTimeout time.Duration, connWriteTimeout time.Duration, proxyProtocol bool) (*Listener, error) {
listener, err := net.Listen(protocol, address)
if err != nil {
return nil, err
}
if proxyProtocol {
proxyListener := &proxyproto.Listener{Listener: listener}
return NewFromListener(proxyListener, authServer, handler, connReadTimeout, connWriteTimeout)
}

return NewFromListener(listener, authServer, handler, connReadTimeout, connWriteTimeout)
}
Expand Down
26 changes: 13 additions & 13 deletions go/mysql/server_flaky_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ func TestConnectionWithoutSourceHost(t *testing.T) {
Password: "password1",
UserData: "userData1",
}}
l, err := NewListener("tcp", ":0", authServer, th, 0, 0)
l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false)
if err != nil {
t.Fatalf("NewListener failed: %v", err)
}
Expand Down Expand Up @@ -287,7 +287,7 @@ func TestConnectionWithSourceHost(t *testing.T) {
},
}

l, err := NewListener("tcp", ":0", authServer, th, 0, 0)
l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false)
if err != nil {
t.Fatalf("NewListener failed: %v", err)
}
Expand Down Expand Up @@ -324,7 +324,7 @@ func TestConnectionUseMysqlNativePasswordWithSourceHost(t *testing.T) {
},
}

l, err := NewListener("tcp", ":0", authServer, th, 0, 0)
l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false)
if err != nil {
t.Fatalf("NewListener failed: %v", err)
}
Expand Down Expand Up @@ -367,7 +367,7 @@ func TestConnectionUnixSocket(t *testing.T) {
}
os.Remove(unixSocket.Name())

l, err := NewListener("unix", unixSocket.Name(), authServer, th, 0, 0)
l, err := NewListener("unix", unixSocket.Name(), authServer, th, 0, 0, false)
if err != nil {
t.Fatalf("NewListener failed: %v", err)
}
Expand Down Expand Up @@ -396,7 +396,7 @@ func TestClientFoundRows(t *testing.T) {
Password: "password1",
UserData: "userData1",
}}
l, err := NewListener("tcp", ":0", authServer, th, 0, 0)
l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false)
if err != nil {
t.Fatalf("NewListener failed: %v", err)
}
Expand Down Expand Up @@ -453,7 +453,7 @@ func TestConnCounts(t *testing.T) {
Password: passwd,
UserData: "userData1",
}}
l, err := NewListener("tcp", ":0", authServer, th, 0, 0)
l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false)
if err != nil {
t.Fatalf("NewListener failed: %v", err)
}
Expand Down Expand Up @@ -525,7 +525,7 @@ func TestServer(t *testing.T) {
Password: "password1",
UserData: "userData1",
}}
l, err := NewListener("tcp", ":0", authServer, th, 0, 0)
l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false)
if err != nil {
t.Fatalf("NewListener failed: %v", err)
}
Expand Down Expand Up @@ -732,7 +732,7 @@ func TestClearTextServer(t *testing.T) {
UserData: "userData1",
}}
authServer.Method = MysqlClearPassword
l, err := NewListener("tcp", ":0", authServer, th, 0, 0)
l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false)
if err != nil {
t.Fatalf("NewListener failed: %v", err)
}
Expand Down Expand Up @@ -817,7 +817,7 @@ func TestDialogServer(t *testing.T) {
UserData: "userData1",
}}
authServer.Method = MysqlDialog
l, err := NewListener("tcp", ":0", authServer, th, 0, 0)
l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false)
if err != nil {
t.Fatalf("NewListener failed: %v", err)
}
Expand Down Expand Up @@ -864,7 +864,7 @@ func TestTLSServer(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", ":0", authServer, th, 0, 0)
l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false)
if err != nil {
t.Fatalf("NewListener failed: %v", err)
}
Expand Down Expand Up @@ -966,7 +966,7 @@ func TestTLSRequired(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", ":0", authServer, th, 0, 0)
l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false)
if err != nil {
t.Fatalf("NewListener failed: %v", err)
}
Expand Down Expand Up @@ -1052,7 +1052,7 @@ func TestErrorCodes(t *testing.T) {
Password: "password1",
UserData: "userData1",
}}
l, err := NewListener("tcp", ":0", authServer, th, 0, 0)
l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false)
if err != nil {
t.Fatalf("NewListener failed: %v", err)
}
Expand Down Expand Up @@ -1235,7 +1235,7 @@ func TestListenerShutdown(t *testing.T) {
Password: "password1",
UserData: "userData1",
}}
l, err := NewListener("tcp", ":0", authServer, th, 0, 0)
l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false)
if err != nil {
t.Fatalf("NewListener failed: %v", err)
}
Expand Down
7 changes: 4 additions & 3 deletions go/vt/vtgate/plugin_mysql_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ var (
mysqlAuthServerImpl = flag.String("mysql_auth_server_impl", "static", "Which auth server implementation to use.")
mysqlAllowClearTextWithoutTLS = flag.Bool("mysql_allow_clear_text_without_tls", false, "If set, the server will allow the use of a clear text password over non-SSL connections.")
mysqlServerVersion = flag.String("mysql_server_version", mysql.DefaultServerVersion, "MySQL server version to advertise.")
mysqlProxyProtocol = flag.Bool("proxy_protocol", false, "Enable HAProxy PROXY protocol on MySQL listener socket")

mysqlServerRequireSecureTransport = flag.Bool("mysql_server_require_secure_transport", false, "Reject insecure connections but only if mysql_server_ssl_cert and mysql_server_ssl_key are provided")

Expand Down Expand Up @@ -339,7 +340,7 @@ func initMySQLProtocol() {
var err error
vh := newVtgateHandler(rpcVTGate)
if *mysqlServerPort >= 0 {
mysqlListener, err = mysql.NewListener(*mysqlTCPVersion, net.JoinHostPort(*mysqlServerBindAddress, fmt.Sprintf("%v", *mysqlServerPort)), authServer, vh, *mysqlConnReadTimeout, *mysqlConnWriteTimeout)
mysqlListener, err = mysql.NewListener(*mysqlTCPVersion, net.JoinHostPort(*mysqlServerBindAddress, fmt.Sprintf("%v", *mysqlServerPort)), authServer, vh, *mysqlConnReadTimeout, *mysqlConnWriteTimeout, *mysqlProxyProtocol)
if err != nil {
log.Exitf("mysql.NewListener failed: %v", err)
}
Expand Down Expand Up @@ -382,7 +383,7 @@ func initMySQLProtocol() {
// newMysqlUnixSocket creates a new unix socket mysql listener. If a socket file already exists, attempts
// to clean it up.
func newMysqlUnixSocket(address string, authServer mysql.AuthServer, handler mysql.Handler) (*mysql.Listener, error) {
listener, err := mysql.NewListener("unix", address, authServer, handler, *mysqlConnReadTimeout, *mysqlConnWriteTimeout)
listener, err := mysql.NewListener("unix", address, authServer, handler, *mysqlConnReadTimeout, *mysqlConnWriteTimeout, false)
switch err := err.(type) {
case nil:
return listener, nil
Expand All @@ -403,7 +404,7 @@ func newMysqlUnixSocket(address string, authServer mysql.AuthServer, handler mys
log.Errorf("Couldn't remove existent socket file: %s", address)
return nil, err
}
listener, listenerErr := mysql.NewListener("unix", address, authServer, handler, *mysqlConnReadTimeout, *mysqlConnWriteTimeout)
listener, listenerErr := mysql.NewListener("unix", address, authServer, handler, *mysqlConnReadTimeout, *mysqlConnWriteTimeout, false)
return listener, listenerErr
default:
return nil, err
Expand Down
6 changes: 3 additions & 3 deletions go/vt/vtqueryserver/plugin_mysql_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func initMySQLProtocol() {
var err error
mh := newProxyHandler(mysqlProxy)
if *mysqlServerPort >= 0 {
mysqlListener, err = mysql.NewListener("tcp", net.JoinHostPort(*mysqlServerBindAddress, fmt.Sprintf("%v", *mysqlServerPort)), authServer, mh, *mysqlConnReadTimeout, *mysqlConnWriteTimeout)
mysqlListener, err = mysql.NewListener("tcp", net.JoinHostPort(*mysqlServerBindAddress, fmt.Sprintf("%v", *mysqlServerPort)), authServer, mh, *mysqlConnReadTimeout, *mysqlConnWriteTimeout, false)
if err != nil {
log.Exitf("mysql.NewListener failed: %v", err)
}
Expand Down Expand Up @@ -223,7 +223,7 @@ func initMySQLProtocol() {
// newMysqlUnixSocket creates a new unix socket mysql listener. If a socket file already exists, attempts
// to clean it up.
func newMysqlUnixSocket(address string, authServer mysql.AuthServer, handler mysql.Handler) (*mysql.Listener, error) {
listener, err := mysql.NewListener("unix", address, authServer, handler, *mysqlConnReadTimeout, *mysqlConnWriteTimeout)
listener, err := mysql.NewListener("unix", address, authServer, handler, *mysqlConnReadTimeout, *mysqlConnWriteTimeout, false)
switch err := err.(type) {
case nil:
return listener, nil
Expand All @@ -244,7 +244,7 @@ func newMysqlUnixSocket(address string, authServer mysql.AuthServer, handler mys
log.Errorf("Couldn't remove existent socket file: %s", address)
return nil, err
}
listener, listenerErr := mysql.NewListener("unix", address, authServer, handler, *mysqlConnReadTimeout, *mysqlConnWriteTimeout)
listener, listenerErr := mysql.NewListener("unix", address, authServer, handler, *mysqlConnReadTimeout, *mysqlConnWriteTimeout, false)
return listener, listenerErr
default:
return nil, err
Expand Down

0 comments on commit d774709

Please sign in to comment.