Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Tcp keep alive and provide keep alive period setting #13434

Merged
merged 7 commits into from
Jul 11, 2023
1 change: 1 addition & 0 deletions go/flags/endtoend/vtgate.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ Usage of vtgate:
--max_payload_size int The threshold for query payloads in bytes. A payload greater than this threshold will result in a failure to handle the query.
--message_stream_grace_period duration the amount of time to give for a vttablet to resume if it ends a message stream, usually because of a reparent. (default 30s)
--min_number_serving_vttablets int The minimum number of vttablets for each replicating tablet_type (e.g. replica, rdonly) that will be continue to be used even with replication lag above discovery_low_replication_lag, but still below discovery_high_replication_lag_minimum_serving. (default 2)
--mysql-server-keepalive-period duration TCP period between keep-alives
--mysql-server-pool-conn-read-buffers If set, the server will pool incoming connection read buffers
--mysql_allow_clear_text_without_tls If set, the server will allow the use of a clear text password over non-SSL connections.
--mysql_auth_server_impl string Which auth server implementation to use. Options: none, ldap, clientcert, static, vault. (default "static")
Expand Down
4 changes: 2 additions & 2 deletions go/mysql/auth_server_clientcert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func TestValidCert(t *testing.T) {
authServer := newAuthServerClientCert()

// Create the listener, so we can get its host.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
Expand Down Expand Up @@ -114,7 +114,7 @@ func TestNoCert(t *testing.T) {
authServer := newAuthServerClientCert()

// Create the listener, so we can get its host.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
Expand Down
10 changes: 5 additions & 5 deletions go/mysql/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func TestTLSClientDisabled(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -221,7 +221,7 @@ func TestTLSClientPreferredDefault(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -294,7 +294,7 @@ func TestTLSClientRequired(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -341,7 +341,7 @@ func TestTLSClientVerifyCA(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
require.NoError(t, err)
defer l.Close()

Expand Down Expand Up @@ -424,7 +424,7 @@ func TestTLSClientVerifyIdentity(t *testing.T) {
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
require.NoError(t, err)
defer l.Close()

Expand Down
31 changes: 31 additions & 0 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ type Conn struct {
// See: ConnParams.EnableQueryInfo
enableQueryInfo bool

// keepAliveOn marks when keep alive is active on the connection.
// This is currently used for testing.
keepAliveOn bool

// mu protects the fields below
mu sync.Mutex
// cancel keep the cancel function for the current executing query.
Expand Down Expand Up @@ -254,10 +258,21 @@ func newConn(conn net.Conn) *Conn {
// the server is shutting down, and has the ability to control buffer
// size for reads.
func newServerConn(conn net.Conn, listener *Listener) *Conn {
// Enable KeepAlive on TCP connections and change keep-alive period if provided.
enabledKeepAlive := false
if tcpConn, ok := conn.(*net.TCPConn); ok {
if err := setTcpConnProperties(tcpConn, listener.connKeepAlivePeriod); err != nil {
log.Errorf("error in setting tcp properties: %v", err)
} else {
enabledKeepAlive = true
}
}

c := &Conn{
conn: conn,
listener: listener,
PrepareData: make(map[uint32]*PrepareData),
keepAliveOn: enabledKeepAlive,
}

if listener.connReadBufferSize > 0 {
Expand All @@ -275,6 +290,22 @@ func newServerConn(conn net.Conn, listener *Listener) *Conn {
return c
}

func setTcpConnProperties(conn *net.TCPConn, keepAlivePeriod time.Duration) error {
if err := conn.SetKeepAlive(true); err != nil {
return vterrors.Wrapf(err, "unable to enable keepalive on tcp connection")
}

if keepAlivePeriod <= 0 {
return nil
}

if err := conn.SetKeepAlivePeriod(keepAlivePeriod); err != nil {
return vterrors.Wrapf(err, "unable to set keepalive period on tcp connection")
}

return nil
}

// startWriterBuffering starts using buffered writes. This should
// be terminated by a call to endWriteBuffering.
func (c *Conn) startWriterBuffering() {
Expand Down
6 changes: 3 additions & 3 deletions go/mysql/fakesqldb/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ func New(t testing.TB) *DB {
authServer := mysql.NewAuthServerNone()

// Start listening.
db.listener, err = mysql.NewListener("unix", socketFile, authServer, db, 0, 0, false, false)
db.listener, err = mysql.NewListener("unix", socketFile, authServer, db, 0, 0, false, false, 0)
if err != nil {
t.Fatalf("NewListener failed: %v", err)
}
Expand Down Expand Up @@ -382,7 +382,7 @@ func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.R
if db.shouldClose.Load() {
c.Close()

//log error
// log error
if err := callback(&sqltypes.Result{}); err != nil {
log.Errorf("callback failed : %v", err)
}
Expand All @@ -393,7 +393,7 @@ func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.R
// The driver may send this at connection time, and we don't want it to
// interfere.
if key == "set names utf8" || strings.HasPrefix(key, "set collation_connection = ") {
//log error
// log error
if err := callback(&sqltypes.Result{}); err != nil {
log.Errorf("callback failed : %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions go/mysql/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func TestClearTextClientAuth(t *testing.T) {
defer authServer.close()

// Create the listener.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
Expand Down Expand Up @@ -99,7 +99,7 @@ func TestSSLConnection(t *testing.T) {
defer authServer.close()

// Create the listener, so we can get its host.
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false)
l, err := NewListener("tcp", "127.0.0.1:", authServer, th, 0, 0, false, false, 0)
require.NoError(t, err, "NewListener failed: %v", err)
defer l.Close()
host := l.Addr().(*net.TCPAddr).IP.String()
Expand Down
62 changes: 35 additions & 27 deletions go/mysql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ type Listener struct {
// connBufferPooling configures if vtgate server pools connection buffers
connBufferPooling bool

// connKeepAlivePeriod is period between tcp keep-alives.
connKeepAlivePeriod time.Duration

// shutdown indicates that Shutdown method was called.
shutdown atomic.Bool

Expand All @@ -216,15 +219,17 @@ func NewFromListener(
connReadTimeout time.Duration,
connWriteTimeout time.Duration,
connBufferPooling bool,
keepAlivePeriod time.Duration,
) (*Listener, error) {
cfg := ListenerConfig{
Listener: l,
AuthServer: authServer,
Handler: handler,
ConnReadTimeout: connReadTimeout,
ConnWriteTimeout: connWriteTimeout,
ConnReadBufferSize: connBufferSize,
ConnBufferPooling: connBufferPooling,
Listener: l,
AuthServer: authServer,
Handler: handler,
ConnReadTimeout: connReadTimeout,
ConnWriteTimeout: connWriteTimeout,
ConnReadBufferSize: connBufferSize,
ConnBufferPooling: connBufferPooling,
ConnKeepAlivePeriod: keepAlivePeriod,
}
return NewListenerWithConfig(cfg)
}
Expand All @@ -238,31 +243,33 @@ func NewListener(
connWriteTimeout time.Duration,
proxyProtocol bool,
connBufferPooling bool,
keepAlivePeriod time.Duration,
) (*Listener, error) {
listener, err := net.Listen(protocol, address)
if err != nil {
return nil, err
}
if proxyProtocol {
proxyListener := &proxyproto.Listener{Listener: listener}
return NewFromListener(proxyListener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling)
return NewFromListener(proxyListener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling, keepAlivePeriod)
}

return NewFromListener(listener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling)
return NewFromListener(listener, authServer, handler, connReadTimeout, connWriteTimeout, connBufferPooling, keepAlivePeriod)
}

// ListenerConfig should be used with NewListenerWithConfig to specify listener parameters.
type ListenerConfig struct {
// Protocol-Address pair and Listener are mutually exclusive parameters
Protocol string
Address string
Listener net.Listener
AuthServer AuthServer
Handler Handler
ConnReadTimeout time.Duration
ConnWriteTimeout time.Duration
ConnReadBufferSize int
ConnBufferPooling bool
Protocol string
Address string
Listener net.Listener
AuthServer AuthServer
Handler Handler
ConnReadTimeout time.Duration
ConnWriteTimeout time.Duration
ConnReadBufferSize int
ConnBufferPooling bool
ConnKeepAlivePeriod time.Duration
}

// NewListenerWithConfig creates new listener using provided config. There are
Expand All @@ -280,15 +287,16 @@ func NewListenerWithConfig(cfg ListenerConfig) (*Listener, error) {
}

return &Listener{
authServer: cfg.AuthServer,
handler: cfg.Handler,
listener: l,
ServerVersion: servenv.AppVersion.MySQLVersion(),
connectionID: 1,
connReadTimeout: cfg.ConnReadTimeout,
connWriteTimeout: cfg.ConnWriteTimeout,
connReadBufferSize: cfg.ConnReadBufferSize,
connBufferPooling: cfg.ConnBufferPooling,
authServer: cfg.AuthServer,
handler: cfg.Handler,
listener: l,
ServerVersion: servenv.AppVersion.MySQLVersion(),
connectionID: 1,
connReadTimeout: cfg.ConnReadTimeout,
connWriteTimeout: cfg.ConnWriteTimeout,
connReadBufferSize: cfg.ConnReadBufferSize,
connBufferPooling: cfg.ConnBufferPooling,
connKeepAlivePeriod: cfg.ConnKeepAlivePeriod,
}, nil
}

Expand Down
Loading
Loading