diff --git a/go/vt/vtgateproxy/mysql_server.go b/go/vt/vtgateproxy/mysql_server.go index 4f00c01bb0d..26c4422f28b 100644 --- a/go/vt/vtgateproxy/mysql_server.go +++ b/go/vt/vtgateproxy/mysql_server.go @@ -100,17 +100,7 @@ func (ph *proxyHandler) NewConnection(c *mysql.Conn) { func (ph *proxyHandler) ComResetConnection(c *mysql.Conn) { ctx := context.Background() - session, err := ph.getSession(ctx, c) - if err != nil { - return - } - if session.SessionPb().InTransaction { - defer atomic.AddInt32(&busyConnections, -1) - } - err = ph.proxy.CloseSession(ctx, session) - if err != nil { - log.Errorf("Error happened in transaction rollback: %v", err) - } + ph.closeSession(ctx, c) } func (ph *proxyHandler) ConnectionClosed(c *mysql.Conn) { @@ -128,14 +118,7 @@ func (ph *proxyHandler) ConnectionClosed(c *mysql.Conn) { } else { ctx = context.Background() } - session, err := ph.getSession(ctx, c) - if err != nil { - return - } - if session.SessionPb().InTransaction { - defer atomic.AddInt32(&busyConnections, -1) - } - _ = ph.proxy.CloseSession(ctx, session) + ph.closeSession(ctx, c) } // Regexp to extract parent span id over the sql query @@ -378,6 +361,23 @@ func (ph *proxyHandler) getSession(ctx context.Context, c *mysql.Conn) (*vtgatec return session, nil } +func (ph *proxyHandler) closeSession(ctx context.Context, c *mysql.Conn) { + session, _ := c.ClientData.(*vtgateconn.VTGateSession) + if session == nil { + return // no active session + } + + if session.SessionPb().InTransaction { + defer atomic.AddInt32(&busyConnections, -1) + } + err := ph.proxy.CloseSession(ctx, session) + if err != nil { + log.Errorf("Error happened in transaction rollback: %v", err) + } + + c.ClientData = nil +} + var mysqlListener *mysql.Listener var mysqlUnixListener *mysql.Listener var sigChan chan os.Signal