Commit c3954dd5 authored by Gwenael Treguier's avatar Gwenael Treguier Committed by Brad Fitzpatrick

database/sql: ensure Stmts are correctly closed.

To make sure that there is no resource leak,
I suggest to fix the 'fakedb' driver such as it fails when any
Stmt is not closed.
First, add a check in fakeConn.Close().
Then, fix all missing Stmt.Close()/Rows.Close().
I am not sure that the strategy choose in fakeConn.Prepare/prepare* is ok.
The weak point in this patch is the change in Tx.Query:
  - Tests pass without this change,
  - I found it by manually analyzing the code,
  - I just try to make Tx.Query look like DB.Query.

R=golang-dev, bradfitz
CC=golang-dev
https://golang.org/cl/5759050
parent 959d0c7a
...@@ -214,6 +214,9 @@ func (c *fakeConn) Close() error { ...@@ -214,6 +214,9 @@ func (c *fakeConn) Close() error {
if c.db == nil { if c.db == nil {
return errors.New("can't close fakeConn; already closed") return errors.New("can't close fakeConn; already closed")
} }
if c.stmtsMade > c.stmtsClosed {
return errors.New("can't close; dangling statement(s)")
}
c.db = nil c.db = nil
return nil return nil
} }
...@@ -250,6 +253,7 @@ func errf(msg string, args ...interface{}) error { ...@@ -250,6 +253,7 @@ func errf(msg string, args ...interface{}) error {
// just a limitation for fakedb) // just a limitation for fakedb)
func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) { func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
if len(parts) != 3 { if len(parts) != 3 {
stmt.Close()
return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
} }
stmt.table = parts[0] stmt.table = parts[0]
...@@ -260,14 +264,17 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e ...@@ -260,14 +264,17 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e
} }
nameVal := strings.Split(colspec, "=") nameVal := strings.Split(colspec, "=")
if len(nameVal) != 2 { if len(nameVal) != 2 {
stmt.Close()
return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
} }
column, value := nameVal[0], nameVal[1] column, value := nameVal[0], nameVal[1]
_, ok := c.db.columnType(stmt.table, column) _, ok := c.db.columnType(stmt.table, column)
if !ok { if !ok {
stmt.Close()
return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
} }
if value != "?" { if value != "?" {
stmt.Close()
return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
stmt.table, column) stmt.table, column)
} }
...@@ -280,12 +287,14 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e ...@@ -280,12 +287,14 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e
// parts are table|col=type,col2=type2 // parts are table|col=type,col2=type2
func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) { func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
if len(parts) != 2 { if len(parts) != 2 {
stmt.Close()
return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
} }
stmt.table = parts[0] stmt.table = parts[0]
for n, colspec := range strings.Split(parts[1], ",") { for n, colspec := range strings.Split(parts[1], ",") {
nameType := strings.Split(colspec, "=") nameType := strings.Split(colspec, "=")
if len(nameType) != 2 { if len(nameType) != 2 {
stmt.Close()
return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
} }
stmt.colName = append(stmt.colName, nameType[0]) stmt.colName = append(stmt.colName, nameType[0])
...@@ -297,17 +306,20 @@ func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, e ...@@ -297,17 +306,20 @@ func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, e
// parts are table|col=?,col2=val // parts are table|col=?,col2=val
func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) { func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
if len(parts) != 2 { if len(parts) != 2 {
stmt.Close()
return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
} }
stmt.table = parts[0] stmt.table = parts[0]
for n, colspec := range strings.Split(parts[1], ",") { for n, colspec := range strings.Split(parts[1], ",") {
nameVal := strings.Split(colspec, "=") nameVal := strings.Split(colspec, "=")
if len(nameVal) != 2 { if len(nameVal) != 2 {
stmt.Close()
return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
} }
column, value := nameVal[0], nameVal[1] column, value := nameVal[0], nameVal[1]
ctype, ok := c.db.columnType(stmt.table, column) ctype, ok := c.db.columnType(stmt.table, column)
if !ok { if !ok {
stmt.Close()
return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column) return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
} }
stmt.colName = append(stmt.colName, column) stmt.colName = append(stmt.colName, column)
...@@ -323,10 +335,12 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e ...@@ -323,10 +335,12 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e
case "int32": case "int32":
i, err := strconv.Atoi(value) i, err := strconv.Atoi(value)
if err != nil { if err != nil {
stmt.Close()
return nil, errf("invalid conversion to int32 from %q", value) return nil, errf("invalid conversion to int32 from %q", value)
} }
subsetVal = int64(i) // int64 is a subset type, but not int32 subsetVal = int64(i) // int64 is a subset type, but not int32
default: default:
stmt.Close()
return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
} }
stmt.colValue = append(stmt.colValue, subsetVal) stmt.colValue = append(stmt.colValue, subsetVal)
...@@ -362,6 +376,7 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { ...@@ -362,6 +376,7 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
case "INSERT": case "INSERT":
return c.prepareInsert(stmt, parts) return c.prepareInsert(stmt, parts)
default: default:
stmt.Close()
return nil, errf("unsupported command type %q", cmd) return nil, errf("unsupported command type %q", cmd)
} }
return stmt, nil return stmt, nil
......
...@@ -612,9 +612,11 @@ func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) { ...@@ -612,9 +612,11 @@ func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
return nil, err return nil, err
} }
rows, err := stmt.Query(args...) rows, err := stmt.Query(args...)
if err == nil { if err != nil {
rows.closeStmt = stmt stmt.Close()
return nil, err
} }
rows.closeStmt = stmt
return rows, err return rows, err
} }
......
...@@ -251,6 +251,7 @@ func TestStatementQueryRow(t *testing.T) { ...@@ -251,6 +251,7 @@ func TestStatementQueryRow(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Prepare: %v", err) t.Fatalf("Prepare: %v", err)
} }
defer stmt.Close()
var age int var age int
for n, tt := range []struct { for n, tt := range []struct {
name string name string
...@@ -291,6 +292,7 @@ func TestExec(t *testing.T) { ...@@ -291,6 +292,7 @@ func TestExec(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("Stmt, err = %v, %v", stmt, err) t.Errorf("Stmt, err = %v, %v", stmt, err)
} }
defer stmt.Close()
type execTest struct { type execTest struct {
args []interface{} args []interface{}
...@@ -332,11 +334,14 @@ func TestTxStmt(t *testing.T) { ...@@ -332,11 +334,14 @@ func TestTxStmt(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Stmt, err = %v, %v", stmt, err) t.Fatalf("Stmt, err = %v, %v", stmt, err)
} }
defer stmt.Close()
tx, err := db.Begin() tx, err := db.Begin()
if err != nil { if err != nil {
t.Fatalf("Begin = %v", err) t.Fatalf("Begin = %v", err)
} }
_, err = tx.Stmt(stmt).Exec("Bobby", 7) txs := tx.Stmt(stmt)
defer txs.Close()
_, err = txs.Exec("Bobby", 7)
if err != nil { if err != nil {
t.Fatalf("Exec = %v", err) t.Fatalf("Exec = %v", err)
} }
...@@ -365,6 +370,7 @@ func TestTxQuery(t *testing.T) { ...@@ -365,6 +370,7 @@ func TestTxQuery(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer r.Close()
if !r.Next() { if !r.Next() {
if r.Err() != nil { if r.Err() != nil {
...@@ -561,6 +567,7 @@ func nullTestRun(t *testing.T, spec nullTestSpec) { ...@@ -561,6 +567,7 @@ func nullTestRun(t *testing.T, spec nullTestSpec) {
if err != nil { if err != nil {
t.Fatalf("prepare: %v", err) t.Fatalf("prepare: %v", err)
} }
defer stmt.Close()
if _, err := stmt.Exec(3, "chris", spec.rows[2].nullParam, spec.rows[2].notNullParam); err != nil { if _, err := stmt.Exec(3, "chris", spec.rows[2].nullParam, spec.rows[2].notNullParam); err != nil {
t.Errorf("exec insert chris: %v", err) t.Errorf("exec insert chris: %v", err)
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment