@@ -1357,16 +1357,7 @@ func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStra
1357
1357
cancel : cancel ,
1358
1358
ctx : ctx ,
1359
1359
}
1360
- go func (tx * Tx ) {
1361
- select {
1362
- case <- tx .ctx .Done ():
1363
- if ! tx .isDone () {
1364
- // Discard and close the connection used to ensure the transaction
1365
- // is closed and the resources are released.
1366
- tx .rollback (true )
1367
- }
1368
- }
1369
- }(tx )
1360
+ go tx .awaitDone ()
1370
1361
return tx , nil
1371
1362
}
1372
1363
@@ -1388,6 +1379,11 @@ func (db *DB) Driver() driver.Driver {
1388
1379
type Tx struct {
1389
1380
db * DB
1390
1381
1382
+ // closemu prevents the transaction from closing while there
1383
+ // is an active query. It is held for read during queries
1384
+ // and exclusively during close.
1385
+ closemu sync.RWMutex
1386
+
1391
1387
// dc is owned exclusively until Commit or Rollback, at which point
1392
1388
// it's returned with putConn.
1393
1389
dc * driverConn
@@ -1413,6 +1409,20 @@ type Tx struct {
1413
1409
ctx context.Context
1414
1410
}
1415
1411
1412
+ // awaitDone blocks until the context in Tx is canceled and rolls back
1413
+ // the transaction if it's not already done.
1414
+ func (tx * Tx ) awaitDone () {
1415
+ // Wait for either the transaction to be committed or rolled
1416
+ // back, or for the associated context to be closed.
1417
+ <- tx .ctx .Done ()
1418
+
1419
+ // Discard and close the connection used to ensure the
1420
+ // transaction is closed and the resources are released. This
1421
+ // rollback does nothing if the transaction has already been
1422
+ // committed or rolled back.
1423
+ tx .rollback (true )
1424
+ }
1425
+
1416
1426
func (tx * Tx ) isDone () bool {
1417
1427
return atomic .LoadInt32 (& tx .done ) != 0
1418
1428
}
@@ -1424,16 +1434,31 @@ var ErrTxDone = errors.New("sql: Transaction has already been committed or rolle
1424
1434
// close returns the connection to the pool and
1425
1435
// must only be called by Tx.rollback or Tx.Commit.
1426
1436
func (tx * Tx ) close (err error ) {
1437
+ tx .closemu .Lock ()
1438
+ defer tx .closemu .Unlock ()
1439
+
1427
1440
tx .db .putConn (tx .dc , err )
1428
1441
tx .cancel ()
1429
1442
tx .dc = nil
1430
1443
tx .txi = nil
1431
1444
}
1432
1445
1446
+ // hookTxGrabConn specifies an optional hook to be called on
1447
+ // a successful call to (*Tx).grabConn. For tests.
1448
+ var hookTxGrabConn func ()
1449
+
1433
1450
func (tx * Tx ) grabConn (ctx context.Context ) (* driverConn , error ) {
1451
+ select {
1452
+ default :
1453
+ case <- ctx .Done ():
1454
+ return nil , ctx .Err ()
1455
+ }
1434
1456
if tx .isDone () {
1435
1457
return nil , ErrTxDone
1436
1458
}
1459
+ if hookTxGrabConn != nil { // test hook
1460
+ hookTxGrabConn ()
1461
+ }
1437
1462
return tx .dc , nil
1438
1463
}
1439
1464
@@ -1503,6 +1528,9 @@ func (tx *Tx) Rollback() error {
1503
1528
// for the execution of the returned statement. The returned statement
1504
1529
// will run in the transaction context.
1505
1530
func (tx * Tx ) PrepareContext (ctx context.Context , query string ) (* Stmt , error ) {
1531
+ tx .closemu .RLock ()
1532
+ defer tx .closemu .RUnlock ()
1533
+
1506
1534
// TODO(bradfitz): We could be more efficient here and either
1507
1535
// provide a method to take an existing Stmt (created on
1508
1536
// perhaps a different Conn), and re-create it on this Conn if
@@ -1567,6 +1595,9 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
1567
1595
// The returned statement operates within the transaction and will be closed
1568
1596
// when the transaction has been committed or rolled back.
1569
1597
func (tx * Tx ) StmtContext (ctx context.Context , stmt * Stmt ) * Stmt {
1598
+ tx .closemu .RLock ()
1599
+ defer tx .closemu .RUnlock ()
1600
+
1570
1601
// TODO(bradfitz): optimize this. Currently this re-prepares
1571
1602
// each time. This is fine for now to illustrate the API but
1572
1603
// we should really cache already-prepared statements
@@ -1618,6 +1649,9 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
1618
1649
// ExecContext executes a query that doesn't return rows.
1619
1650
// For example: an INSERT and UPDATE.
1620
1651
func (tx * Tx ) ExecContext (ctx context.Context , query string , args ... interface {}) (Result , error ) {
1652
+ tx .closemu .RLock ()
1653
+ defer tx .closemu .RUnlock ()
1654
+
1621
1655
dc , err := tx .grabConn (ctx )
1622
1656
if err != nil {
1623
1657
return nil , err
@@ -1661,6 +1695,9 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
1661
1695
1662
1696
// QueryContext executes a query that returns rows, typically a SELECT.
1663
1697
func (tx * Tx ) QueryContext (ctx context.Context , query string , args ... interface {}) (* Rows , error ) {
1698
+ tx .closemu .RLock ()
1699
+ defer tx .closemu .RUnlock ()
1700
+
1664
1701
dc , err := tx .grabConn (ctx )
1665
1702
if err != nil {
1666
1703
return nil , err
@@ -2038,25 +2075,21 @@ type Rows struct {
2038
2075
// closed value is 1 when the Rows is closed.
2039
2076
// Use atomic operations on value when checking value.
2040
2077
closed int32
2041
- ctxClose chan struct {} // closed when Rows is closed, may be null .
2078
+ cancel func () // called when Rows is closed, may be nil .
2042
2079
lastcols []driver.Value
2043
2080
lasterr error // non-nil only if closed is true
2044
2081
closeStmt * driverStmt // if non-nil, statement to Close on close
2045
2082
}
2046
2083
2047
2084
func (rs * Rows ) initContextClose (ctx context.Context ) {
2048
- if ctx . Done () == context .Background (). Done () {
2049
- return
2050
- }
2085
+ ctx , rs . cancel = context .WithCancel ( ctx )
2086
+ go rs . awaitDone ( ctx )
2087
+ }
2051
2088
2052
- rs .ctxClose = make (chan struct {})
2053
- go func () {
2054
- select {
2055
- case <- ctx .Done ():
2056
- rs .Close ()
2057
- case <- rs .ctxClose :
2058
- }
2059
- }()
2089
+ // awaitDone blocks until the rows are closed or the context canceled.
2090
+ func (rs * Rows ) awaitDone (ctx context.Context ) {
2091
+ <- ctx .Done ()
2092
+ rs .Close ()
2060
2093
}
2061
2094
2062
2095
// Next prepares the next result row for reading with the Scan method. It
@@ -2314,7 +2347,9 @@ func (rs *Rows) Scan(dest ...interface{}) error {
2314
2347
return nil
2315
2348
}
2316
2349
2317
- var rowsCloseHook func (* Rows , * error )
2350
+ // rowsCloseHook returns a function so tests may install the
2351
+ // hook throug a test only mutex.
2352
+ var rowsCloseHook = func () func (* Rows , * error ) { return nil }
2318
2353
2319
2354
func (rs * Rows ) isClosed () bool {
2320
2355
return atomic .LoadInt32 (& rs .closed ) != 0
@@ -2328,13 +2363,15 @@ func (rs *Rows) Close() error {
2328
2363
if ! atomic .CompareAndSwapInt32 (& rs .closed , 0 , 1 ) {
2329
2364
return nil
2330
2365
}
2331
- if rs .ctxClose != nil {
2332
- close (rs .ctxClose )
2333
- }
2366
+
2334
2367
err := rs .rowsi .Close ()
2335
- if fn := rowsCloseHook ; fn != nil {
2368
+ if fn := rowsCloseHook () ; fn != nil {
2336
2369
fn (rs , & err )
2337
2370
}
2371
+ if rs .cancel != nil {
2372
+ rs .cancel ()
2373
+ }
2374
+
2338
2375
if rs .closeStmt != nil {
2339
2376
rs .closeStmt .Close ()
2340
2377
}
0 commit comments