Commit f28c8fba authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

database/sql: associate a mutex with each driver interface

The database/sql/driver docs make this promise:

   "Conn is a connection to a database. It is not used
   concurrently by multiple goroutines."

That promises exists as part of database/sql's overall
goal of making drivers relatively easy to write.

So far this promise has been kept without the use of locks by
being careful in the database/sql package, but sometimes too
careful. (cf. golang.org/issue/3857)

The CL associates a Mutex with each driver.Conn, and with the
interface value progeny thereof. (e.g. each driver.Tx,
driver.Stmt, driver.Rows, driver.Result, etc) Then whenever
those interface values are used, the Locker is locked.

This CL should be a no-op (aside from some new Lock/Unlock
pairs) and doesn't attempt to fix Issue 3857 or Issue 4459,
but should make it much easier in a subsequent CL.

Update #3857

R=golang-dev, adg
CC=golang-dev
https://golang.org/cl/7803043
parent eb80431b
...@@ -19,9 +19,13 @@ var errNilPtr = errors.New("destination pointer is nil") // embedded in descript ...@@ -19,9 +19,13 @@ var errNilPtr = errors.New("destination pointer is nil") // embedded in descript
// driverArgs converts arguments from callers of Stmt.Exec and // driverArgs converts arguments from callers of Stmt.Exec and
// Stmt.Query into driver Values. // Stmt.Query into driver Values.
// //
// The statement si may be nil, if no statement is available. // The statement ds may be nil, if no statement is available.
func driverArgs(si driver.Stmt, args []interface{}) ([]driver.Value, error) { func driverArgs(ds *driverStmt, args []interface{}) ([]driver.Value, error) {
dargs := make([]driver.Value, len(args)) dargs := make([]driver.Value, len(args))
var si driver.Stmt
if ds != nil {
si = ds.si
}
cc, ok := si.(driver.ColumnConverter) cc, ok := si.(driver.ColumnConverter)
// Normal path, for a driver.Stmt that is not a ColumnConverter. // Normal path, for a driver.Stmt that is not a ColumnConverter.
...@@ -60,7 +64,9 @@ func driverArgs(si driver.Stmt, args []interface{}) ([]driver.Value, error) { ...@@ -60,7 +64,9 @@ func driverArgs(si driver.Stmt, args []interface{}) ([]driver.Value, error) {
// column before going across the network to get the // column before going across the network to get the
// same error. // same error.
var err error var err error
ds.Lock()
dargs[n], err = cc.ColumnConverter(n).ConvertValue(arg) dargs[n], err = cc.ColumnConverter(n).ConvertValue(arg)
ds.Unlock()
if err != nil { if err != nil {
return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n, err) return nil, fmt.Errorf("sql: converting argument #%d's type: %v", n, err)
} }
......
This diff is collapsed.
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
package sql package sql
import ( import (
"database/sql/driver"
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
...@@ -16,10 +15,10 @@ import ( ...@@ -16,10 +15,10 @@ import (
func init() { func init() {
type dbConn struct { type dbConn struct {
db *DB db *DB
c driver.Conn c *driverConn
} }
freedFrom := make(map[dbConn]string) freedFrom := make(map[dbConn]string)
putConnHook = func(db *DB, c driver.Conn) { putConnHook = func(db *DB, c *driverConn) {
for _, oc := range db.freeConn { for _, oc := range db.freeConn {
if oc == c { if oc == c {
// print before panic, as panic may get lost due to conflicting panic // print before panic, as panic may get lost due to conflicting panic
...@@ -78,7 +77,7 @@ func numPrepares(t *testing.T, db *DB) int { ...@@ -78,7 +77,7 @@ func numPrepares(t *testing.T, db *DB) int {
if n := len(db.freeConn); n != 1 { if n := len(db.freeConn); n != 1 {
t.Fatalf("free conns = %d; want 1", n) t.Fatalf("free conns = %d; want 1", n)
} }
return db.freeConn[0].(*fakeConn).numPrepare return db.freeConn[0].ci.(*fakeConn).numPrepare
} }
func TestQuery(t *testing.T) { func TestQuery(t *testing.T) {
...@@ -576,7 +575,7 @@ func TestQueryRowClosingStmt(t *testing.T) { ...@@ -576,7 +575,7 @@ func TestQueryRowClosingStmt(t *testing.T) {
if len(db.freeConn) != 1 { if len(db.freeConn) != 1 {
t.Fatalf("expected 1 free conn") t.Fatalf("expected 1 free conn")
} }
fakeConn := db.freeConn[0].(*fakeConn) fakeConn := db.freeConn[0].ci.(*fakeConn)
if made, closed := fakeConn.stmtsMade, fakeConn.stmtsClosed; made != closed { if made, closed := fakeConn.stmtsMade, fakeConn.stmtsClosed; made != closed {
t.Errorf("statement close mismatch: made %d, closed %d", made, closed) t.Errorf("statement close mismatch: made %d, closed %d", made, closed)
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment