Commit 90b8a0ca authored by Daniel Theophanes's avatar Daniel Theophanes Committed by Brad Fitzpatrick

database/sql: ensure all driver Stmt are closed once

Previously  driver.Stmt could could be closed multiple times in
edge cases that drivers may not test for initially. Make their
job easier by ensuring the driver is only closed a single time.

Fixes #16019

Change-Id: I1e4777ef70697a849602e6ef9da73054a8feb4cd
Reviewed-on: https://go-review.googlesource.com/33352Reviewed-by: 's avatarBrad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent e0942b76
......@@ -330,7 +330,7 @@ type driverConn struct {
ci driver.Conn
closed bool
finalClosed bool // ci.Close has been called
openStmt map[driver.Stmt]bool
openStmt map[*driverStmt]bool
// guarded by db.mu
inUse bool
......@@ -342,10 +342,10 @@ func (dc *driverConn) releaseConn(err error) {
dc.db.putConn(dc, err)
}
func (dc *driverConn) removeOpenStmt(si driver.Stmt) {
func (dc *driverConn) removeOpenStmt(ds *driverStmt) {
dc.Lock()
defer dc.Unlock()
delete(dc.openStmt, si)
delete(dc.openStmt, ds)
}
func (dc *driverConn) expired(timeout time.Duration) bool {
......@@ -355,28 +355,23 @@ func (dc *driverConn) expired(timeout time.Duration) bool {
return dc.createdAt.Add(timeout).Before(nowFunc())
}
func (dc *driverConn) prepareLocked(ctx context.Context, query string) (driver.Stmt, error) {
func (dc *driverConn) prepareLocked(ctx context.Context, query string) (*driverStmt, error) {
si, err := ctxDriverPrepare(ctx, dc.ci, query)
if err == nil {
if err != nil {
return nil, err
}
// Track each driverConn's open statements, so we can close them
// before closing the conn.
//
// TODO(bradfitz): let drivers opt out of caring about
// stmt closes if the conn is about to close anyway? For now
// do the safe thing, in case stmts need to be closed.
//
// TODO(bradfitz): after Go 1.2, closing driver.Stmts
// should be moved to driverStmt, using unique
// *driverStmts everywhere (including from
// *Stmt.connStmt, instead of returning a
// driver.Stmt), using driverStmt as a pointer
// everywhere, and making it a finalCloser.
// Wrap all driver.Stmt is *driverStmt to ensure they are only closed once.
if dc.openStmt == nil {
dc.openStmt = make(map[driver.Stmt]bool)
}
dc.openStmt[si] = true
dc.openStmt = make(map[*driverStmt]bool)
}
return si, err
ds := &driverStmt{Locker: dc, si: si}
dc.openStmt[ds] = true
return ds, nil
}
// the dc.db's Mutex is held.
......@@ -409,17 +404,24 @@ func (dc *driverConn) Close() error {
func (dc *driverConn) finalClose() error {
var err error
// Each *driverStmt has a lock to the dc. Copy the list out of the dc
// before calling close on each stmt.
var openStmt []*driverStmt
withLock(dc, func() {
defer func() { // In case si.Close panics.
openStmt = make([]*driverStmt, 0, len(dc.openStmt))
for ds := range dc.openStmt {
openStmt = append(openStmt, ds)
}
dc.openStmt = nil
})
for _, ds := range openStmt {
ds.Close()
}
withLock(dc, func() {
dc.finalClosed = true
err = dc.ci.Close()
dc.ci = nil
}()
for si := range dc.openStmt {
si.Close()
}
})
dc.db.mu.Lock()
......@@ -437,12 +439,21 @@ func (dc *driverConn) finalClose() error {
type driverStmt struct {
sync.Locker // the *driverConn
si driver.Stmt
closed bool
closeErr error // return value of previous Close call
}
// Close ensures dirver.Stmt is only closed once any always returns the same
// result.
func (ds *driverStmt) Close() error {
ds.Lock()
defer ds.Unlock()
return ds.si.Close()
if ds.closed {
return ds.closeErr
}
ds.closed = true
ds.closeErr = ds.si.Close()
return ds.closeErr
}
// depSet is a finalCloser's outstanding dependencies
......@@ -933,21 +944,22 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn
// putConnHook is a hook for testing.
var putConnHook func(*DB, *driverConn)
// noteUnusedDriverStatement notes that si is no longer used and should
// noteUnusedDriverStatement notes that ds is no longer used and should
// be closed whenever possible (when c is next not in use), unless c is
// already closed.
func (db *DB) noteUnusedDriverStatement(c *driverConn, si driver.Stmt) {
func (db *DB) noteUnusedDriverStatement(c *driverConn, ds *driverStmt) {
db.mu.Lock()
defer db.mu.Unlock()
if c.inUse {
c.onPut = append(c.onPut, func() {
si.Close()
ds.Close()
})
} else {
c.Lock()
defer c.Unlock()
if !c.finalClosed {
si.Close()
fc := c.finalClosed
c.Unlock()
if !fc {
ds.Close()
}
}
}
......@@ -1084,9 +1096,9 @@ func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrat
if err != nil {
return nil, err
}
var si driver.Stmt
var ds *driverStmt
withLock(dc, func() {
si, err = dc.prepareLocked(ctx, query)
ds, err = dc.prepareLocked(ctx, query)
})
if err != nil {
db.putConn(dc, err)
......@@ -1095,7 +1107,7 @@ func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrat
stmt := &Stmt{
db: db,
query: query,
css: []connStmt{{dc, si}},
css: []connStmt{{dc, ds}},
lastNumClosed: atomic.LoadUint64(&db.numClosed),
}
db.addDep(stmt, stmt)
......@@ -1160,8 +1172,9 @@ func (db *DB) exec(ctx context.Context, query string, args []interface{}, strate
if err != nil {
return nil, err
}
defer withLock(dc, func() { si.Close() })
return resultFromStatement(ctx, driverStmt{dc, si}, args...)
ds := &driverStmt{Locker: dc, si: si}
defer ds.Close()
return resultFromStatement(ctx, ds, args...)
}
// QueryContext executes a query that returns rows, typically a SELECT.
......@@ -1236,12 +1249,10 @@ func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(er
return nil, err
}
ds := driverStmt{dc, si}
ds := &driverStmt{Locker: dc, si: si}
rowsi, err := rowsiFromStatement(ctx, ds, args...)
if err != nil {
withLock(dc, func() {
si.Close()
})
ds.Close()
releaseConn(err)
return nil, err
}
......@@ -1252,7 +1263,7 @@ func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(er
dc: dc,
releaseConn: releaseConn,
rowsi: rowsi,
closeStmt: si,
closeStmt: ds,
}
rows.initContextClose(ctx)
return rows, nil
......@@ -1476,7 +1487,7 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
stmt := &Stmt{
db: tx.db,
tx: tx,
txsi: &driverStmt{
txds: &driverStmt{
Locker: dc,
si: si,
},
......@@ -1530,7 +1541,7 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
txs := &Stmt{
db: tx.db,
tx: tx,
txsi: &driverStmt{
txds: &driverStmt{
Locker: dc,
si: si,
},
......@@ -1591,9 +1602,10 @@ func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}
if err != nil {
return nil, err
}
defer withLock(dc, func() { si.Close() })
ds := &driverStmt{Locker: dc, si: si}
defer ds.Close()
return resultFromStatement(ctx, driverStmt{dc, si}, args...)
return resultFromStatement(ctx, ds, args...)
}
// Exec executes a query that doesn't return rows.
......@@ -1635,7 +1647,7 @@ func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
// connStmt is a prepared statement on a particular connection.
type connStmt struct {
dc *driverConn
si driver.Stmt
ds *driverStmt
}
// Stmt is a prepared statement.
......@@ -1650,7 +1662,7 @@ type Stmt struct {
// If in a transaction, else both nil:
tx *Tx
txsi *driverStmt
txds *driverStmt
mu sync.Mutex // protects the rest of the fields
closed bool
......@@ -1674,7 +1686,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er
var res Result
for i := 0; i < maxBadConnRetries; i++ {
dc, releaseConn, si, err := s.connStmt(ctx)
_, releaseConn, ds, err := s.connStmt(ctx)
if err != nil {
if err == driver.ErrBadConn {
continue
......@@ -1682,7 +1694,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er
return nil, err
}
res, err = resultFromStatement(ctx, driverStmt{dc, si}, args...)
res, err = resultFromStatement(ctx, ds, args...)
releaseConn(err)
if err != driver.ErrBadConn {
return res, err
......@@ -1697,13 +1709,13 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
return s.ExecContext(context.Background(), args...)
}
func driverNumInput(ds driverStmt) int {
func driverNumInput(ds *driverStmt) int {
ds.Lock()
defer ds.Unlock() // in case NumInput panics
return ds.si.NumInput()
}
func resultFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) (Result, error) {
func resultFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (Result, error) {
want := driverNumInput(ds)
// -1 means the driver doesn't know how to count the number of
......@@ -1713,7 +1725,7 @@ func resultFromStatement(ctx context.Context, ds driverStmt, args ...interface{}
return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args))
}
dargs, err := driverArgs(&ds, args)
dargs, err := driverArgs(ds, args)
if err != nil {
return nil, err
}
......@@ -1757,7 +1769,7 @@ func (s *Stmt) removeClosedStmtLocked() {
// connStmt returns a free driver connection on which to execute the
// statement, a function to call to release the connection, and a
// statement bound to that connection.
func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) {
func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(error), ds *driverStmt, err error) {
if err = s.stickyErr; err != nil {
return
}
......@@ -1777,7 +1789,7 @@ func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(e
return
}
releaseConn = func(error) {}
return ci, releaseConn, s.txsi.si, nil
return ci, releaseConn, s.txds, nil
}
s.removeClosedStmtLocked()
......@@ -1792,25 +1804,25 @@ func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(e
for _, v := range s.css {
if v.dc == dc {
s.mu.Unlock()
return dc, dc.releaseConn, v.si, nil
return dc, dc.releaseConn, v.ds, nil
}
}
s.mu.Unlock()
// No luck; we need to prepare the statement on this connection
withLock(dc, func() {
si, err = dc.prepareLocked(ctx, s.query)
ds, err = dc.prepareLocked(ctx, s.query)
})
if err != nil {
s.db.putConn(dc, err)
return nil, nil, nil, err
}
s.mu.Lock()
cs := connStmt{dc, si}
cs := connStmt{dc, ds}
s.css = append(s.css, cs)
s.mu.Unlock()
return dc, dc.releaseConn, si, nil
return dc, dc.releaseConn, ds, nil
}
// QueryContext executes a prepared query statement with the given arguments
......@@ -1821,7 +1833,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er
var rowsi driver.Rows
for i := 0; i < maxBadConnRetries; i++ {
dc, releaseConn, si, err := s.connStmt(ctx)
dc, releaseConn, ds, err := s.connStmt(ctx)
if err != nil {
if err == driver.ErrBadConn {
continue
......@@ -1829,7 +1841,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er
return nil, err
}
rowsi, err = rowsiFromStatement(ctx, driverStmt{dc, si}, args...)
rowsi, err = rowsiFromStatement(ctx, ds, args...)
if err == nil {
// Note: ownership of ci passes to the *Rows, to be freed
// with releaseConn.
......@@ -1861,7 +1873,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
return s.QueryContext(context.Background(), args...)
}
func rowsiFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) (driver.Rows, error) {
func rowsiFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (driver.Rows, error) {
var want int
withLock(ds, func() {
want = ds.si.NumInput()
......@@ -1874,7 +1886,7 @@ func rowsiFromStatement(ctx context.Context, ds driverStmt, args ...interface{})
return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args))
}
dargs, err := driverArgs(&ds, args)
dargs, err := driverArgs(ds, args)
if err != nil {
return nil, err
}
......@@ -1937,12 +1949,11 @@ func (s *Stmt) Close() error {
return nil
}
s.closed = true
s.mu.Unlock()
if s.tx != nil {
defer s.mu.Unlock()
return s.txsi.Close()
return s.txds.Close()
}
s.mu.Unlock()
return s.db.removeDep(s, s)
}
......@@ -1952,8 +1963,8 @@ func (s *Stmt) finalClose() error {
defer s.mu.Unlock()
if s.css != nil {
for _, v := range s.css {
s.db.noteUnusedDriverStatement(v.dc, v.si)
v.dc.removeOpenStmt(v.si)
s.db.noteUnusedDriverStatement(v.dc, v.ds)
v.dc.removeOpenStmt(v.ds)
}
s.css = nil
}
......@@ -1985,7 +1996,7 @@ type Rows struct {
ctxClose chan struct{} // closed when Rows is closed, may be null.
lastcols []driver.Value
lasterr error // non-nil only if closed is true
closeStmt driver.Stmt // if non-nil, statement to Close on close
closeStmt *driverStmt // if non-nil, statement to Close on close
}
func (rs *Rows) initContextClose(ctx context.Context) {
......
......@@ -672,7 +672,7 @@ func TestStatementClose(t *testing.T) {
msg string
}{
{&Stmt{stickyErr: want}, "stickyErr not propagated"},
{&Stmt{tx: &Tx{}, txsi: &driverStmt{&sync.Mutex{}, stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"},
{&Stmt{tx: &Tx{}, txds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"},
}
for _, test := range tests {
if err := test.stmt.Close(); err != want {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment