Commit 1126d148 authored by Daniel Theophanes's avatar Daniel Theophanes

database/sql: ensure all driver interfaces are called under single lock

Russ pointed out in a previous CL golang.org/cl/65731 that not only
was the locking incomplete, previous changes did not correctly
lock driver calls in other sections. After inspecting
driverConn, driverStmt, driverResult, Tx, and Rows structs
where driver interfaces are stored, I discovered a few more places
that failed to lock driver calls. The largest of these
was the parameter type converter "driverArgs".

driverArgs was typically called right before another call to the
driver in a locked region, so I made the entire driverArgs expect
a locked driver mutex and combined the region. This should not
be a problem because the connection is pulled out of the connection
pool either way so there shouldn't be contention.

Fixes #21117

Change-Id: I88d46f74dca25fb11a30f0bf8e79785a73133d23
Reviewed-on: https://go-review.googlesource.com/71433
Run-TryBot: Daniel Theophanes <kardianos@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: 's avatarRuss Cox <rsc@golang.org>
parent 98658212
...@@ -12,7 +12,6 @@ import ( ...@@ -12,7 +12,6 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strconv" "strconv"
"sync"
"time" "time"
"unicode" "unicode"
"unicode/utf8" "unicode/utf8"
...@@ -38,17 +37,10 @@ func validateNamedValueName(name string) error { ...@@ -38,17 +37,10 @@ func validateNamedValueName(name string) error {
return fmt.Errorf("name %q does not begin with a letter", name) return fmt.Errorf("name %q does not begin with a letter", name)
} }
func driverNumInput(ds *driverStmt) int {
ds.Lock()
defer ds.Unlock() // in case NumInput panics
return ds.si.NumInput()
}
// ccChecker wraps the driver.ColumnConverter and allows it to be used // ccChecker wraps the driver.ColumnConverter and allows it to be used
// as if it were a NamedValueChecker. If the driver ColumnConverter // as if it were a NamedValueChecker. If the driver ColumnConverter
// is not present then the NamedValueChecker will return driver.ErrSkip. // is not present then the NamedValueChecker will return driver.ErrSkip.
type ccChecker struct { type ccChecker struct {
sync.Locker
cci driver.ColumnConverter cci driver.ColumnConverter
want int want int
} }
...@@ -88,9 +80,7 @@ func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error { ...@@ -88,9 +80,7 @@ func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error {
// same error. // same error.
var err error var err error
arg := nv.Value arg := nv.Value
c.Lock()
nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg) nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg)
c.Unlock()
if err != nil { if err != nil {
return err return err
} }
...@@ -112,7 +102,7 @@ func defaultCheckNamedValue(nv *driver.NamedValue) (err error) { ...@@ -112,7 +102,7 @@ func defaultCheckNamedValue(nv *driver.NamedValue) (err error) {
// Stmt.Query into driver Values. // Stmt.Query into driver Values.
// //
// The statement ds may be nil, if no statement is available. // The statement ds may be nil, if no statement is available.
func driverArgs(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) { func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) {
nvargs := make([]driver.NamedValue, len(args)) nvargs := make([]driver.NamedValue, len(args))
// -1 means the driver doesn't know how to count the number of // -1 means the driver doesn't know how to count the number of
...@@ -124,8 +114,7 @@ func driverArgs(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.Na ...@@ -124,8 +114,7 @@ func driverArgs(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.Na
var cc ccChecker var cc ccChecker
if ds != nil { if ds != nil {
si = ds.si si = ds.si
want = driverNumInput(ds) want = ds.si.NumInput()
cc.Locker = ds.Locker
cc.want = want cc.want = want
} }
......
...@@ -481,7 +481,7 @@ func TestDriverArgs(t *testing.T) { ...@@ -481,7 +481,7 @@ func TestDriverArgs(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
ds := &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{nil}} ds := &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{nil}}
got, err := driverArgs(nil, ds, tt.args) got, err := driverArgsConnLocked(nil, ds, tt.args)
if err != nil { if err != nil {
t.Errorf("test[%d]: %v", i, err) t.Errorf("test[%d]: %v", i, err)
continue continue
......
...@@ -1005,6 +1005,7 @@ type rowsCursor struct { ...@@ -1005,6 +1005,7 @@ type rowsCursor struct {
} }
func (rc *rowsCursor) touchMem() { func (rc *rowsCursor) touchMem() {
rc.parentMem.touchMem()
rc.line++ rc.line++
} }
......
...@@ -1368,12 +1368,12 @@ func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), q ...@@ -1368,12 +1368,12 @@ func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), q
} }
if ok { if ok {
var nvdargs []driver.NamedValue var nvdargs []driver.NamedValue
nvdargs, err = driverArgs(dc.ci, nil, args)
if err != nil {
return nil, err
}
var resi driver.Result var resi driver.Result
withLock(dc, func() { withLock(dc, func() {
nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
if err != nil {
return
}
resi, err = ctxDriverExec(ctx, execerCtx, execer, query, nvdargs) resi, err = ctxDriverExec(ctx, execerCtx, execer, query, nvdargs)
}) })
if err != driver.ErrSkip { if err != driver.ErrSkip {
...@@ -1439,13 +1439,14 @@ func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn fu ...@@ -1439,13 +1439,14 @@ func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn fu
queryer, ok = dc.ci.(driver.Queryer) queryer, ok = dc.ci.(driver.Queryer)
} }
if ok { if ok {
nvdargs, err := driverArgs(dc.ci, nil, args) var nvdargs []driver.NamedValue
if err != nil {
releaseConn(err)
return nil, err
}
var rowsi driver.Rows var rowsi driver.Rows
var err error
withLock(dc, func() { withLock(dc, func() {
nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
if err != nil {
return
}
rowsi, err = ctxDriverQuery(ctx, queryerCtx, queryer, query, nvdargs) rowsi, err = ctxDriverQuery(ctx, queryerCtx, queryer, query, nvdargs)
}) })
if err != driver.ErrSkip { if err != driver.ErrSkip {
...@@ -2034,11 +2035,14 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { ...@@ -2034,11 +2035,14 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
stmt.mu.Unlock() stmt.mu.Unlock()
if si == nil { if si == nil {
cs, err := stmt.prepareOnConnLocked(ctx, dc) withLock(dc, func() {
var ds *driverStmt
ds, err = stmt.prepareOnConnLocked(ctx, dc)
si = ds.si
})
if err != nil { if err != nil {
return &Stmt{stickyErr: err} return &Stmt{stickyErr: err}
} }
si = cs.si
} }
parentStmt = stmt parentStmt = stmt
} }
...@@ -2230,14 +2234,14 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) { ...@@ -2230,14 +2234,14 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
} }
func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (Result, error) { func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (Result, error) {
dargs, err := driverArgs(ci, ds, args) ds.Lock()
defer ds.Unlock()
dargs, err := driverArgsConnLocked(ci, ds, args)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ds.Lock()
defer ds.Unlock()
resi, err := ctxDriverStmtExec(ctx, ds.si, dargs) resi, err := ctxDriverStmtExec(ctx, ds.si, dargs)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -2401,10 +2405,10 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { ...@@ -2401,10 +2405,10 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
} }
func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (driver.Rows, error) { func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (driver.Rows, error) {
var want int ds.Lock()
withLock(ds, func() { defer ds.Unlock()
want = ds.si.NumInput()
}) want := ds.si.NumInput()
// -1 means the driver doesn't know how to count the number of // -1 means the driver doesn't know how to count the number of
// placeholders, so we won't sanity check input here and instead let the // placeholders, so we won't sanity check input here and instead let the
...@@ -2413,14 +2417,11 @@ func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, arg ...@@ -2413,14 +2417,11 @@ func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, arg
return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args)) return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args))
} }
dargs, err := driverArgs(ci, ds, args) dargs, err := driverArgsConnLocked(ci, ds, args)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ds.Lock()
defer ds.Unlock()
rowsi, err := ctxDriverStmtQuery(ctx, ds.si, dargs) rowsi, err := ctxDriverStmtQuery(ctx, ds.si, dargs)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -2583,9 +2584,16 @@ func (rs *Rows) nextLocked() (doClose, ok bool) { ...@@ -2583,9 +2584,16 @@ func (rs *Rows) nextLocked() (doClose, ok bool) {
if rs.closed { if rs.closed {
return false, false return false, false
} }
// Lock the driver connection before calling the driver interface
// rowsi to prevent a Tx from rolling back the connection at the same time.
rs.dc.Lock()
defer rs.dc.Unlock()
if rs.lastcols == nil { if rs.lastcols == nil {
rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns())) rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
} }
rs.lasterr = rs.rowsi.Next(rs.lastcols) rs.lasterr = rs.rowsi.Next(rs.lastcols)
if rs.lasterr != nil { if rs.lasterr != nil {
// Close the connection if there is a driver error. // Close the connection if there is a driver error.
...@@ -2635,6 +2643,12 @@ func (rs *Rows) NextResultSet() bool { ...@@ -2635,6 +2643,12 @@ func (rs *Rows) NextResultSet() bool {
doClose = true doClose = true
return false return false
} }
// Lock the driver connection before calling the driver interface
// rowsi to prevent a Tx from rolling back the connection at the same time.
rs.dc.Lock()
defer rs.dc.Unlock()
rs.lasterr = nextResultSet.NextResultSet() rs.lasterr = nextResultSet.NextResultSet()
if rs.lasterr != nil { if rs.lasterr != nil {
doClose = true doClose = true
...@@ -2666,6 +2680,9 @@ func (rs *Rows) Columns() ([]string, error) { ...@@ -2666,6 +2680,9 @@ func (rs *Rows) Columns() ([]string, error) {
if rs.rowsi == nil { if rs.rowsi == nil {
return nil, errors.New("sql: no Rows available") return nil, errors.New("sql: no Rows available")
} }
rs.dc.Lock()
defer rs.dc.Unlock()
return rs.rowsi.Columns(), nil return rs.rowsi.Columns(), nil
} }
...@@ -2680,7 +2697,10 @@ func (rs *Rows) ColumnTypes() ([]*ColumnType, error) { ...@@ -2680,7 +2697,10 @@ func (rs *Rows) ColumnTypes() ([]*ColumnType, error) {
if rs.rowsi == nil { if rs.rowsi == nil {
return nil, errors.New("sql: no Rows available") return nil, errors.New("sql: no Rows available")
} }
return rowsColumnInfoSetup(rs.rowsi), nil rs.dc.Lock()
defer rs.dc.Unlock()
return rowsColumnInfoSetupConnLocked(rs.rowsi), nil
} }
// ColumnType contains the name and type of a column. // ColumnType contains the name and type of a column.
...@@ -2741,7 +2761,7 @@ func (ci *ColumnType) DatabaseTypeName() string { ...@@ -2741,7 +2761,7 @@ func (ci *ColumnType) DatabaseTypeName() string {
return ci.databaseType return ci.databaseType
} }
func rowsColumnInfoSetup(rowsi driver.Rows) []*ColumnType { func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType {
names := rowsi.Columns() names := rowsi.Columns()
list := make([]*ColumnType, len(names)) list := make([]*ColumnType, len(names))
......
...@@ -3157,6 +3157,9 @@ func TestIssue6081(t *testing.T) { ...@@ -3157,6 +3157,9 @@ func TestIssue6081(t *testing.T) {
// In the test, a context is canceled while the query is in process so // In the test, a context is canceled while the query is in process so
// the internal rollback will run concurrently with the explicitly called // the internal rollback will run concurrently with the explicitly called
// Tx.Rollback. // Tx.Rollback.
//
// The addition of calling rows.Next also tests
// Issue 21117.
func TestIssue18429(t *testing.T) { func TestIssue18429(t *testing.T) {
db := newTestDB(t, "people") db := newTestDB(t, "people")
defer closeDB(t, db) defer closeDB(t, db)
...@@ -3189,6 +3192,9 @@ func TestIssue18429(t *testing.T) { ...@@ -3189,6 +3192,9 @@ func TestIssue18429(t *testing.T) {
// reported. // reported.
rows, _ := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|") rows, _ := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|")
if rows != nil { if rows != nil {
// Call Next to test Issue 21117 and check for races.
for rows.Next() {
}
rows.Close() rows.Close()
} }
// This call will race with the context cancel rollback to complete // This call will race with the context cancel rollback to complete
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment