mirror of
https://github.com/golang/go.git
synced 2025-05-18 22:04:38 +00:00
database/sql: fix race when canceling queries immediately
Previously the following could happen, though in practice it would be rare. Goroutine 1: (*Tx).QueryContext begins a query, passing in userContext Goroutine 2: (*Tx).awaitDone starts to wait on the context derived from the passed in context Goroutine 1: (*Tx).grabConn returns a valid (*driverConn) The (*driverConn) passes to (*DB).queryConn Goroutine 3: userContext is canceled Goroutine 2: (*Tx).awaitDone unblocks and calls (*Tx).rollback (*driverConn).finalClose obtains dc.Mutex (*driverConn).finalClose sets dc.ci = nil Goroutine 1: (*DB).queryConn obtains dc.Mutex in withLock ctxDriverPrepare accepts dc.ci which is now nil ctxCriverPrepare panics on the nil ci The fix for this is to guard the Tx methods with a RWLock holding it exclusivly when closing the Tx and holding a read lock when executing a query. Fixes #18719 Change-Id: I37aa02c37083c9793dabd28f7f934a1c5cbc05ea Reviewed-on: https://go-review.googlesource.com/35550 Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
parent
1cf08182f9
commit
2b283cedef
@ -1357,16 +1357,7 @@ func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStra
|
||||
cancel: cancel,
|
||||
ctx: ctx,
|
||||
}
|
||||
go func(tx *Tx) {
|
||||
select {
|
||||
case <-tx.ctx.Done():
|
||||
if !tx.isDone() {
|
||||
// Discard and close the connection used to ensure the transaction
|
||||
// is closed and the resources are released.
|
||||
tx.rollback(true)
|
||||
}
|
||||
}
|
||||
}(tx)
|
||||
go tx.awaitDone()
|
||||
return tx, nil
|
||||
}
|
||||
|
||||
@ -1388,6 +1379,11 @@ func (db *DB) Driver() driver.Driver {
|
||||
type Tx struct {
|
||||
db *DB
|
||||
|
||||
// closemu prevents the transaction from closing while there
|
||||
// is an active query. It is held for read during queries
|
||||
// and exclusively during close.
|
||||
closemu sync.RWMutex
|
||||
|
||||
// dc is owned exclusively until Commit or Rollback, at which point
|
||||
// it's returned with putConn.
|
||||
dc *driverConn
|
||||
@ -1413,6 +1409,20 @@ type Tx struct {
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// awaitDone blocks until the context in Tx is canceled and rolls back
|
||||
// the transaction if it's not already done.
|
||||
func (tx *Tx) awaitDone() {
|
||||
// Wait for either the transaction to be committed or rolled
|
||||
// back, or for the associated context to be closed.
|
||||
<-tx.ctx.Done()
|
||||
|
||||
// Discard and close the connection used to ensure the
|
||||
// transaction is closed and the resources are released. This
|
||||
// rollback does nothing if the transaction has already been
|
||||
// committed or rolled back.
|
||||
tx.rollback(true)
|
||||
}
|
||||
|
||||
func (tx *Tx) isDone() bool {
|
||||
return atomic.LoadInt32(&tx.done) != 0
|
||||
}
|
||||
@ -1424,16 +1434,31 @@ var ErrTxDone = errors.New("sql: Transaction has already been committed or rolle
|
||||
// close returns the connection to the pool and
|
||||
// must only be called by Tx.rollback or Tx.Commit.
|
||||
func (tx *Tx) close(err error) {
|
||||
tx.closemu.Lock()
|
||||
defer tx.closemu.Unlock()
|
||||
|
||||
tx.db.putConn(tx.dc, err)
|
||||
tx.cancel()
|
||||
tx.dc = nil
|
||||
tx.txi = nil
|
||||
}
|
||||
|
||||
// hookTxGrabConn specifies an optional hook to be called on
|
||||
// a successful call to (*Tx).grabConn. For tests.
|
||||
var hookTxGrabConn func()
|
||||
|
||||
func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) {
|
||||
select {
|
||||
default:
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
if tx.isDone() {
|
||||
return nil, ErrTxDone
|
||||
}
|
||||
if hookTxGrabConn != nil { // test hook
|
||||
hookTxGrabConn()
|
||||
}
|
||||
return tx.dc, nil
|
||||
}
|
||||
|
||||
@ -1503,6 +1528,9 @@ func (tx *Tx) Rollback() error {
|
||||
// for the execution of the returned statement. The returned statement
|
||||
// will run in the transaction context.
|
||||
func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
|
||||
tx.closemu.RLock()
|
||||
defer tx.closemu.RUnlock()
|
||||
|
||||
// TODO(bradfitz): We could be more efficient here and either
|
||||
// provide a method to take an existing Stmt (created on
|
||||
// perhaps a different Conn), and re-create it on this Conn if
|
||||
@ -1567,6 +1595,9 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
|
||||
// The returned statement operates within the transaction and will be closed
|
||||
// when the transaction has been committed or rolled back.
|
||||
func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
|
||||
tx.closemu.RLock()
|
||||
defer tx.closemu.RUnlock()
|
||||
|
||||
// TODO(bradfitz): optimize this. Currently this re-prepares
|
||||
// each time. This is fine for now to illustrate the API but
|
||||
// we should really cache already-prepared statements
|
||||
@ -1618,6 +1649,9 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
|
||||
// ExecContext executes a query that doesn't return rows.
|
||||
// For example: an INSERT and UPDATE.
|
||||
func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
|
||||
tx.closemu.RLock()
|
||||
defer tx.closemu.RUnlock()
|
||||
|
||||
dc, err := tx.grabConn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -1661,6 +1695,9 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
|
||||
|
||||
// QueryContext executes a query that returns rows, typically a SELECT.
|
||||
func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
|
||||
tx.closemu.RLock()
|
||||
defer tx.closemu.RUnlock()
|
||||
|
||||
dc, err := tx.grabConn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -2038,25 +2075,21 @@ type Rows struct {
|
||||
// closed value is 1 when the Rows is closed.
|
||||
// Use atomic operations on value when checking value.
|
||||
closed int32
|
||||
ctxClose chan struct{} // closed when Rows is closed, may be null.
|
||||
cancel func() // called when Rows is closed, may be nil.
|
||||
lastcols []driver.Value
|
||||
lasterr error // non-nil only if closed is true
|
||||
closeStmt *driverStmt // if non-nil, statement to Close on close
|
||||
}
|
||||
|
||||
func (rs *Rows) initContextClose(ctx context.Context) {
|
||||
if ctx.Done() == context.Background().Done() {
|
||||
return
|
||||
}
|
||||
ctx, rs.cancel = context.WithCancel(ctx)
|
||||
go rs.awaitDone(ctx)
|
||||
}
|
||||
|
||||
rs.ctxClose = make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
rs.Close()
|
||||
case <-rs.ctxClose:
|
||||
}
|
||||
}()
|
||||
// awaitDone blocks until the rows are closed or the context canceled.
|
||||
func (rs *Rows) awaitDone(ctx context.Context) {
|
||||
<-ctx.Done()
|
||||
rs.Close()
|
||||
}
|
||||
|
||||
// Next prepares the next result row for reading with the Scan method. It
|
||||
@ -2314,7 +2347,9 @@ func (rs *Rows) Scan(dest ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var rowsCloseHook func(*Rows, *error)
|
||||
// rowsCloseHook returns a function so tests may install the
|
||||
// hook throug a test only mutex.
|
||||
var rowsCloseHook = func() func(*Rows, *error) { return nil }
|
||||
|
||||
func (rs *Rows) isClosed() bool {
|
||||
return atomic.LoadInt32(&rs.closed) != 0
|
||||
@ -2328,13 +2363,15 @@ func (rs *Rows) Close() error {
|
||||
if !atomic.CompareAndSwapInt32(&rs.closed, 0, 1) {
|
||||
return nil
|
||||
}
|
||||
if rs.ctxClose != nil {
|
||||
close(rs.ctxClose)
|
||||
}
|
||||
|
||||
err := rs.rowsi.Close()
|
||||
if fn := rowsCloseHook; fn != nil {
|
||||
if fn := rowsCloseHook(); fn != nil {
|
||||
fn(rs, &err)
|
||||
}
|
||||
if rs.cancel != nil {
|
||||
rs.cancel()
|
||||
}
|
||||
|
||||
if rs.closeStmt != nil {
|
||||
rs.closeStmt.Close()
|
||||
}
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@ -1135,6 +1136,24 @@ func TestQueryRowClosingStmt(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
var atomicRowsCloseHook atomic.Value // of func(*Rows, *error)
|
||||
|
||||
func init() {
|
||||
rowsCloseHook = func() func(*Rows, *error) {
|
||||
fn, _ := atomicRowsCloseHook.Load().(func(*Rows, *error))
|
||||
return fn
|
||||
}
|
||||
}
|
||||
|
||||
func setRowsCloseHook(fn func(*Rows, *error)) {
|
||||
if fn == nil {
|
||||
// Can't change an atomic.Value back to nil, so set it to this
|
||||
// no-op func instead.
|
||||
fn = func(*Rows, *error) {}
|
||||
}
|
||||
atomicRowsCloseHook.Store(fn)
|
||||
}
|
||||
|
||||
// Test issue 6651
|
||||
func TestIssue6651(t *testing.T) {
|
||||
db := newTestDB(t, "people")
|
||||
@ -1147,6 +1166,7 @@ func TestIssue6651(t *testing.T) {
|
||||
return fmt.Errorf(want)
|
||||
}
|
||||
defer func() { rowsCursorNextHook = nil }()
|
||||
|
||||
err := db.QueryRow("SELECT|people|name|").Scan(&v)
|
||||
if err == nil || err.Error() != want {
|
||||
t.Errorf("error = %q; want %q", err, want)
|
||||
@ -1154,10 +1174,10 @@ func TestIssue6651(t *testing.T) {
|
||||
rowsCursorNextHook = nil
|
||||
|
||||
want = "error in rows.Close"
|
||||
rowsCloseHook = func(rows *Rows, err *error) {
|
||||
setRowsCloseHook(func(rows *Rows, err *error) {
|
||||
*err = fmt.Errorf(want)
|
||||
}
|
||||
defer func() { rowsCloseHook = nil }()
|
||||
})
|
||||
defer setRowsCloseHook(nil)
|
||||
err = db.QueryRow("SELECT|people|name|").Scan(&v)
|
||||
if err == nil || err.Error() != want {
|
||||
t.Errorf("error = %q; want %q", err, want)
|
||||
@ -1830,7 +1850,9 @@ func TestStmtCloseDeps(t *testing.T) {
|
||||
db.dumpDeps(t)
|
||||
}
|
||||
|
||||
if len(stmt.css) > nquery {
|
||||
if !waitCondition(5*time.Second, 5*time.Millisecond, func() bool {
|
||||
return len(stmt.css) <= nquery
|
||||
}) {
|
||||
t.Errorf("len(stmt.css) = %d; want <= %d", len(stmt.css), nquery)
|
||||
}
|
||||
|
||||
@ -2576,10 +2598,10 @@ func TestIssue6081(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rowsCloseHook = func(rows *Rows, err *error) {
|
||||
setRowsCloseHook(func(rows *Rows, err *error) {
|
||||
*err = driver.ErrBadConn
|
||||
}
|
||||
defer func() { rowsCloseHook = nil }()
|
||||
})
|
||||
defer setRowsCloseHook(nil)
|
||||
for i := 0; i < 10; i++ {
|
||||
rows, err := stmt.Query()
|
||||
if err != nil {
|
||||
@ -2642,7 +2664,10 @@ func TestIssue18429(t *testing.T) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
rows, err := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|")
|
||||
// This is expected to give a cancel error many, but not all the time.
|
||||
// Test failure will happen with a panic or other race condition being
|
||||
// reported.
|
||||
rows, _ := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|")
|
||||
if rows != nil {
|
||||
rows.Close()
|
||||
}
|
||||
@ -2655,6 +2680,56 @@ func TestIssue18429(t *testing.T) {
|
||||
time.Sleep(milliWait * 3 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestIssue18719 closes the context right before use. The sql.driverConn
|
||||
// will nil out the ci on close in a lock, but if another process uses it right after
|
||||
// it will panic with on the nil ref.
|
||||
//
|
||||
// See https://golang.org/cl/35550 .
|
||||
func TestIssue18719(t *testing.T) {
|
||||
db := newTestDB(t, "people")
|
||||
defer closeDB(t, db)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
hookTxGrabConn = func() {
|
||||
cancel()
|
||||
|
||||
// Wait for the context to cancel and tx to rollback.
|
||||
for tx.isDone() == false {
|
||||
time.Sleep(time.Millisecond * 3)
|
||||
}
|
||||
}
|
||||
defer func() { hookTxGrabConn = nil }()
|
||||
|
||||
// This call will grab the connection and cancel the context
|
||||
// after it has done so. Code after must deal with the canceled state.
|
||||
rows, err := tx.QueryContext(ctx, "SELECT|people|name|")
|
||||
if err != nil {
|
||||
rows.Close()
|
||||
t.Fatalf("expected error %v but got %v", nil, err)
|
||||
}
|
||||
|
||||
// Rows may be ignored because it will be closed when the context is canceled.
|
||||
|
||||
// Do not explicitly rollback. The rollback will happen from the
|
||||
// canceled context.
|
||||
|
||||
// Wait for connections to return to pool.
|
||||
var numOpen int
|
||||
if !waitCondition(5*time.Second, 5*time.Millisecond, func() bool {
|
||||
numOpen = db.numOpenConns()
|
||||
return numOpen == 0
|
||||
}) {
|
||||
t.Fatalf("open conns after hitting EOF = %d; want 0", numOpen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrency(t *testing.T) {
|
||||
doConcurrentTest(t, new(concurrentDBQueryTest))
|
||||
doConcurrentTest(t, new(concurrentDBExecTest))
|
||||
|
Loading…
x
Reference in New Issue
Block a user