Skip to content

Commit 2b283ce

Browse files
kardianosbradfitz
authored andcommitted
database/sql: fix race when canceling queries immediately
Previously the following could happen, though in practice it would be rare. Goroutine 1: (*Tx).QueryContext begins a query, passing in userContext Goroutine 2: (*Tx).awaitDone starts to wait on the context derived from the passed in context Goroutine 1: (*Tx).grabConn returns a valid (*driverConn) The (*driverConn) passes to (*DB).queryConn Goroutine 3: userContext is canceled Goroutine 2: (*Tx).awaitDone unblocks and calls (*Tx).rollback (*driverConn).finalClose obtains dc.Mutex (*driverConn).finalClose sets dc.ci = nil Goroutine 1: (*DB).queryConn obtains dc.Mutex in withLock ctxDriverPrepare accepts dc.ci which is now nil ctxCriverPrepare panics on the nil ci The fix for this is to guard the Tx methods with a RWLock holding it exclusivly when closing the Tx and holding a read lock when executing a query. Fixes #18719 Change-Id: I37aa02c37083c9793dabd28f7f934a1c5cbc05ea Reviewed-on: https://go-review.googlesource.com/35550 Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
1 parent 1cf0818 commit 2b283ce

File tree

2 files changed

+147
-35
lines changed

2 files changed

+147
-35
lines changed

src/database/sql/sql.go

+64-27
Original file line numberDiff line numberDiff line change
@@ -1357,16 +1357,7 @@ func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStra
13571357
cancel: cancel,
13581358
ctx: ctx,
13591359
}
1360-
go func(tx *Tx) {
1361-
select {
1362-
case <-tx.ctx.Done():
1363-
if !tx.isDone() {
1364-
// Discard and close the connection used to ensure the transaction
1365-
// is closed and the resources are released.
1366-
tx.rollback(true)
1367-
}
1368-
}
1369-
}(tx)
1360+
go tx.awaitDone()
13701361
return tx, nil
13711362
}
13721363

@@ -1388,6 +1379,11 @@ func (db *DB) Driver() driver.Driver {
13881379
type Tx struct {
13891380
db *DB
13901381

1382+
// closemu prevents the transaction from closing while there
1383+
// is an active query. It is held for read during queries
1384+
// and exclusively during close.
1385+
closemu sync.RWMutex
1386+
13911387
// dc is owned exclusively until Commit or Rollback, at which point
13921388
// it's returned with putConn.
13931389
dc *driverConn
@@ -1413,6 +1409,20 @@ type Tx struct {
14131409
ctx context.Context
14141410
}
14151411

1412+
// awaitDone blocks until the context in Tx is canceled and rolls back
1413+
// the transaction if it's not already done.
1414+
func (tx *Tx) awaitDone() {
1415+
// Wait for either the transaction to be committed or rolled
1416+
// back, or for the associated context to be closed.
1417+
<-tx.ctx.Done()
1418+
1419+
// Discard and close the connection used to ensure the
1420+
// transaction is closed and the resources are released. This
1421+
// rollback does nothing if the transaction has already been
1422+
// committed or rolled back.
1423+
tx.rollback(true)
1424+
}
1425+
14161426
func (tx *Tx) isDone() bool {
14171427
return atomic.LoadInt32(&tx.done) != 0
14181428
}
@@ -1424,16 +1434,31 @@ var ErrTxDone = errors.New("sql: Transaction has already been committed or rolle
14241434
// close returns the connection to the pool and
14251435
// must only be called by Tx.rollback or Tx.Commit.
14261436
func (tx *Tx) close(err error) {
1437+
tx.closemu.Lock()
1438+
defer tx.closemu.Unlock()
1439+
14271440
tx.db.putConn(tx.dc, err)
14281441
tx.cancel()
14291442
tx.dc = nil
14301443
tx.txi = nil
14311444
}
14321445

1446+
// hookTxGrabConn specifies an optional hook to be called on
1447+
// a successful call to (*Tx).grabConn. For tests.
1448+
var hookTxGrabConn func()
1449+
14331450
func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) {
1451+
select {
1452+
default:
1453+
case <-ctx.Done():
1454+
return nil, ctx.Err()
1455+
}
14341456
if tx.isDone() {
14351457
return nil, ErrTxDone
14361458
}
1459+
if hookTxGrabConn != nil { // test hook
1460+
hookTxGrabConn()
1461+
}
14371462
return tx.dc, nil
14381463
}
14391464

@@ -1503,6 +1528,9 @@ func (tx *Tx) Rollback() error {
15031528
// for the execution of the returned statement. The returned statement
15041529
// will run in the transaction context.
15051530
func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
1531+
tx.closemu.RLock()
1532+
defer tx.closemu.RUnlock()
1533+
15061534
// TODO(bradfitz): We could be more efficient here and either
15071535
// provide a method to take an existing Stmt (created on
15081536
// perhaps a different Conn), and re-create it on this Conn if
@@ -1567,6 +1595,9 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
15671595
// The returned statement operates within the transaction and will be closed
15681596
// when the transaction has been committed or rolled back.
15691597
func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
1598+
tx.closemu.RLock()
1599+
defer tx.closemu.RUnlock()
1600+
15701601
// TODO(bradfitz): optimize this. Currently this re-prepares
15711602
// each time. This is fine for now to illustrate the API but
15721603
// we should really cache already-prepared statements
@@ -1618,6 +1649,9 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
16181649
// ExecContext executes a query that doesn't return rows.
16191650
// For example: an INSERT and UPDATE.
16201651
func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
1652+
tx.closemu.RLock()
1653+
defer tx.closemu.RUnlock()
1654+
16211655
dc, err := tx.grabConn(ctx)
16221656
if err != nil {
16231657
return nil, err
@@ -1661,6 +1695,9 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
16611695

16621696
// QueryContext executes a query that returns rows, typically a SELECT.
16631697
func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
1698+
tx.closemu.RLock()
1699+
defer tx.closemu.RUnlock()
1700+
16641701
dc, err := tx.grabConn(ctx)
16651702
if err != nil {
16661703
return nil, err
@@ -2038,25 +2075,21 @@ type Rows struct {
20382075
// closed value is 1 when the Rows is closed.
20392076
// Use atomic operations on value when checking value.
20402077
closed int32
2041-
ctxClose chan struct{} // closed when Rows is closed, may be null.
2078+
cancel func() // called when Rows is closed, may be nil.
20422079
lastcols []driver.Value
20432080
lasterr error // non-nil only if closed is true
20442081
closeStmt *driverStmt // if non-nil, statement to Close on close
20452082
}
20462083

20472084
func (rs *Rows) initContextClose(ctx context.Context) {
2048-
if ctx.Done() == context.Background().Done() {
2049-
return
2050-
}
2085+
ctx, rs.cancel = context.WithCancel(ctx)
2086+
go rs.awaitDone(ctx)
2087+
}
20512088

2052-
rs.ctxClose = make(chan struct{})
2053-
go func() {
2054-
select {
2055-
case <-ctx.Done():
2056-
rs.Close()
2057-
case <-rs.ctxClose:
2058-
}
2059-
}()
2089+
// awaitDone blocks until the rows are closed or the context canceled.
2090+
func (rs *Rows) awaitDone(ctx context.Context) {
2091+
<-ctx.Done()
2092+
rs.Close()
20602093
}
20612094

20622095
// Next prepares the next result row for reading with the Scan method. It
@@ -2314,7 +2347,9 @@ func (rs *Rows) Scan(dest ...interface{}) error {
23142347
return nil
23152348
}
23162349

2317-
var rowsCloseHook func(*Rows, *error)
2350+
// rowsCloseHook returns a function so tests may install the
2351+
// hook throug a test only mutex.
2352+
var rowsCloseHook = func() func(*Rows, *error) { return nil }
23182353

23192354
func (rs *Rows) isClosed() bool {
23202355
return atomic.LoadInt32(&rs.closed) != 0
@@ -2328,13 +2363,15 @@ func (rs *Rows) Close() error {
23282363
if !atomic.CompareAndSwapInt32(&rs.closed, 0, 1) {
23292364
return nil
23302365
}
2331-
if rs.ctxClose != nil {
2332-
close(rs.ctxClose)
2333-
}
2366+
23342367
err := rs.rowsi.Close()
2335-
if fn := rowsCloseHook; fn != nil {
2368+
if fn := rowsCloseHook(); fn != nil {
23362369
fn(rs, &err)
23372370
}
2371+
if rs.cancel != nil {
2372+
rs.cancel()
2373+
}
2374+
23382375
if rs.closeStmt != nil {
23392376
rs.closeStmt.Close()
23402377
}

src/database/sql/sql_test.go

+83-8
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"runtime"
1515
"strings"
1616
"sync"
17+
"sync/atomic"
1718
"testing"
1819
"time"
1920
)
@@ -1135,6 +1136,24 @@ func TestQueryRowClosingStmt(t *testing.T) {
11351136
}
11361137
}
11371138

1139+
var atomicRowsCloseHook atomic.Value // of func(*Rows, *error)
1140+
1141+
func init() {
1142+
rowsCloseHook = func() func(*Rows, *error) {
1143+
fn, _ := atomicRowsCloseHook.Load().(func(*Rows, *error))
1144+
return fn
1145+
}
1146+
}
1147+
1148+
func setRowsCloseHook(fn func(*Rows, *error)) {
1149+
if fn == nil {
1150+
// Can't change an atomic.Value back to nil, so set it to this
1151+
// no-op func instead.
1152+
fn = func(*Rows, *error) {}
1153+
}
1154+
atomicRowsCloseHook.Store(fn)
1155+
}
1156+
11381157
// Test issue 6651
11391158
func TestIssue6651(t *testing.T) {
11401159
db := newTestDB(t, "people")
@@ -1147,17 +1166,18 @@ func TestIssue6651(t *testing.T) {
11471166
return fmt.Errorf(want)
11481167
}
11491168
defer func() { rowsCursorNextHook = nil }()
1169+
11501170
err := db.QueryRow("SELECT|people|name|").Scan(&v)
11511171
if err == nil || err.Error() != want {
11521172
t.Errorf("error = %q; want %q", err, want)
11531173
}
11541174
rowsCursorNextHook = nil
11551175

11561176
want = "error in rows.Close"
1157-
rowsCloseHook = func(rows *Rows, err *error) {
1177+
setRowsCloseHook(func(rows *Rows, err *error) {
11581178
*err = fmt.Errorf(want)
1159-
}
1160-
defer func() { rowsCloseHook = nil }()
1179+
})
1180+
defer setRowsCloseHook(nil)
11611181
err = db.QueryRow("SELECT|people|name|").Scan(&v)
11621182
if err == nil || err.Error() != want {
11631183
t.Errorf("error = %q; want %q", err, want)
@@ -1830,7 +1850,9 @@ func TestStmtCloseDeps(t *testing.T) {
18301850
db.dumpDeps(t)
18311851
}
18321852

1833-
if len(stmt.css) > nquery {
1853+
if !waitCondition(5*time.Second, 5*time.Millisecond, func() bool {
1854+
return len(stmt.css) <= nquery
1855+
}) {
18341856
t.Errorf("len(stmt.css) = %d; want <= %d", len(stmt.css), nquery)
18351857
}
18361858

@@ -2576,10 +2598,10 @@ func TestIssue6081(t *testing.T) {
25762598
if err != nil {
25772599
t.Fatal(err)
25782600
}
2579-
rowsCloseHook = func(rows *Rows, err *error) {
2601+
setRowsCloseHook(func(rows *Rows, err *error) {
25802602
*err = driver.ErrBadConn
2581-
}
2582-
defer func() { rowsCloseHook = nil }()
2603+
})
2604+
defer setRowsCloseHook(nil)
25832605
for i := 0; i < 10; i++ {
25842606
rows, err := stmt.Query()
25852607
if err != nil {
@@ -2642,7 +2664,10 @@ func TestIssue18429(t *testing.T) {
26422664
if err != nil {
26432665
return
26442666
}
2645-
rows, err := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|")
2667+
// This is expected to give a cancel error many, but not all the time.
2668+
// Test failure will happen with a panic or other race condition being
2669+
// reported.
2670+
rows, _ := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|")
26462671
if rows != nil {
26472672
rows.Close()
26482673
}
@@ -2655,6 +2680,56 @@ func TestIssue18429(t *testing.T) {
26552680
time.Sleep(milliWait * 3 * time.Millisecond)
26562681
}
26572682

2683+
// TestIssue18719 closes the context right before use. The sql.driverConn
2684+
// will nil out the ci on close in a lock, but if another process uses it right after
2685+
// it will panic with on the nil ref.
2686+
//
2687+
// See https://golang.org/cl/35550 .
2688+
func TestIssue18719(t *testing.T) {
2689+
db := newTestDB(t, "people")
2690+
defer closeDB(t, db)
2691+
2692+
ctx, cancel := context.WithCancel(context.Background())
2693+
defer cancel()
2694+
2695+
tx, err := db.BeginTx(ctx, nil)
2696+
if err != nil {
2697+
t.Fatal(err)
2698+
}
2699+
2700+
hookTxGrabConn = func() {
2701+
cancel()
2702+
2703+
// Wait for the context to cancel and tx to rollback.
2704+
for tx.isDone() == false {
2705+
time.Sleep(time.Millisecond * 3)
2706+
}
2707+
}
2708+
defer func() { hookTxGrabConn = nil }()
2709+
2710+
// This call will grab the connection and cancel the context
2711+
// after it has done so. Code after must deal with the canceled state.
2712+
rows, err := tx.QueryContext(ctx, "SELECT|people|name|")
2713+
if err != nil {
2714+
rows.Close()
2715+
t.Fatalf("expected error %v but got %v", nil, err)
2716+
}
2717+
2718+
// Rows may be ignored because it will be closed when the context is canceled.
2719+
2720+
// Do not explicitly rollback. The rollback will happen from the
2721+
// canceled context.
2722+
2723+
// Wait for connections to return to pool.
2724+
var numOpen int
2725+
if !waitCondition(5*time.Second, 5*time.Millisecond, func() bool {
2726+
numOpen = db.numOpenConns()
2727+
return numOpen == 0
2728+
}) {
2729+
t.Fatalf("open conns after hitting EOF = %d; want 0", numOpen)
2730+
}
2731+
}
2732+
26582733
func TestConcurrency(t *testing.T) {
26592734
doConcurrentTest(t, new(concurrentDBQueryTest))
26602735
doConcurrentTest(t, new(concurrentDBExecTest))

0 commit comments

Comments
 (0)