diff --git a/src/database/sql/driver/driver.go b/src/database/sql/driver/driver.go index 316e7cea375417..a0ba7ecf694bfb 100644 --- a/src/database/sql/driver/driver.go +++ b/src/database/sql/driver/driver.go @@ -255,12 +255,9 @@ type ConnBeginTx interface { // SessionResetter may be implemented by Conn to allow drivers to reset the // session state associated with the connection and to signal a bad connection. type SessionResetter interface { - // ResetSession is called while a connection is in the connection - // pool. No queries will run on this connection until this method returns. - // - // If the connection is bad this should return driver.ErrBadConn to prevent - // the connection from being returned to the connection pool. Any other - // error will be discarded. + // ResetSession is called prior to executing a query on the connection + // if the connection has been used before. If the driver returns ErrBadConn + // the connection is discarded. ResetSession(ctx context.Context) error } diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go index c0371f3e784886..378ed77c597be1 100644 --- a/src/database/sql/fakedb_test.go +++ b/src/database/sql/fakedb_test.go @@ -393,12 +393,19 @@ func setStrictFakeConnClose(t *testing.T) { func (c *fakeConn) ResetSession(ctx context.Context) error { c.dirtySession = false + c.currTx = nil if c.isBad() { return driver.ErrBadConn } return nil } +var _ validator = (*fakeConn)(nil) + +func (c *fakeConn) IsValid() bool { + return !c.isBad() +} + func (c *fakeConn) Close() (err error) { drv := fdriver.(*fakeDriver) defer func() { @@ -731,6 +738,9 @@ var hookExecBadConn func() bool func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { panic("Using ExecContext") } + +var errFakeConnSessionDirty = errors.New("fakedb: session is dirty") + func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { if s.panic == "Exec" { panic(s.panic) @@ -743,7 +753,7 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d return nil, driver.ErrBadConn } if s.c.isDirtyAndMark() { - return nil, errors.New("fakedb: session is dirty") + return nil, errFakeConnSessionDirty } err := checkSubsetTypes(s.c.db.allowAny, args) @@ -857,7 +867,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) ( return nil, driver.ErrBadConn } if s.c.isDirtyAndMark() { - return nil, errors.New("fakedb: session is dirty") + return nil, errFakeConnSessionDirty } err := checkSubsetTypes(s.c.db.allowAny, args) @@ -890,6 +900,37 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) ( } } } + if s.table == "tx_status" && s.colName[0] == "tx_status" { + txStatus := "autocommit" + if s.c.currTx != nil { + txStatus = "transaction" + } + cursor := &rowsCursor{ + parentMem: s.c, + posRow: -1, + rows: [][]*row{ + []*row{ + { + cols: []interface{}{ + txStatus, + }, + }, + }, + }, + cols: [][]string{ + []string{ + "tx_status", + }, + }, + colType: [][]string{ + []string{ + "string", + }, + }, + errPos: -1, + } + return cursor, nil + } t.mu.Lock() diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 5c5b7dc7e973e9..25e241859cbba9 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -421,7 +421,6 @@ type DB struct { // It is closed during db.Close(). The close tells the connectionOpener // goroutine to exit. openerCh chan struct{} - resetterCh chan *driverConn closed bool dep map[finalCloser]depSet lastPut map[*driverConn]string // stacktrace of last conn's put; debug only @@ -458,10 +457,10 @@ type driverConn struct { sync.Mutex // guards following ci driver.Conn + needReset bool // The connection session should be reset before use if true. closed bool finalClosed bool // ci.Close has been called openStmt map[*driverStmt]bool - lastErr error // lastError captures the result of the session resetter. // guarded by db.mu inUse bool @@ -486,6 +485,41 @@ func (dc *driverConn) expired(timeout time.Duration) bool { return dc.createdAt.Add(timeout).Before(nowFunc()) } +// resetSession checks if the driver connection needs the +// session to be reset and if required, resets it. +func (dc *driverConn) resetSession(ctx context.Context) error { + dc.Lock() + defer dc.Unlock() + + if !dc.needReset { + return nil + } + if cr, ok := dc.ci.(driver.SessionResetter); ok { + return cr.ResetSession(ctx) + } + return nil +} + +// validator was introduced for Go1.15, but backported to Go1.14. +type validator interface { + IsValid() bool +} + +// validateConnection checks if the connection is valid and can +// still be used. It also marks the session for reset if required. +func (dc *driverConn) validateConnection(needsReset bool) bool { + dc.Lock() + defer dc.Unlock() + + if needsReset { + dc.needReset = true + } + if cv, ok := dc.ci.(validator); ok { + return cv.IsValid() + } + return true +} + // prepareLocked prepares the query on dc. When cg == nil the dc must keep track of // the prepared statements in a pool. func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, query string) (*driverStmt, error) { @@ -511,19 +545,6 @@ func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, que return ds, nil } -// resetSession resets the connection session and sets the lastErr -// that is checked before returning the connection to another query. -// -// resetSession assumes that the embedded mutex is locked when the connection -// was returned to the pool. This unlocks the mutex. -func (dc *driverConn) resetSession(ctx context.Context) { - defer dc.Unlock() // In case of panic. - if dc.closed { // Check if the database has been closed. - return - } - dc.lastErr = dc.ci.(driver.SessionResetter).ResetSession(ctx) -} - // the dc.db's Mutex is held. func (dc *driverConn) closeDBLocked() func() error { dc.Lock() @@ -713,14 +734,12 @@ func OpenDB(c driver.Connector) *DB { db := &DB{ connector: c, openerCh: make(chan struct{}, connectionRequestQueueSize), - resetterCh: make(chan *driverConn, 50), lastPut: make(map[*driverConn]string), connRequests: make(map[uint64]chan connRequest), stop: cancel, } go db.connectionOpener(ctx) - go db.connectionResetter(ctx) return db } @@ -1058,23 +1077,6 @@ func (db *DB) connectionOpener(ctx context.Context) { } } -// connectionResetter runs in a separate goroutine to reset connections async -// to exported API. -func (db *DB) connectionResetter(ctx context.Context) { - for { - select { - case <-ctx.Done(): - close(db.resetterCh) - for dc := range db.resetterCh { - dc.Unlock() - } - return - case dc := <-db.resetterCh: - dc.resetSession(ctx) - } - } -} - // Open one new connection func (db *DB) openNewConnection(ctx context.Context) { // maybeOpenNewConnctions has already executed db.numOpen++ before it sent @@ -1155,14 +1157,13 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn conn.Close() return nil, driver.ErrBadConn } - // Lock around reading lastErr to ensure the session resetter finished. - conn.Lock() - err := conn.lastErr - conn.Unlock() - if err == driver.ErrBadConn { + + // Reset the session if required. + if err := conn.resetSession(ctx); err == driver.ErrBadConn { conn.Close() return nil, driver.ErrBadConn } + return conn, nil } @@ -1204,18 +1205,22 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn if !ok { return nil, errDBClosed } - if ret.err == nil && ret.conn.expired(lifetime) { + // Only check if the connection is expired if the strategy is cachedOrNewConns. + // If we require a new connection, just re-use the connection without looking + // at the expiry time. If it is expired, it will be checked when it is placed + // back into the connection pool. + // This prioritizes giving a valid connection to a client over the exact connection + // lifetime, which could expire exactly after this point anyway. + if strategy == cachedOrNewConn && ret.err == nil && ret.conn.expired(lifetime) { ret.conn.Close() return nil, driver.ErrBadConn } if ret.conn == nil { return nil, ret.err } - // Lock around reading lastErr to ensure the session resetter finished. - ret.conn.Lock() - err := ret.conn.lastErr - ret.conn.Unlock() - if err == driver.ErrBadConn { + + // Reset the session if required. + if err := ret.conn.resetSession(ctx); err == driver.ErrBadConn { ret.conn.Close() return nil, driver.ErrBadConn } @@ -1275,13 +1280,23 @@ const debugGetPut = false // putConn adds a connection to the db's free pool. // err is optionally the last error that occurred on this connection. func (db *DB) putConn(dc *driverConn, err error, resetSession bool) { + if err != driver.ErrBadConn { + if !dc.validateConnection(resetSession) { + err = driver.ErrBadConn + } + } db.mu.Lock() if !dc.inUse { + db.mu.Unlock() if debugGetPut { fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", dc, stack(), db.lastPut[dc]) } panic("sql: connection returned that was never out") } + + if err != driver.ErrBadConn && dc.expired(db.maxLifetime) { + err = driver.ErrBadConn + } if debugGetPut { db.lastPut[dc] = stack() } @@ -1305,41 +1320,13 @@ func (db *DB) putConn(dc *driverConn, err error, resetSession bool) { if putConnHook != nil { putConnHook(db, dc) } - if db.closed { - // Connections do not need to be reset if they will be closed. - // Prevents writing to resetterCh after the DB has closed. - resetSession = false - } - if resetSession { - if _, resetSession = dc.ci.(driver.SessionResetter); resetSession { - // Lock the driverConn here so it isn't released until - // the connection is reset. - // The lock must be taken before the connection is put into - // the pool to prevent it from being taken out before it is reset. - dc.Lock() - } - } added := db.putConnDBLocked(dc, nil) db.mu.Unlock() if !added { - if resetSession { - dc.Unlock() - } dc.Close() return } - if !resetSession { - return - } - select { - default: - // If the resetterCh is blocking then mark the connection - // as bad and continue on. - dc.lastErr = driver.ErrBadConn - dc.Unlock() - case db.resetterCh <- dc: - } } // Satisfy a connRequest or put the driverConn in the idle pool and return true @@ -1701,7 +1688,11 @@ func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStra // beginDC starts a transaction. The provided dc must be valid and ready to use. func (db *DB) beginDC(ctx context.Context, dc *driverConn, release func(error), opts *TxOptions) (tx *Tx, err error) { var txi driver.Tx + keepConnOnRollback := false withLock(dc, func() { + _, hasSessionResetter := dc.ci.(driver.SessionResetter) + _, hasConnectionValidator := dc.ci.(validator) + keepConnOnRollback = hasSessionResetter && hasConnectionValidator txi, err = ctxDriverBegin(ctx, opts, dc.ci) }) if err != nil { @@ -1713,12 +1704,13 @@ func (db *DB) beginDC(ctx context.Context, dc *driverConn, release func(error), // The cancel function in Tx will be called after done is set to true. ctx, cancel := context.WithCancel(ctx) tx = &Tx{ - db: db, - dc: dc, - releaseConn: release, - txi: txi, - cancel: cancel, - ctx: ctx, + db: db, + dc: dc, + releaseConn: release, + txi: txi, + cancel: cancel, + keepConnOnRollback: keepConnOnRollback, + ctx: ctx, } go tx.awaitDone() return tx, nil @@ -1980,6 +1972,11 @@ type Tx struct { // Use atomic operations on value when checking value. done int32 + // keepConnOnRollback is true if the driver knows + // how to reset the connection's session and if need be discard + // the connection. + keepConnOnRollback bool + // All Stmts prepared for this transaction. These will be closed after the // transaction has been committed or rolled back. stmts struct { @@ -2005,7 +2002,10 @@ func (tx *Tx) awaitDone() { // transaction is closed and the resources are released. This // rollback does nothing if the transaction has already been // committed or rolled back. - tx.rollback(true) + // Do not discard the connection if the connection knows + // how to reset the session. + discardConnection := !tx.keepConnOnRollback + tx.rollback(discardConnection) } func (tx *Tx) isDone() bool { @@ -2016,14 +2016,10 @@ func (tx *Tx) isDone() bool { // that has already been committed or rolled back. var ErrTxDone = errors.New("sql: transaction has already been committed or rolled back") -// close returns the connection to the pool and -// must only be called by Tx.rollback or Tx.Commit. -func (tx *Tx) close(err error) { - tx.cancel() - - tx.closemu.Lock() - defer tx.closemu.Unlock() - +// closeLocked returns the connection to the pool and +// must only be called by Tx.rollback or Tx.Commit while +// closemu is Locked and tx already canceled. +func (tx *Tx) closeLocked(err error) { tx.releaseConn(err) tx.dc = nil tx.txi = nil @@ -2090,6 +2086,15 @@ func (tx *Tx) Commit() error { if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) { return ErrTxDone } + + // Cancel the Tx to release any active R-closemu locks. + // This is safe to do because tx.done has already transitioned + // from 0 to 1. Hold the W-closemu lock prior to rollback + // to ensure no other connection has an active query. + tx.cancel() + tx.closemu.Lock() + defer tx.closemu.Unlock() + var err error withLock(tx.dc, func() { err = tx.txi.Commit() @@ -2097,16 +2102,31 @@ func (tx *Tx) Commit() error { if err != driver.ErrBadConn { tx.closePrepared() } - tx.close(err) + tx.closeLocked(err) return err } +var rollbackHook func() + // rollback aborts the transaction and optionally forces the pool to discard // the connection. func (tx *Tx) rollback(discardConn bool) error { if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) { return ErrTxDone } + + if rollbackHook != nil { + rollbackHook() + } + + // Cancel the Tx to release any active R-closemu locks. + // This is safe to do because tx.done has already transitioned + // from 0 to 1. Hold the W-closemu lock prior to rollback + // to ensure no other connection has an active query. + tx.cancel() + tx.closemu.Lock() + defer tx.closemu.Unlock() + var err error withLock(tx.dc, func() { err = tx.txi.Rollback() @@ -2117,7 +2137,7 @@ func (tx *Tx) rollback(discardConn bool) error { if discardConn { err = driver.ErrBadConn } - tx.close(err) + tx.closeLocked(err) return err } diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index f68cefe43aed5a..a799093ff9e7e3 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -80,6 +80,11 @@ func newTestDBConnector(t testing.TB, fc *fakeConnector, name string) *DB { exec(t, db, "CREATE|magicquery|op=string,millis=int32") exec(t, db, "INSERT|magicquery|op=sleep,millis=10") } + if name == "tx_status" { + // Magic table name and column, known by fakedb_test.go. + exec(t, db, "CREATE|tx_status|tx_status=string") + exec(t, db, "INSERT|tx_status|tx_status=invalid") + } return db } @@ -437,6 +442,7 @@ func TestTxContextWait(t *testing.T) { } t.Fatal(err) } + tx.keepConnOnRollback = false // This will trigger the *fakeConn.Prepare method which will take time // performing the query. The ctxDriverPrepare func will check the context @@ -449,6 +455,35 @@ func TestTxContextWait(t *testing.T) { waitForFree(t, db, 5*time.Second, 0) } +// TestTxContextWaitNoDiscard is the same as TestTxContextWait, but should not discard +// the final connection. +func TestTxContextWaitNoDiscard(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Millisecond) + defer cancel() + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + // Guard against the context being canceled before BeginTx completes. + if err == context.DeadlineExceeded { + t.Skip("tx context canceled prior to first use") + } + t.Fatal(err) + } + + // This will trigger the *fakeConn.Prepare method which will take time + // performing the query. The ctxDriverPrepare func will check the context + // after this and close the rows and return an error. + _, err = tx.QueryContext(ctx, "WAIT|1s|SELECT|people|age,name|") + if err != context.DeadlineExceeded { + t.Fatalf("expected QueryContext to error with context deadline exceeded but returned %v", err) + } + + waitForFree(t, db, 5*time.Second, 1) +} + // TestUnsupportedOptions checks that the database fails when a driver that // doesn't implement ConnBeginTx is used with non-default options and an // un-cancellable context. @@ -1521,6 +1556,37 @@ func TestConnTx(t *testing.T) { } } +// TestConnIsValid verifies that a database connection that should be discarded, +// is actually discarded and does not re-enter the connection pool. +// If the IsValid method from *fakeConn is removed, this test will fail. +func TestConnIsValid(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + db.SetMaxOpenConns(1) + + ctx := context.Background() + + c, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + + err = c.Raw(func(raw interface{}) error { + dc := raw.(*fakeConn) + dc.stickyBad = true + return nil + }) + if err != nil { + t.Fatal(err) + } + c.Close() + + if len(db.freeConn) > 0 && db.freeConn[0].ci.(*fakeConn).stickyBad { + t.Fatal("bad connection returned to pool; expected bad connection to be discarded") + } +} + // Tests fix for issue 2542, that we release a lock when querying on // a closed connection. func TestIssue2542Deadlock(t *testing.T) { @@ -2654,6 +2720,159 @@ func TestManyErrBadConn(t *testing.T) { } } +// Issue 34755: Ensure that a Tx cannot commit after a rollback. +func TestTxCannotCommitAfterRollback(t *testing.T) { + db := newTestDB(t, "tx_status") + defer closeDB(t, db) + + // First check query reporting is correct. + var txStatus string + err := db.QueryRow("SELECT|tx_status|tx_status|").Scan(&txStatus) + if err != nil { + t.Fatal(err) + } + if g, w := txStatus, "autocommit"; g != w { + t.Fatalf("tx_status=%q, wanted %q", g, w) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Fatal(err) + } + + // Ignore dirty session for this test. + // A failing test should trigger the dirty session flag as well, + // but that isn't exactly what this should test for. + tx.txi.(*fakeTx).c.skipDirtySession = true + + defer tx.Rollback() + + err = tx.QueryRow("SELECT|tx_status|tx_status|").Scan(&txStatus) + if err != nil { + t.Fatal(err) + } + if g, w := txStatus, "transaction"; g != w { + t.Fatalf("tx_status=%q, wanted %q", g, w) + } + + // 1. Begin a transaction. + // 2. (A) Start a query, (B) begin Tx rollback through a ctx cancel. + // 3. Check if 2.A has committed in Tx (pass) or outside of Tx (fail). + sendQuery := make(chan struct{}) + hookTxGrabConn = func() { + cancel() + <-sendQuery + } + rollbackHook = func() { + close(sendQuery) + } + defer func() { + hookTxGrabConn = nil + rollbackHook = nil + }() + + err = tx.QueryRow("SELECT|tx_status|tx_status|").Scan(&txStatus) + if err != nil { + // A failure here would be expected if skipDirtySession was not set to true above. + t.Fatal(err) + } + if g, w := txStatus, "transaction"; g != w { + t.Fatalf("tx_status=%q, wanted %q", g, w) + } +} + +// Issue32530 encounters an issue where a connection may +// expire right after it comes out of a used connection pool +// even when a new connection is requested. +func TestConnExpiresFreshOutOfPool(t *testing.T) { + execCases := []struct { + expired bool + badReset bool + }{ + {false, false}, + {true, false}, + {false, true}, + } + + t0 := time.Unix(1000000, 0) + offset := time.Duration(0) + offsetMu := sync.RWMutex{} + + nowFunc = func() time.Time { + offsetMu.RLock() + defer offsetMu.RUnlock() + return t0.Add(offset) + } + defer func() { nowFunc = time.Now }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + db := newTestDB(t, "magicquery") + defer closeDB(t, db) + + db.SetMaxOpenConns(1) + + for _, ec := range execCases { + ec := ec + name := fmt.Sprintf("expired=%t,badReset=%t", ec.expired, ec.badReset) + t.Run(name, func(t *testing.T) { + db.clearAllConns(t) + + db.SetMaxIdleConns(1) + db.SetConnMaxLifetime(10 * time.Second) + + conn, err := db.conn(ctx, alwaysNewConn) + if err != nil { + t.Fatal(err) + } + + afterPutConn := make(chan struct{}) + waitingForConn := make(chan struct{}) + + go func() { + conn, err := db.conn(ctx, alwaysNewConn) + if err != nil { + t.Fatal(err) + } + db.putConn(conn, err, false) + close(afterPutConn) + }() + go func() { + for { + db.mu.Lock() + ct := len(db.connRequests) + db.mu.Unlock() + if ct > 0 { + close(waitingForConn) + return + } + time.Sleep(10 * time.Millisecond) + } + }() + + <-waitingForConn + + offsetMu.Lock() + if ec.expired { + offset = 11 * time.Second + } else { + offset = time.Duration(0) + } + offsetMu.Unlock() + + conn.ci.(*fakeConn).stickyBad = ec.badReset + + db.putConn(conn, err, true) + + <-afterPutConn + }) + } +} + // TestIssue20575 ensures the Rows from query does not block // closing a transaction. Ensure Rows is closed while closing a trasaction. func TestIssue20575(t *testing.T) {