From 0899b7bc111778bea74a49244d14881690cfacdd Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Tue, 4 Jul 2023 10:35:30 +0530 Subject: [PATCH 1/5] tcp keep alive and period settings Signed-off-by: Harshit Gangal --- go/flags/endtoend/vtgate.txt | 1 + go/mysql/auth_server_clientcert_test.go | 4 +- go/mysql/client_test.go | 10 +-- go/mysql/fakesqldb/server.go | 6 +- go/mysql/handshake_test.go | 4 +- go/mysql/server.go | 83 +++++++++++++++++-------- go/mysql/server_flaky_test.go | 48 +++++++------- go/vt/vtgate/plugin_mysql_server.go | 5 ++ 8 files changed, 98 insertions(+), 63 deletions(-) diff --git a/go/flags/endtoend/vtgate.txt b/go/flags/endtoend/vtgate.txt index 469e6506eef..86d116cfd87 100644 --- a/go/flags/endtoend/vtgate.txt +++ b/go/flags/endtoend/vtgate.txt @@ -89,6 +89,7 @@ Usage of vtgate: --max_payload_size int The threshold for query payloads in bytes. A payload greater than this threshold will result in a failure to handle the query. --message_stream_grace_period duration the amount of time to give for a vttablet to resume if it ends a message stream, usually because of a reparent. (default 30s) --min_number_serving_vttablets int The minimum number of vttablets for each replicating tablet_type (e.g. replica, rdonly) that will be continue to be used even with replication lag above discovery_low_replication_lag, but still below discovery_high_replication_lag_minimum_serving. (default 2) + --mysql-server-keepalive-period duration TCP period between keep-alives --mysql-server-pool-conn-read-buffers If set, the server will pool incoming connection read buffers --mysql_allow_clear_text_without_tls If set, the server will allow the use of a clear text password over non-SSL connections. --mysql_auth_server_impl string Which auth server implementation to use. Options: none, ldap, clientcert, static, vault. (default "static") diff --git a/go/mysql/auth_server_clientcert_test.go b/go/mysql/auth_server_clientcert_test.go index 4528ee5dbf4..28ed19fd9c5 100644 --- a/go/mysql/auth_server_clientcert_test.go +++ b/go/mysql/auth_server_clientcert_test.go @@ -45,7 +45,7 @@ func TestValidCert(t *testing.T) { authServer := newAuthServerClientCert() // Create the listener, so we can get its host. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) require.NoError(t, err, "NewListener failed: %v", err) defer l.Close() host := l.Addr().(*net.TCPAddr).IP.String() @@ -114,7 +114,7 @@ func TestNoCert(t *testing.T) { authServer := newAuthServerClientCert() // Create the listener, so we can get its host. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) require.NoError(t, err, "NewListener failed: %v", err) defer l.Close() host := l.Addr().(*net.TCPAddr).IP.String() diff --git a/go/mysql/client_test.go b/go/mysql/client_test.go index f9db5cee523..ddbb7f19f06 100644 --- a/go/mysql/client_test.go +++ b/go/mysql/client_test.go @@ -149,7 +149,7 @@ func TestTLSClientDisabled(t *testing.T) { // Below, we are enabling --ssl-verify-server-cert, which adds // a check that the common name of the certificate matches the // server host name we connect to. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) require.NoError(t, err) defer l.Close() @@ -221,7 +221,7 @@ func TestTLSClientPreferredDefault(t *testing.T) { // Below, we are enabling --ssl-verify-server-cert, which adds // a check that the common name of the certificate matches the // server host name we connect to. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) require.NoError(t, err) defer l.Close() @@ -294,7 +294,7 @@ func TestTLSClientRequired(t *testing.T) { // Below, we are enabling --ssl-verify-server-cert, which adds // a check that the common name of the certificate matches the // server host name we connect to. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) require.NoError(t, err) defer l.Close() @@ -341,7 +341,7 @@ func TestTLSClientVerifyCA(t *testing.T) { // Below, we are enabling --ssl-verify-server-cert, which adds // a check that the common name of the certificate matches the // server host name we connect to. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) require.NoError(t, err) defer l.Close() @@ -424,7 +424,7 @@ func TestTLSClientVerifyIdentity(t *testing.T) { // Below, we are enabling --ssl-verify-server-cert, which adds // a check that the common name of the certificate matches the // server host name we connect to. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) require.NoError(t, err) defer l.Close() diff --git a/go/mysql/fakesqldb/server.go b/go/mysql/fakesqldb/server.go index f43f63c0d53..746f82aed2a 100644 --- a/go/mysql/fakesqldb/server.go +++ b/go/mysql/fakesqldb/server.go @@ -188,7 +188,7 @@ func New(t testing.TB) *DB { authServer := mysql.NewAuthServerNone() // Start listening. - db.listener, err = mysql.NewListener("unix", socketFile, authServer, db, 0, 0, false, false) + db.listener, err = mysql.NewListener("unix", socketFile, authServer, db, 0, 0, false, false, 0) if err != nil { t.Fatalf("NewListener failed: %v", err) } @@ -382,7 +382,7 @@ func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.R if db.shouldClose.Load() { c.Close() - //log error + // log error if err := callback(&sqltypes.Result{}); err != nil { log.Errorf("callback failed : %v", err) } @@ -393,7 +393,7 @@ func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.R // The driver may send this at connection time, and we don't want it to // interfere. if key == "set names utf8" || strings.HasPrefix(key, "set collation_connection = ") { - //log error + // log error if err := callback(&sqltypes.Result{}); err != nil { log.Errorf("callback failed : %v", err) } diff --git a/go/mysql/handshake_test.go b/go/mysql/handshake_test.go index b6532f830b3..c2b27d6f6d4 100644 --- a/go/mysql/handshake_test.go +++ b/go/mysql/handshake_test.go @@ -45,7 +45,7 @@ func TestClearTextClientAuth(t *testing.T) { defer authServer.close() // Create the listener. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) require.NoError(t, err, "NewListener failed: %v", err) defer l.Close() host := l.Addr().(*net.TCPAddr).IP.String() @@ -99,7 +99,7 @@ func TestSSLConnection(t *testing.T) { defer authServer.close() // Create the listener, so we can get its host. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) require.NoError(t, err, "NewListener failed: %v", err) defer l.Close() host := l.Addr().(*net.TCPAddr).IP.String() diff --git a/go/mysql/server.go b/go/mysql/server.go index e17bd82ef90..47cdf408feb 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -196,6 +196,9 @@ type Listener struct { // connBufferPooling configures if vtgate server pools connection buffers connBufferPooling bool + // connKeepAlivePeriod is period between tcp keep-alives. + connKeepAlivePeriod time.Duration + // shutdown indicates that Shutdown method was called. shutdown atomic.Bool @@ -218,15 +221,17 @@ func NewFromListener( connReadTimeout time.Duration, connWriteTimeout time.Duration, connBufferPooling bool, + keepAlivePeriod time.Duration, ) (*Listener, error) { cfg := ListenerConfig{ - Listener: l, - AuthServer: authServer, - Handler: handler, - ConnReadTimeout: connReadTimeout, - ConnWriteTimeout: connWriteTimeout, - ConnReadBufferSize: connBufferSize, - ConnBufferPooling: connBufferPooling, + Listener: l, + AuthServer: authServer, + Handler: handler, + ConnReadTimeout: connReadTimeout, + ConnWriteTimeout: connWriteTimeout, + ConnReadBufferSize: connBufferSize, + ConnBufferPooling: connBufferPooling, + ConnKeepAlivePeriod: keepAlivePeriod, } return NewListenerWithConfig(cfg) } @@ -240,6 +245,7 @@ func NewListener( connWriteTimeout time.Duration, proxyProtocol bool, connBufferPooling bool, + keepAlivePeriod time.Duration, ) (*Listener, error) { listener, err := net.Listen(protocol, address) if err != nil { @@ -247,24 +253,25 @@ func NewListener( } if proxyProtocol { proxyListener := &proxyproto.Listener{Listener: listener} - return NewFromListener(proxyListener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling) + return NewFromListener(proxyListener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling, keepAlivePeriod) } - return NewFromListener(listener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling) + return NewFromListener(listener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling, keepAlivePeriod) } // ListenerConfig should be used with NewListenerWithConfig to specify listener parameters. type ListenerConfig struct { // Protocol-Address pair and Listener are mutually exclusive parameters - Protocol string - Address string - Listener net.Listener - AuthServer AuthServer - Handler Handler - ConnReadTimeout time.Duration - ConnWriteTimeout time.Duration - ConnReadBufferSize int - ConnBufferPooling bool + Protocol string + Address string + Listener net.Listener + AuthServer AuthServer + Handler Handler + ConnReadTimeout time.Duration + ConnWriteTimeout time.Duration + ConnReadBufferSize int + ConnBufferPooling bool + ConnKeepAlivePeriod time.Duration } // NewListenerWithConfig creates new listener using provided config. There are @@ -282,15 +289,16 @@ func NewListenerWithConfig(cfg ListenerConfig) (*Listener, error) { } return &Listener{ - authServer: cfg.AuthServer, - handler: cfg.Handler, - listener: l, - ServerVersion: servenv.AppVersion.MySQLVersion(), - connectionID: 1, - connReadTimeout: cfg.ConnReadTimeout, - connWriteTimeout: cfg.ConnWriteTimeout, - connReadBufferSize: cfg.ConnReadBufferSize, - connBufferPooling: cfg.ConnBufferPooling, + authServer: cfg.AuthServer, + handler: cfg.Handler, + listener: l, + ServerVersion: servenv.AppVersion.MySQLVersion(), + connectionID: 1, + connReadTimeout: cfg.ConnReadTimeout, + connWriteTimeout: cfg.ConnWriteTimeout, + connReadBufferSize: cfg.ConnReadBufferSize, + connBufferPooling: cfg.ConnBufferPooling, + connKeepAlivePeriod: cfg.ConnKeepAlivePeriod, }, nil } @@ -336,6 +344,12 @@ func (l *Listener) Accept() { // handle is called in a go routine for each client connection. // FIXME(alainjobart) handle per-connection logs in a way that makes sense. func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Time) { + + // Enable KeepAlive on TCP connections and change keep-alive period if provided. + if tcpConn, ok := conn.(*net.TCPConn); ok { + setTcpConnProperties(tcpConn, l.connKeepAlivePeriod) + } + if l.connReadTimeout != 0 || l.connWriteTimeout != 0 { conn = netutil.NewConnWithTimeouts(conn, l.connReadTimeout, l.connWriteTimeout) } @@ -531,6 +545,21 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti } } +func setTcpConnProperties(conn *net.TCPConn, keepAlivePeriod time.Duration) { + if err := conn.SetKeepAlive(true); err != nil { + log.Errorf("unable to enable keepalive on tcp connection: %v", err) + return + } + + if keepAlivePeriod <= 0 { + return + } + + if err := conn.SetKeepAlivePeriod(keepAlivePeriod); err != nil { + log.Errorf("unable to set keepalive period on tcp connection: %v", err) + } +} + // Close stops the listener, which prevents accept of any new connections. Existing connections won't be closed. func (l *Listener) Close() { l.listener.Close() diff --git a/go/mysql/server_flaky_test.go b/go/mysql/server_flaky_test.go index 7225f29a816..3255e098966 100644 --- a/go/mysql/server_flaky_test.go +++ b/go/mysql/server_flaky_test.go @@ -263,7 +263,7 @@ func TestConnectionFromListener(t *testing.T) { listener, err := net.Listen("tcp", "127.0.0.1:") require.NoError(t, err, "net.Listener failed") - l, err := NewFromListener(listener, authServer, th, 0, 0, false) + l, err := NewFromListener(listener, authServer, th, 0, 0, false, 0) require.NoError(t, err, "NewListener failed") defer l.Close() go l.Accept() @@ -292,7 +292,7 @@ func TestConnectionWithoutSourceHost(t *testing.T) { UserData: "userData1", }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) require.NoError(t, err, "NewListener failed") defer l.Close() go l.Accept() @@ -325,7 +325,7 @@ func TestConnectionWithSourceHost(t *testing.T) { } defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) require.NoError(t, err, "NewListener failed") defer l.Close() go l.Accept() @@ -358,7 +358,7 @@ func TestConnectionUseMysqlNativePasswordWithSourceHost(t *testing.T) { } defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) require.NoError(t, err, "NewListener failed") defer l.Close() go l.Accept() @@ -396,7 +396,7 @@ func TestConnectionUnixSocket(t *testing.T) { os.Remove(unixSocket.Name()) - l, err := NewListener("unix", unixSocket.Name(), authServer, th, 0, 0, false, false) + l, err := NewListener("unix", unixSocket.Name(), authServer, th, 0, 0, false, false, 0) require.NoError(t, err, "NewListener failed") defer l.Close() go l.Accept() @@ -422,7 +422,7 @@ func TestClientFoundRows(t *testing.T) { UserData: "userData1", }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) require.NoError(t, err, "NewListener failed") defer l.Close() go l.Accept() @@ -471,7 +471,7 @@ func TestConnCounts(t *testing.T) { UserData: "userData1", }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) require.NoError(t, err, "NewListener failed") defer l.Close() go l.Accept() @@ -503,12 +503,12 @@ func TestConnCounts(t *testing.T) { // Test after closing connections. time.Sleep lets it work, but seems flakey. c.Close() - //time.Sleep(10 * time.Millisecond) - //checkCountsForUser(t, user, 1) + // time.Sleep(10 * time.Millisecond) + // checkCountsForUser(t, user, 1) c2.Close() - //time.Sleep(10 * time.Millisecond) - //checkCountsForUser(t, user, 0) + // time.Sleep(10 * time.Millisecond) + // checkCountsForUser(t, user, 0) } func checkCountsForUser(t *testing.T, user string, expected int64) { @@ -528,7 +528,7 @@ func TestServer(t *testing.T) { UserData: "userData1", }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) require.NoError(t, err) l.SlowConnectWarnThreshold.Store(time.Nanosecond.Nanoseconds()) defer l.Close() @@ -628,7 +628,7 @@ func TestServerStats(t *testing.T) { UserData: "userData1", }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) require.NoError(t, err) l.SlowConnectWarnThreshold.Store(time.Nanosecond.Nanoseconds()) defer l.Close() @@ -702,7 +702,7 @@ func TestClearTextServer(t *testing.T) { UserData: "userData1", }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) require.NoError(t, err) defer l.Close() go l.Accept() @@ -775,7 +775,7 @@ func TestDialogServer(t *testing.T) { UserData: "userData1", }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) require.NoError(t, err) l.AllowClearTextWithoutTLS.Store(true) defer l.Close() @@ -818,7 +818,7 @@ func TestTLSServer(t *testing.T) { // Below, we are enabling --ssl-verify-server-cert, which adds // a check that the common name of the certificate matches the // server host name we connect to. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) require.NoError(t, err) defer l.Close() @@ -870,7 +870,7 @@ func TestTLSServer(t *testing.T) { // Run a 'select rows' command with results. conn, err := Connect(context.Background(), params) - //output, ok := runMysql(t, params, "select rows") + // output, ok := runMysql(t, params, "select rows") require.NoError(t, err) results, err := conn.ExecuteFetch("select rows", 1000, true) require.NoError(t, err) @@ -916,7 +916,7 @@ func TestTLSRequired(t *testing.T) { // Below, we are enabling --ssl-verify-server-cert, which adds // a check that the common name of the certificate matches the // server host name we connect to. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) require.NoError(t, err) defer l.Close() @@ -1005,7 +1005,7 @@ func TestCachingSha2PasswordAuthWithTLS(t *testing.T) { defer authServer.close() // Create the listener, so we can get its host. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) require.NoError(t, err, "NewListener failed: %v", err) defer l.Close() host := l.Addr().(*net.TCPAddr).IP.String() @@ -1099,7 +1099,7 @@ func TestCachingSha2PasswordAuthWithMoreData(t *testing.T) { defer authServer.close() // Create the listener, so we can get its host. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) require.NoError(t, err, "NewListener failed: %v", err) defer l.Close() host := l.Addr().(*net.TCPAddr).IP.String() @@ -1168,7 +1168,7 @@ func TestCachingSha2PasswordAuthWithoutTLS(t *testing.T) { defer authServer.close() // Create the listener. - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) require.NoError(t, err, "NewListener failed: %v", err) defer l.Close() host := l.Addr().(*net.TCPAddr).IP.String() @@ -1210,7 +1210,7 @@ func TestErrorCodes(t *testing.T) { UserData: "userData1", }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) require.NoError(t, err) defer l.Close() go l.Accept() @@ -1388,7 +1388,7 @@ func TestListenerShutdown(t *testing.T) { UserData: "userData1", }} defer authServer.close() - l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false) + l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0) require.NoError(t, err) defer l.Close() go l.Accept() @@ -1461,7 +1461,7 @@ func TestServerFlush(t *testing.T) { th := &testHandler{} - l, err := NewListener("tcp", "127.0.0.1:", NewAuthServerNone(), th, 0, 0, false, false) + l, err := NewListener("tcp", "127.0.0.1:", NewAuthServerNone(), th, 0, 0, false, false, 0) require.NoError(t, err) defer l.Close() go l.Accept() diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index c7d4c53785c..4ca164039f0 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -62,6 +62,7 @@ var ( mysqlSslServerCA string mysqlTLSMinVersion string + mysqlKeepAlivePeriod time.Duration mysqlConnReadTimeout time.Duration mysqlConnWriteTimeout time.Duration mysqlQueryTimeout time.Duration @@ -94,6 +95,7 @@ func registerPluginFlags(fs *pflag.FlagSet) { fs.DurationVar(&mysqlConnWriteTimeout, "mysql_server_write_timeout", mysqlConnWriteTimeout, "connection write timeout") fs.DurationVar(&mysqlQueryTimeout, "mysql_server_query_timeout", mysqlQueryTimeout, "mysql query timeout") fs.BoolVar(&mysqlConnBufferPooling, "mysql-server-pool-conn-read-buffers", mysqlConnBufferPooling, "If set, the server will pool incoming connection read buffers") + fs.DurationVar(&mysqlKeepAlivePeriod, "mysql-server-keepalive-period", mysqlKeepAlivePeriod, "TCP period between keep-alives") fs.StringVar(&mysqlDefaultWorkloadName, "mysql_default_workload", mysqlDefaultWorkloadName, "Default session workload (OLTP, OLAP, DBA)") } @@ -475,6 +477,7 @@ func initMySQLProtocol() { mysqlConnWriteTimeout, mysqlProxyProtocol, mysqlConnBufferPooling, + mysqlKeepAlivePeriod, ) if err != nil { log.Exitf("mysql.NewListener failed: %v", err) @@ -525,6 +528,7 @@ func newMysqlUnixSocket(address string, authServer mysql.AuthServer, handler mys mysqlConnWriteTimeout, false, mysqlConnBufferPooling, + mysqlKeepAlivePeriod, ) switch err := err.(type) { @@ -556,6 +560,7 @@ func newMysqlUnixSocket(address string, authServer mysql.AuthServer, handler mys mysqlConnWriteTimeout, false, mysqlConnBufferPooling, + mysqlKeepAlivePeriod, ) return listener, listenerErr default: From 60de413cc15ed6c7f7a0967833c001659d9264d7 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Thu, 6 Jul 2023 14:46:44 +0530 Subject: [PATCH 2/5] added e2e test for tcp keep alive Signed-off-by: Harshit Gangal --- go.mod | 1 + go.sum | 2 + .../endtoend/vtgate/queries/misc/main_test.go | 8 ++- .../endtoend/vtgate/queries/misc/misc_test.go | 72 +++++++++++++++++-- 4 files changed, 76 insertions(+), 7 deletions(-) diff --git a/go.mod b/go.mod index 36fe98d7ec4..a71cbf4c0d8 100644 --- a/go.mod +++ b/go.mod @@ -97,6 +97,7 @@ require ( require ( github.com/Shopify/toxiproxy/v2 v2.5.0 github.com/bndr/gotabulate v1.1.2 + github.com/google/gopacket v1.1.19 github.com/google/safehtml v0.1.0 github.com/hashicorp/go-version v1.6.0 github.com/kr/pretty v0.3.1 diff --git a/go.sum b/go.sum index 746b7012d48..b2829390e2e 100644 --- a/go.sum +++ b/go.sum @@ -255,6 +255,8 @@ github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= +github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/martian v2.1.0+incompatible h1:/CP5g8u/VJHijgedC/Legn3BAbAaWPgecwXBIDzw5no= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= diff --git a/go/test/endtoend/vtgate/queries/misc/main_test.go b/go/test/endtoend/vtgate/queries/misc/main_test.go index de2d00219b6..5b175b50e97 100644 --- a/go/test/endtoend/vtgate/queries/misc/main_test.go +++ b/go/test/endtoend/vtgate/queries/misc/main_test.go @@ -61,7 +61,8 @@ func TestMain(m *testing.M) { return 1 } - clusterInstance.VtTabletExtraArgs = append(clusterInstance.VtTabletExtraArgs, "--queryserver-config-max-result-size", "1000000", + clusterInstance.VtTabletExtraArgs = append(clusterInstance.VtTabletExtraArgs, + "--queryserver-config-max-result-size", "1000000", "--queryserver-config-query-timeout", "200", "--queryserver-config-query-pool-timeout", "200") // Start Unsharded keyspace @@ -85,7 +86,10 @@ func TestMain(m *testing.M) { return 1 } - clusterInstance.VtGateExtraArgs = append(clusterInstance.VtGateExtraArgs, "--enable_system_settings=true", "--query-timeout=100") + clusterInstance.VtGateExtraArgs = append(clusterInstance.VtGateExtraArgs, + "--enable_system_settings", + "--query-timeout", "100", + "--mysql-server-keepalive-period", "3s") // Start vtgate err = clusterInstance.StartVtgate() if err != nil { diff --git a/go/test/endtoend/vtgate/queries/misc/misc_test.go b/go/test/endtoend/vtgate/queries/misc/misc_test.go index 667e59ed1ea..54cc698800c 100644 --- a/go/test/endtoend/vtgate/queries/misc/misc_test.go +++ b/go/test/endtoend/vtgate/queries/misc/misc_test.go @@ -17,17 +17,23 @@ limitations under the License. package misc import ( + "context" "database/sql" "fmt" + "net" "strconv" "strings" "testing" + "time" _ "github.com/go-sql-driver/mysql" - + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/google/gopacket/pcap" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/test/endtoend/cluster" "vitess.io/vitess/go/test/endtoend/utils" ) @@ -37,7 +43,7 @@ func start(t *testing.T) (utils.MySQLCompare, func()) { require.NoError(t, err) deleteAll := func() { - tables := []string{"t1"} + tables := []string{"t1", "uks.unsharded"} for _, table := range tables { _, _ = mcmp.ExecAndIgnore("delete from " + table) } @@ -126,7 +132,7 @@ func TestQueryTimeoutWithTables(t *testing.T) { // unsharded utils.Exec(t, mcmp.VtConn, "insert /*vt+ QUERY_TIMEOUT_MS=1000 */ into uks.unsharded(id1) values (1),(2),(3),(4),(5)") for i := 0; i < 12; i++ { - utils.Exec(t, mcmp.VtConn, "insert /*vt+ QUERY_TIMEOUT_MS=1000 */ into uks.unsharded(id1) select id1+5 from uks.unsharded") + utils.Exec(t, mcmp.VtConn, "insert /*vt+ QUERY_TIMEOUT_MS=2000 */ into uks.unsharded(id1) select id1+5 from uks.unsharded") } utils.Exec(t, mcmp.VtConn, "select count(*) from uks.unsharded where id1 > 31") @@ -304,12 +310,68 @@ func TestPrepareStatements(t *testing.T) { assert.ErrorContains(t, err, "VT09011: Unknown prepared statement handler (prep_art) given to DEALLOCATE PREPARE") } +// TestBuggyOuterJoin validates inconsistencies around outer joins, adding these tests to stop regressions. func TestBuggyOuterJoin(t *testing.T) { - // We found a couple of inconsistencies around outer joins, adding these tests to stop regressions mcmp, closer := start(t) defer closer() mcmp.Exec("insert into t1(id1, id2) values (1,2), (42,5), (5, 42)") - mcmp.Exec("select t1.id1, t2.id1 from t1 left join t1 as t2 on t2.id1 = t2.id2") } + +func TestTCPKeepAlive(t *testing.T) { + conn, err := mysql.Connect(context.Background(), &vtParams) + require.NoError(t, err) + defer conn.Close() + + _, ok := conn.GetRawConn().(*net.TCPConn) + if !ok { + t.Fatalf("tcp connection expected, got: %T", conn.GetRawConn()) + } + + // finding a loopback device. + ifs, err := pcap.FindAllDevs() + require.NoError(t, err) + var localDevice string + for _, dev := range ifs { + if strings.Contains(dev.Name, "lo") { + localDevice = dev.Name + } + } + if localDevice == "" { + t.Skip("loopback device not found. Cannot continue with the test.") + } + + // Set up packet capture handle + handle, err := pcap.OpenLive(localDevice, 65536, true, pcap.BlockForever) + require.NoError(t, err) + defer handle.Close() + + // Set packet filter for vtgate mysql port. + expr := fmt.Sprintf("tcp and port %d", clusterInstance.VtgateMySQLPort) + err = handle.SetBPFFilter(expr) + require.NoError(t, err) + + // Start capturing packets + packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) + ch := packetSource.Packets() + + testTimeout := time.After(5 * time.Second) + for { + select { + case packet := <-ch: + // Check if packet is TCP + if tcpLayer := packet.Layer(layers.LayerTypeTCP); tcpLayer != nil { + tcp, _ := tcpLayer.(*layers.TCP) + + // Check if packet has only the ACK flag set (keep-alive packet) + if tcp.ACK && !tcp.SYN && !tcp.FIN && !tcp.RST && !tcp.PSH && !tcp.URG { + // received a keep-alive packet. + return + } + } + case <-testTimeout: + t.Fatal("test timeout out after 5 seconds.") + } + } +} From 2540cce1d4a8c364b76ef15802c357e797968b82 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Thu, 6 Jul 2023 17:54:54 +0530 Subject: [PATCH 3/5] added unit test and removed e2e test due to cgo dependency Signed-off-by: Harshit Gangal --- go.mod | 1 - go.sum | 2 - go/mysql/server.go | 18 +++-- go/mysql/server_flaky_test.go | 36 ++++++++++ .../endtoend/vtgate/queries/misc/main_test.go | 8 +-- .../endtoend/vtgate/queries/misc/misc_test.go | 65 ------------------- 6 files changed, 50 insertions(+), 80 deletions(-) diff --git a/go.mod b/go.mod index a71cbf4c0d8..36fe98d7ec4 100644 --- a/go.mod +++ b/go.mod @@ -97,7 +97,6 @@ require ( require ( github.com/Shopify/toxiproxy/v2 v2.5.0 github.com/bndr/gotabulate v1.1.2 - github.com/google/gopacket v1.1.19 github.com/google/safehtml v0.1.0 github.com/hashicorp/go-version v1.6.0 github.com/kr/pretty v0.3.1 diff --git a/go.sum b/go.sum index b2829390e2e..746b7012d48 100644 --- a/go.sum +++ b/go.sum @@ -255,8 +255,6 @@ github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= -github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/martian v2.1.0+incompatible h1:/CP5g8u/VJHijgedC/Legn3BAbAaWPgecwXBIDzw5no= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= diff --git a/go/mysql/server.go b/go/mysql/server.go index 47cdf408feb..1aedc57ef40 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -211,6 +211,8 @@ type Listener struct { // handled further by the MySQL handler. An non-nil error will stop // processing the connection by the MySQL handler. PreHandleFunc func(context.Context, net.Conn, uint32) (net.Conn, error) + + TcpPropFunc func(*net.TCPConn, time.Duration) error } // NewFromListener creates a new mysql listener from an existing net.Listener @@ -299,6 +301,7 @@ func NewListenerWithConfig(cfg ListenerConfig) (*Listener, error) { connReadBufferSize: cfg.ConnReadBufferSize, connBufferPooling: cfg.ConnBufferPooling, connKeepAlivePeriod: cfg.ConnKeepAlivePeriod, + TcpPropFunc: setTcpConnProperties, }, nil } @@ -347,7 +350,9 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti // Enable KeepAlive on TCP connections and change keep-alive period if provided. if tcpConn, ok := conn.(*net.TCPConn); ok { - setTcpConnProperties(tcpConn, l.connKeepAlivePeriod) + if err := l.TcpPropFunc(tcpConn, l.connKeepAlivePeriod); err != nil { + log.Errorf("error in setting tcp properties: %v", err) + } } if l.connReadTimeout != 0 || l.connWriteTimeout != 0 { @@ -545,19 +550,20 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti } } -func setTcpConnProperties(conn *net.TCPConn, keepAlivePeriod time.Duration) { +func setTcpConnProperties(conn *net.TCPConn, keepAlivePeriod time.Duration) error { if err := conn.SetKeepAlive(true); err != nil { - log.Errorf("unable to enable keepalive on tcp connection: %v", err) - return + return vterrors.Wrapf(err, "unable to enable keepalive on tcp connection") } if keepAlivePeriod <= 0 { - return + return nil } if err := conn.SetKeepAlivePeriod(keepAlivePeriod); err != nil { - log.Errorf("unable to set keepalive period on tcp connection: %v", err) + return vterrors.Wrapf(err, "unable to set keepalive period on tcp connection") } + + return nil } // Close stops the listener, which prevents accept of any new connections. Existing connections won't be closed. diff --git a/go/mysql/server_flaky_test.go b/go/mysql/server_flaky_test.go index 3255e098966..8a94205145e 100644 --- a/go/mysql/server_flaky_test.go +++ b/go/mysql/server_flaky_test.go @@ -1503,3 +1503,39 @@ func TestServerFlush(t *testing.T) { require.NoError(t, err) assert.Nil(t, row) } + +func TestTcpKeepAlive(t *testing.T) { + th := &testHandler{} + l, err := NewListener("tcp", "127.0.0.1:", NewAuthServerNone(), th, 0, 0, false, false, 0) + require.NoError(t, err) + defer l.Close() + go l.Accept() + + host, port := getHostPort(t, l.Addr()) + params := &ConnParams{ + Host: host, + Port: port, + } + + var called bool + l.TcpPropFunc = func(conn *net.TCPConn, duration time.Duration) error { + called = true + return nil + } + + // on connect, the tcp method should be called. + c, err := Connect(context.Background(), params) + require.NoError(t, err) + defer c.Close() + require.True(t, called, "tcp property method not called") + + // move to original method + l.TcpPropFunc = setTcpConnProperties + + // close the connection + th.lastConn.Close() + + // now calling this method should fail. + err = setTcpConnProperties(th.lastConn.conn.(*net.TCPConn), 0) + require.ErrorContains(t, err, "unable to enable keepalive on tcp connection") +} diff --git a/go/test/endtoend/vtgate/queries/misc/main_test.go b/go/test/endtoend/vtgate/queries/misc/main_test.go index 5b175b50e97..de2d00219b6 100644 --- a/go/test/endtoend/vtgate/queries/misc/main_test.go +++ b/go/test/endtoend/vtgate/queries/misc/main_test.go @@ -61,8 +61,7 @@ func TestMain(m *testing.M) { return 1 } - clusterInstance.VtTabletExtraArgs = append(clusterInstance.VtTabletExtraArgs, - "--queryserver-config-max-result-size", "1000000", + clusterInstance.VtTabletExtraArgs = append(clusterInstance.VtTabletExtraArgs, "--queryserver-config-max-result-size", "1000000", "--queryserver-config-query-timeout", "200", "--queryserver-config-query-pool-timeout", "200") // Start Unsharded keyspace @@ -86,10 +85,7 @@ func TestMain(m *testing.M) { return 1 } - clusterInstance.VtGateExtraArgs = append(clusterInstance.VtGateExtraArgs, - "--enable_system_settings", - "--query-timeout", "100", - "--mysql-server-keepalive-period", "3s") + clusterInstance.VtGateExtraArgs = append(clusterInstance.VtGateExtraArgs, "--enable_system_settings=true", "--query-timeout=100") // Start vtgate err = clusterInstance.StartVtgate() if err != nil { diff --git a/go/test/endtoend/vtgate/queries/misc/misc_test.go b/go/test/endtoend/vtgate/queries/misc/misc_test.go index 54cc698800c..e5b50b5192c 100644 --- a/go/test/endtoend/vtgate/queries/misc/misc_test.go +++ b/go/test/endtoend/vtgate/queries/misc/misc_test.go @@ -17,23 +17,15 @@ limitations under the License. package misc import ( - "context" "database/sql" "fmt" - "net" "strconv" "strings" "testing" - "time" - _ "github.com/go-sql-driver/mysql" - "github.com/google/gopacket" - "github.com/google/gopacket/layers" - "github.com/google/gopacket/pcap" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/test/endtoend/cluster" "vitess.io/vitess/go/test/endtoend/utils" ) @@ -318,60 +310,3 @@ func TestBuggyOuterJoin(t *testing.T) { mcmp.Exec("insert into t1(id1, id2) values (1,2), (42,5), (5, 42)") mcmp.Exec("select t1.id1, t2.id1 from t1 left join t1 as t2 on t2.id1 = t2.id2") } - -func TestTCPKeepAlive(t *testing.T) { - conn, err := mysql.Connect(context.Background(), &vtParams) - require.NoError(t, err) - defer conn.Close() - - _, ok := conn.GetRawConn().(*net.TCPConn) - if !ok { - t.Fatalf("tcp connection expected, got: %T", conn.GetRawConn()) - } - - // finding a loopback device. - ifs, err := pcap.FindAllDevs() - require.NoError(t, err) - var localDevice string - for _, dev := range ifs { - if strings.Contains(dev.Name, "lo") { - localDevice = dev.Name - } - } - if localDevice == "" { - t.Skip("loopback device not found. Cannot continue with the test.") - } - - // Set up packet capture handle - handle, err := pcap.OpenLive(localDevice, 65536, true, pcap.BlockForever) - require.NoError(t, err) - defer handle.Close() - - // Set packet filter for vtgate mysql port. - expr := fmt.Sprintf("tcp and port %d", clusterInstance.VtgateMySQLPort) - err = handle.SetBPFFilter(expr) - require.NoError(t, err) - - // Start capturing packets - packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) - ch := packetSource.Packets() - - testTimeout := time.After(5 * time.Second) - for { - select { - case packet := <-ch: - // Check if packet is TCP - if tcpLayer := packet.Layer(layers.LayerTypeTCP); tcpLayer != nil { - tcp, _ := tcpLayer.(*layers.TCP) - - // Check if packet has only the ACK flag set (keep-alive packet) - if tcp.ACK && !tcp.SYN && !tcp.FIN && !tcp.RST && !tcp.PSH && !tcp.URG { - // received a keep-alive packet. - return - } - } - case <-testTimeout: - t.Fatal("test timeout out after 5 seconds.") - } - } -} From e6f4a192c9849104c037708b74c98ea14e8e6619 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Thu, 6 Jul 2023 18:50:41 +0530 Subject: [PATCH 4/5] addressed review comments Signed-off-by: Harshit Gangal --- go/mysql/server_flaky_test.go | 9 +++++++-- go/test/endtoend/vtgate/queries/misc/misc_test.go | 1 + 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/go/mysql/server_flaky_test.go b/go/mysql/server_flaky_test.go index 8a94205145e..6f03fd6797c 100644 --- a/go/mysql/server_flaky_test.go +++ b/go/mysql/server_flaky_test.go @@ -1505,10 +1505,15 @@ func TestServerFlush(t *testing.T) { } func TestTcpKeepAlive(t *testing.T) { + var origFunc func(*net.TCPConn, time.Duration) error th := &testHandler{} l, err := NewListener("tcp", "127.0.0.1:", NewAuthServerNone(), th, 0, 0, false, false, 0) require.NoError(t, err) - defer l.Close() + origFunc = l.TcpPropFunc + defer func() { + l.TcpPropFunc = origFunc + l.Close() + }() go l.Accept() host, port := getHostPort(t, l.Addr()) @@ -1530,7 +1535,7 @@ func TestTcpKeepAlive(t *testing.T) { require.True(t, called, "tcp property method not called") // move to original method - l.TcpPropFunc = setTcpConnProperties + l.TcpPropFunc = origFunc // close the connection th.lastConn.Close() diff --git a/go/test/endtoend/vtgate/queries/misc/misc_test.go b/go/test/endtoend/vtgate/queries/misc/misc_test.go index e5b50b5192c..710f4934786 100644 --- a/go/test/endtoend/vtgate/queries/misc/misc_test.go +++ b/go/test/endtoend/vtgate/queries/misc/misc_test.go @@ -23,6 +23,7 @@ import ( "strings" "testing" + _ "github.com/go-sql-driver/mysql" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" From 0c9ff600c370d55f5c44bee51e0d9cc633751630 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Thu, 6 Jul 2023 20:51:59 +0530 Subject: [PATCH 5/5] introduce a marker on conn to know if keepalive is on, used for testing Signed-off-by: Harshit Gangal --- go/mysql/conn.go | 33 ++++++++++++++++++++++++++++++++- go/mysql/server.go | 27 --------------------------- go/mysql/server_flaky_test.go | 18 ++---------------- 3 files changed, 34 insertions(+), 44 deletions(-) diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 9fb47da189e..b9aa2c95b18 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -199,6 +199,10 @@ type Conn struct { // enableQueryInfo controls whether we parse the INFO field in QUERY_OK packets // See: ConnParams.EnableQueryInfo enableQueryInfo bool + + // keepAliveOn marks when keep alive is active on the connection. + // This is currently used for testing. + keepAliveOn bool } // splitStatementFunciton is the function that is used to split the statement in case of a multi-statement query. @@ -246,10 +250,21 @@ func newConn(conn net.Conn) *Conn { // the server is shutting down, and has the ability to control buffer // size for reads. func newServerConn(conn net.Conn, listener *Listener) *Conn { + // Enable KeepAlive on TCP connections and change keep-alive period if provided. + enabledKeepAlive := false + if tcpConn, ok := conn.(*net.TCPConn); ok { + if err := setTcpConnProperties(tcpConn, listener.connKeepAlivePeriod); err != nil { + log.Errorf("error in setting tcp properties: %v", err) + } else { + enabledKeepAlive = true + } + } + c := &Conn{ conn: conn, listener: listener, PrepareData: make(map[uint32]*PrepareData), + keepAliveOn: enabledKeepAlive, } if listener.connReadBufferSize > 0 { @@ -267,6 +282,22 @@ func newServerConn(conn net.Conn, listener *Listener) *Conn { return c } +func setTcpConnProperties(conn *net.TCPConn, keepAlivePeriod time.Duration) error { + if err := conn.SetKeepAlive(true); err != nil { + return vterrors.Wrapf(err, "unable to enable keepalive on tcp connection") + } + + if keepAlivePeriod <= 0 { + return nil + } + + if err := conn.SetKeepAlivePeriod(keepAlivePeriod); err != nil { + return vterrors.Wrapf(err, "unable to set keepalive period on tcp connection") + } + + return nil +} + // startWriterBuffering starts using buffered writes. This should // be terminated by a call to endWriteBuffering. func (c *Conn) startWriterBuffering() { @@ -767,7 +798,7 @@ func (c *Conn) writeOKPacketWithHeader(packetOk *PacketOK, headerType byte) erro bytes, pos := c.startEphemeralPacketWithHeader(length) data := &coder{data: bytes, pos: pos} - data.writeByte(headerType) //header - OK or EOF + data.writeByte(headerType) // header - OK or EOF data.writeLenEncInt(packetOk.affectedRows) data.writeLenEncInt(packetOk.lastInsertID) data.writeUint16(packetOk.statusFlags) diff --git a/go/mysql/server.go b/go/mysql/server.go index 1aedc57ef40..7b97860defa 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -211,8 +211,6 @@ type Listener struct { // handled further by the MySQL handler. An non-nil error will stop // processing the connection by the MySQL handler. PreHandleFunc func(context.Context, net.Conn, uint32) (net.Conn, error) - - TcpPropFunc func(*net.TCPConn, time.Duration) error } // NewFromListener creates a new mysql listener from an existing net.Listener @@ -301,7 +299,6 @@ func NewListenerWithConfig(cfg ListenerConfig) (*Listener, error) { connReadBufferSize: cfg.ConnReadBufferSize, connBufferPooling: cfg.ConnBufferPooling, connKeepAlivePeriod: cfg.ConnKeepAlivePeriod, - TcpPropFunc: setTcpConnProperties, }, nil } @@ -347,14 +344,6 @@ func (l *Listener) Accept() { // handle is called in a go routine for each client connection. // FIXME(alainjobart) handle per-connection logs in a way that makes sense. func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Time) { - - // Enable KeepAlive on TCP connections and change keep-alive period if provided. - if tcpConn, ok := conn.(*net.TCPConn); ok { - if err := l.TcpPropFunc(tcpConn, l.connKeepAlivePeriod); err != nil { - log.Errorf("error in setting tcp properties: %v", err) - } - } - if l.connReadTimeout != 0 || l.connWriteTimeout != 0 { conn = netutil.NewConnWithTimeouts(conn, l.connReadTimeout, l.connWriteTimeout) } @@ -550,22 +539,6 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti } } -func setTcpConnProperties(conn *net.TCPConn, keepAlivePeriod time.Duration) error { - if err := conn.SetKeepAlive(true); err != nil { - return vterrors.Wrapf(err, "unable to enable keepalive on tcp connection") - } - - if keepAlivePeriod <= 0 { - return nil - } - - if err := conn.SetKeepAlivePeriod(keepAlivePeriod); err != nil { - return vterrors.Wrapf(err, "unable to set keepalive period on tcp connection") - } - - return nil -} - // Close stops the listener, which prevents accept of any new connections. Existing connections won't be closed. func (l *Listener) Close() { l.listener.Close() diff --git a/go/mysql/server_flaky_test.go b/go/mysql/server_flaky_test.go index 6f03fd6797c..4b0bb936fc9 100644 --- a/go/mysql/server_flaky_test.go +++ b/go/mysql/server_flaky_test.go @@ -1505,15 +1505,10 @@ func TestServerFlush(t *testing.T) { } func TestTcpKeepAlive(t *testing.T) { - var origFunc func(*net.TCPConn, time.Duration) error th := &testHandler{} l, err := NewListener("tcp", "127.0.0.1:", NewAuthServerNone(), th, 0, 0, false, false, 0) require.NoError(t, err) - origFunc = l.TcpPropFunc - defer func() { - l.TcpPropFunc = origFunc - l.Close() - }() + defer l.Close() go l.Accept() host, port := getHostPort(t, l.Addr()) @@ -1522,20 +1517,11 @@ func TestTcpKeepAlive(t *testing.T) { Port: port, } - var called bool - l.TcpPropFunc = func(conn *net.TCPConn, duration time.Duration) error { - called = true - return nil - } - // on connect, the tcp method should be called. c, err := Connect(context.Background(), params) require.NoError(t, err) defer c.Close() - require.True(t, called, "tcp property method not called") - - // move to original method - l.TcpPropFunc = origFunc + require.True(t, th.lastConn.keepAliveOn, "tcp property method not called") // close the connection th.lastConn.Close()