diff --git a/src/database/sql/driver/driver.go b/src/database/sql/driver/driver.go index f5a2e7c16c..6113af79c5 100644 --- a/src/database/sql/driver/driver.go +++ b/src/database/sql/driver/driver.go @@ -222,6 +222,18 @@ type ConnBeginTx interface { BeginTx(ctx context.Context, opts TxOptions) (Tx, error) } +// ResetSessioner may be implemented by Conn to allow drivers to reset the +// session state associated with the connection and to signal a bad connection. +type ResetSessioner interface { + // ResetSession is called while a connection is in the connection + // pool. No queries will run on this connection until this method returns. + // + // If the connection is bad this should return driver.ErrBadConn to prevent + // the connection from being returned to the connection pool. Any other + // error will be discarded. + ResetSession(ctx context.Context) error +} + // Result is the result of a query execution. type Result interface { // LastInsertId returns the database's auto-generated ID diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go index 4dcd096ca4..070b783453 100644 --- a/src/database/sql/fakedb_test.go +++ b/src/database/sql/fakedb_test.go @@ -55,6 +55,22 @@ type fakeDriver struct { dbs map[string]*fakeDB } +type fakeConnector struct { + name string + + waiter func(context.Context) +} + +func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) { + conn, err := fdriver.Open(c.name) + conn.(*fakeConn).waiter = c.waiter + return conn, err +} + +func (c *fakeConnector) Driver() driver.Driver { + return fdriver +} + type fakeDB struct { name string @@ -107,6 +123,16 @@ type fakeConn struct { // bad connection tests; see isBad() bad bool stickyBad bool + + skipDirtySession bool // tests that use Conn should set this to true. + + // dirtySession tests ResetSession, true if a query has executed + // until ResetSession is called. + dirtySession bool + + // The waiter is called before each query. May be used in place of the "WAIT" + // directive. + waiter func(context.Context) } func (c *fakeConn) touchMem() { @@ -298,6 +324,9 @@ func (c *fakeConn) isBad() bool { if c.stickyBad { return true } else if c.bad { + if c.db == nil { + return false + } // alternate between bad conn and not bad conn c.db.badConn = !c.db.badConn return c.db.badConn @@ -306,6 +335,21 @@ func (c *fakeConn) isBad() bool { } } +func (c *fakeConn) isDirtyAndMark() bool { + if c.skipDirtySession { + return false + } + if c.currTx != nil { + c.dirtySession = true + return false + } + if c.dirtySession { + return true + } + c.dirtySession = true + return false +} + func (c *fakeConn) Begin() (driver.Tx, error) { if c.isBad() { return nil, driver.ErrBadConn @@ -337,6 +381,14 @@ func setStrictFakeConnClose(t *testing.T) { testStrictClose = t } +func (c *fakeConn) ResetSession(ctx context.Context) error { + c.dirtySession = false + if c.isBad() { + return driver.ErrBadConn + } + return nil +} + func (c *fakeConn) Close() (err error) { drv := fdriver.(*fakeDriver) defer func() { @@ -572,6 +624,10 @@ func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stm stmt.cmd = cmd parts = parts[1:] + if c.waiter != nil { + c.waiter(ctx) + } + if stmt.wait > 0 { wait := time.NewTimer(stmt.wait) select { @@ -662,6 +718,9 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) { return nil, driver.ErrBadConn } + if s.c.isDirtyAndMark() { + return nil, errors.New("session is dirty") + } err := checkSubsetTypes(s.c.db.allowAny, args) if err != nil { @@ -774,6 +833,9 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) ( if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) { return nil, driver.ErrBadConn } + if s.c.isDirtyAndMark() { + return nil, errors.New("session is dirty") + } err := checkSubsetTypes(s.c.db.allowAny, args) if err != nil { diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 7c35710688..c17b2b543b 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -334,6 +334,7 @@ type DB struct { // It is closed during db.Close(). The close tells the connectionOpener // goroutine to exit. openerCh chan struct{} + resetterCh chan *driverConn closed bool dep map[finalCloser]depSet lastPut map[*driverConn]string // stacktrace of last conn's put; debug only @@ -341,6 +342,8 @@ type DB struct { maxOpen int // <= 0 means unlimited maxLifetime time.Duration // maximum amount of time a connection may be reused cleanerCh chan struct{} + + stop func() // stop cancels the connection opener and the session resetter. } // connReuseStrategy determines how (*DB).conn returns database connections. @@ -368,6 +371,7 @@ type driverConn struct { closed bool finalClosed bool // ci.Close has been called openStmt map[*driverStmt]bool + lastErr error // lastError captures the result of the session resetter. // guarded by db.mu inUse bool @@ -376,7 +380,7 @@ type driverConn struct { } func (dc *driverConn) releaseConn(err error) { - dc.db.putConn(dc, err) + dc.db.putConn(dc, err, true) } func (dc *driverConn) removeOpenStmt(ds *driverStmt) { @@ -417,6 +421,19 @@ func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, que return ds, nil } +// resetSession resets the connection session and sets the lastErr +// that is checked before returning the connection to another query. +// +// resetSession assumes that the embedded mutex is locked when the connection +// was returned to the pool. This unlocks the mutex. +func (dc *driverConn) resetSession(ctx context.Context) { + defer dc.Unlock() // In case of panic. + if dc.closed { // Check if the database has been closed. + return + } + dc.lastErr = dc.ci.(driver.ResetSessioner).ResetSession(ctx) +} + // the dc.db's Mutex is held. func (dc *driverConn) closeDBLocked() func() error { dc.Lock() @@ -604,14 +621,18 @@ func (t dsnConnector) Driver() driver.Driver { // function should be called just once. It is rarely necessary to // close a DB. func OpenDB(c driver.Connector) *DB { + ctx, cancel := context.WithCancel(context.Background()) db := &DB{ connector: c, openerCh: make(chan struct{}, connectionRequestQueueSize), + resetterCh: make(chan *driverConn, 50), lastPut: make(map[*driverConn]string), connRequests: make(map[uint64]chan connRequest), + stop: cancel, } - go db.connectionOpener() + go db.connectionOpener(ctx) + go db.connectionResetter(ctx) return db } @@ -693,7 +714,6 @@ func (db *DB) Close() error { db.mu.Unlock() return nil } - close(db.openerCh) if db.cleanerCh != nil { close(db.cleanerCh) } @@ -714,6 +734,7 @@ func (db *DB) Close() error { err = err1 } } + db.stop() return err } @@ -901,18 +922,39 @@ func (db *DB) maybeOpenNewConnections() { } // Runs in a separate goroutine, opens new connections when requested. -func (db *DB) connectionOpener() { - for range db.openerCh { - db.openNewConnection() +func (db *DB) connectionOpener(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-db.openerCh: + db.openNewConnection(ctx) + } + } +} + +// connectionResetter runs in a separate goroutine to reset connections async +// to exported API. +func (db *DB) connectionResetter(ctx context.Context) { + for { + select { + case <-ctx.Done(): + for dc := range db.resetterCh { + dc.Unlock() + } + return + case dc := <-db.resetterCh: + dc.resetSession(ctx) + } } } // Open one new connection -func (db *DB) openNewConnection() { +func (db *DB) openNewConnection(ctx context.Context) { // maybeOpenNewConnctions has already executed db.numOpen++ before it sent // on db.openerCh. This function must execute db.numOpen-- if the // connection fails or is closed before returning. - ci, err := db.connector.Connect(context.Background()) + ci, err := db.connector.Connect(ctx) db.mu.Lock() defer db.mu.Unlock() if db.closed { @@ -987,6 +1029,14 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn conn.Close() return nil, driver.ErrBadConn } + // Lock around reading lastErr to ensure the session resetter finished. + conn.Lock() + err := conn.lastErr + conn.Unlock() + if err == driver.ErrBadConn { + conn.Close() + return nil, driver.ErrBadConn + } return conn, nil } @@ -1012,7 +1062,7 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn default: case ret, ok := <-req: if ok { - db.putConn(ret.conn, ret.err) + db.putConn(ret.conn, ret.err, false) } } return nil, ctx.Err() @@ -1024,6 +1074,17 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn ret.conn.Close() return nil, driver.ErrBadConn } + if ret.conn == nil { + return nil, ret.err + } + // Lock around reading lastErr to ensure the session resetter finished. + ret.conn.Lock() + err := ret.conn.lastErr + ret.conn.Unlock() + if err == driver.ErrBadConn { + ret.conn.Close() + return nil, driver.ErrBadConn + } return ret.conn, ret.err } } @@ -1079,7 +1140,7 @@ const debugGetPut = false // putConn adds a connection to the db's free pool. // err is optionally the last error that occurred on this connection. -func (db *DB) putConn(dc *driverConn, err error) { +func (db *DB) putConn(dc *driverConn, err error, resetSession bool) { db.mu.Lock() if !dc.inUse { if debugGetPut { @@ -1110,11 +1171,35 @@ func (db *DB) putConn(dc *driverConn, err error) { if putConnHook != nil { putConnHook(db, dc) } + if resetSession { + if _, resetSession = dc.ci.(driver.ResetSessioner); resetSession { + // Lock the driverConn here so it isn't released until + // the connection is reset. + // The lock must be taken before the connection is put into + // the pool to prevent it from being taken out before it is reset. + dc.Lock() + } + } added := db.putConnDBLocked(dc, nil) db.mu.Unlock() if !added { + if resetSession { + dc.Unlock() + } dc.Close() + return + } + if !resetSession { + return + } + select { + default: + // If the resetterCh is blocking then mark the connection + // as bad and continue on. + dc.lastErr = driver.ErrBadConn + dc.Unlock() + case db.resetterCh <- dc: } } diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 3551366369..7100d000c7 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -60,10 +60,12 @@ const fakeDBName = "foo" var chrisBirthday = time.Unix(123456789, 0) func newTestDB(t testing.TB, name string) *DB { - db, err := Open("test", fakeDBName) - if err != nil { - t.Fatalf("Open: %v", err) - } + return newTestDBConnector(t, &fakeConnector{name: fakeDBName}, name) +} + +func newTestDBConnector(t testing.TB, fc *fakeConnector, name string) *DB { + fc.name = fakeDBName + db := OpenDB(fc) if _, err := db.Exec("WIPE"); err != nil { t.Fatalf("exec wipe: %v", err) } @@ -585,24 +587,46 @@ func TestPoolExhaustOnCancel(t *testing.T) { if testing.Short() { t.Skip("long test") } - db := newTestDB(t, "people") - defer closeDB(t, db) max := 3 + var saturate, saturateDone sync.WaitGroup + saturate.Add(max) + saturateDone.Add(max) + + donePing := make(chan bool) + state := 0 + + // waiter will be called for all queries, including + // initial setup queries. The state is only assigned when no + // no queries are made. + // + // Only allow the first batch of queries to finish once the + // second batch of Ping queries have finished. + waiter := func(ctx context.Context) { + switch state { + case 0: + // Nothing. Initial database setup. + case 1: + saturate.Done() + select { + case <-ctx.Done(): + case <-donePing: + } + case 2: + } + } + db := newTestDBConnector(t, &fakeConnector{waiter: waiter}, "people") + defer closeDB(t, db) db.SetMaxOpenConns(max) // First saturate the connection pool. // Then start new requests for a connection that is cancelled after it is requested. - var saturate, saturateDone sync.WaitGroup - saturate.Add(max) - saturateDone.Add(max) - + state = 1 for i := 0; i < max; i++ { go func() { - saturate.Done() - rows, err := db.Query("WAIT|500ms|SELECT|people|name,photo|") + rows, err := db.Query("SELECT|people|name,photo|") if err != nil { t.Fatalf("Query: %v", err) } @@ -612,6 +636,7 @@ func TestPoolExhaustOnCancel(t *testing.T) { } saturate.Wait() + state = 2 // Now cancel the request while it is waiting. ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) @@ -628,7 +653,7 @@ func TestPoolExhaustOnCancel(t *testing.T) { t.Fatalf("PingContext (Exhaust): %v", err) } } - + close(donePing) saturateDone.Wait() // Now try to open a normal connection. @@ -1332,6 +1357,7 @@ func TestConnQuery(t *testing.T) { if err != nil { t.Fatal(err) } + conn.dc.ci.(*fakeConn).skipDirtySession = true defer conn.Close() var name string @@ -1359,6 +1385,7 @@ func TestConnTx(t *testing.T) { if err != nil { t.Fatal(err) } + conn.dc.ci.(*fakeConn).skipDirtySession = true defer conn.Close() tx, err := conn.BeginTx(ctx, nil) @@ -2384,7 +2411,9 @@ func TestManyErrBadConn(t *testing.T) { t.Fatalf("unexpected len(db.freeConn) %d (was expecting %d)", len(db.freeConn), nconn) } for _, conn := range db.freeConn { + conn.Lock() conn.ci.(*fakeConn).stickyBad = true + conn.Unlock() } return db } @@ -2474,6 +2503,7 @@ func TestManyErrBadConn(t *testing.T) { if err != nil { t.Fatal(err) } + conn.dc.ci.(*fakeConn).skipDirtySession = true err = conn.Close() if err != nil { t.Fatal(err) @@ -3238,9 +3268,8 @@ func TestIssue18719(t *testing.T) { // This call will grab the connection and cancel the context // after it has done so. Code after must deal with the canceled state. - rows, err := tx.QueryContext(ctx, "SELECT|people|name|") + _, err = tx.QueryContext(ctx, "SELECT|people|name|") if err != nil { - rows.Close() t.Fatalf("expected error %v but got %v", nil, err) } @@ -3263,6 +3292,7 @@ func TestIssue20647(t *testing.T) { if err != nil { t.Fatal(err) } + conn.dc.ci.(*fakeConn).skipDirtySession = true defer conn.Close() stmt, err := conn.PrepareContext(ctx, "SELECT|people|name|")