Commit 99ed71a0 authored by Russ Cox's avatar Russ Cox

database/sql: guard against panics in driver.Stmt implementation

For #13677, but there is more to do.

Change-Id: Id1af999dc972d07cdfc771e5855a1a7dca47ca96
Reviewed-on: https://go-review.googlesource.com/18046Reviewed-by: 's avatarBrad Fitzpatrick <bradfitz@golang.org>
parent 81adfa50
...@@ -33,6 +33,9 @@ var _ = log.Printf ...@@ -33,6 +33,9 @@ var _ = log.Printf
// INSERT|<tablename>|col=val,col2=val2,col3=? // INSERT|<tablename>|col=val,col2=val2,col3=?
// SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=? // SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
// //
// Any of these can be preceded by PANIC|<method>|, to cause the
// named method on fakeStmt to panic.
//
// When opening a fakeDriver's database, it starts empty with no // When opening a fakeDriver's database, it starts empty with no
// tables. All tables and data are stored in memory only. // tables. All tables and data are stored in memory only.
type fakeDriver struct { type fakeDriver struct {
...@@ -111,6 +114,7 @@ type fakeStmt struct { ...@@ -111,6 +114,7 @@ type fakeStmt struct {
cmd string cmd string
table string table string
panic string
closed bool closed bool
...@@ -499,9 +503,15 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { ...@@ -499,9 +503,15 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
if len(parts) < 1 { if len(parts) < 1 {
return nil, errf("empty query") return nil, errf("empty query")
} }
stmt := &fakeStmt{q: query, c: c}
if len(parts) >= 3 && parts[0] == "PANIC" {
stmt.panic = parts[1]
parts = parts[2:]
}
cmd := parts[0] cmd := parts[0]
stmt.cmd = cmd
parts = parts[1:] parts = parts[1:]
stmt := &fakeStmt{q: query, c: c, cmd: cmd}
c.incrStat(&c.stmtsMade) c.incrStat(&c.stmtsMade)
switch cmd { switch cmd {
case "WIPE": case "WIPE":
...@@ -524,6 +534,9 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { ...@@ -524,6 +534,9 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
} }
func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
if s.panic == "ColumnConverter" {
panic(s.panic)
}
if len(s.placeholderConverter) == 0 { if len(s.placeholderConverter) == 0 {
return driver.DefaultParameterConverter return driver.DefaultParameterConverter
} }
...@@ -531,6 +544,9 @@ func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { ...@@ -531,6 +544,9 @@ func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
} }
func (s *fakeStmt) Close() error { func (s *fakeStmt) Close() error {
if s.panic == "Close" {
panic(s.panic)
}
if s.c == nil { if s.c == nil {
panic("nil conn in fakeStmt.Close") panic("nil conn in fakeStmt.Close")
} }
...@@ -550,6 +566,9 @@ var errClosed = errors.New("fakedb: statement has been closed") ...@@ -550,6 +566,9 @@ var errClosed = errors.New("fakedb: statement has been closed")
var hookExecBadConn func() bool var hookExecBadConn func() bool
func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
if s.panic == "Exec" {
panic(s.panic)
}
if s.closed { if s.closed {
return nil, errClosed return nil, errClosed
} }
...@@ -634,6 +653,9 @@ func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result ...@@ -634,6 +653,9 @@ func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result
var hookQueryBadConn func() bool var hookQueryBadConn func() bool
func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
if s.panic == "Query" {
panic(s.panic)
}
if s.closed { if s.closed {
return nil, errClosed return nil, errClosed
} }
...@@ -716,6 +738,9 @@ rows: ...@@ -716,6 +738,9 @@ rows:
} }
func (s *fakeStmt) NumInput() int { func (s *fakeStmt) NumInput() int {
if s.panic == "NumInput" {
panic(s.panic)
}
return s.placeholders return s.placeholders
} }
......
...@@ -1477,10 +1477,14 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) { ...@@ -1477,10 +1477,14 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) { func driverNumInput(ds driverStmt) int {
ds.Lock() ds.Lock()
want := ds.si.NumInput() defer ds.Unlock() // in case NumInput panics
ds.Unlock() return ds.si.NumInput()
}
func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) {
want := driverNumInput(ds)
// -1 means the driver doesn't know how to count the number of // -1 means the driver doesn't know how to count the number of
// placeholders, so we won't sanity check input here and instead let the // placeholders, so we won't sanity check input here and instead let the
...@@ -1495,8 +1499,8 @@ func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) { ...@@ -1495,8 +1499,8 @@ func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) {
} }
ds.Lock() ds.Lock()
defer ds.Unlock()
resi, err := ds.si.Exec(dargs) resi, err := ds.si.Exec(dargs)
ds.Unlock()
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -1927,6 +1931,6 @@ func stack() string { ...@@ -1927,6 +1931,6 @@ func stack() string {
// withLock runs while holding lk. // withLock runs while holding lk.
func withLock(lk sync.Locker, fn func()) { func withLock(lk sync.Locker, fn func()) {
lk.Lock() lk.Lock()
defer lk.Unlock() // in case fn panics
fn() fn()
lk.Unlock()
} }
...@@ -68,6 +68,46 @@ func newTestDB(t testing.TB, name string) *DB { ...@@ -68,6 +68,46 @@ func newTestDB(t testing.TB, name string) *DB {
return db return db
} }
func TestDriverPanic(t *testing.T) {
// Test that if driver panics, database/sql does not deadlock.
db, err := Open("test", fakeDBName)
if err != nil {
t.Fatalf("Open: %v", err)
}
expectPanic := func(name string, f func()) {
defer func() {
err := recover()
if err == nil {
t.Fatalf("%s did not panic", name)
}
}()
f()
}
expectPanic("Exec Exec", func() { db.Exec("PANIC|Exec|WIPE") })
exec(t, db, "WIPE") // check not deadlocked
expectPanic("Exec NumInput", func() { db.Exec("PANIC|NumInput|WIPE") })
exec(t, db, "WIPE") // check not deadlocked
expectPanic("Exec Close", func() { db.Exec("PANIC|Close|WIPE") })
exec(t, db, "WIPE") // check not deadlocked
exec(t, db, "PANIC|Query|WIPE") // should run successfully: Exec does not call Query
exec(t, db, "WIPE") // check not deadlocked
exec(t, db, "CREATE|people|name=string,age=int32,photo=blob,dead=bool,bdate=datetime")
expectPanic("Query Query", func() { db.Query("PANIC|Query|SELECT|people|age,name|") })
expectPanic("Query NumInput", func() { db.Query("PANIC|NumInput|SELECT|people|age,name|") })
expectPanic("Query Close", func() {
rows, err := db.Query("PANIC|Close|SELECT|people|age,name|")
if err != nil {
t.Fatal(err)
}
rows.Close()
})
db.Query("PANIC|Exec|SELECT|people|age,name|") // should run successfully: Query does not call Exec
exec(t, db, "WIPE") // check not deadlocked
}
func exec(t testing.TB, db *DB, query string, args ...interface{}) { func exec(t testing.TB, db *DB, query string, args ...interface{}) {
_, err := db.Exec(query, args...) _, err := db.Exec(query, args...)
if err != nil { if err != nil {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment