Commit 5f739d9d authored by Marko Tiikkaja's avatar Marko Tiikkaja Committed by Brad Fitzpatrick

database/sql: Close per-tx prepared statements when the associated tx ends

LGTM=bradfitz
R=golang-codereviews, bradfitz, mattn.jp
CC=golang-codereviews
https://golang.org/cl/131650043
parent 93e5cc22
......@@ -1043,6 +1043,13 @@ type Tx struct {
// or Rollback. once done, all operations fail with
// ErrTxDone.
done bool
// All Stmts prepared for this transaction. These will be closed after the
// transaction has been committed or rolled back.
stmts struct {
sync.Mutex
v []*Stmt
}
}
var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back")
......@@ -1064,6 +1071,15 @@ func (tx *Tx) grabConn() (*driverConn, error) {
return tx.dc, nil
}
// Closes all Stmts prepared for this transaction.
func (tx *Tx) closePrepared() {
tx.stmts.Lock()
for _, stmt := range tx.stmts.v {
stmt.Close()
}
tx.stmts.Unlock()
}
// Commit commits the transaction.
func (tx *Tx) Commit() error {
if tx.done {
......@@ -1071,8 +1087,12 @@ func (tx *Tx) Commit() error {
}
defer tx.close()
tx.dc.Lock()
defer tx.dc.Unlock()
return tx.txi.Commit()
err := tx.txi.Commit()
tx.dc.Unlock()
if err != driver.ErrBadConn {
tx.closePrepared()
}
return err
}
// Rollback aborts the transaction.
......@@ -1082,8 +1102,12 @@ func (tx *Tx) Rollback() error {
}
defer tx.close()
tx.dc.Lock()
defer tx.dc.Unlock()
return tx.txi.Rollback()
err := tx.txi.Rollback()
tx.dc.Unlock()
if err != driver.ErrBadConn {
tx.closePrepared()
}
return err
}
// Prepare creates a prepared statement for use within a transaction.
......@@ -1127,6 +1151,9 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
},
query: query,
}
tx.stmts.Lock()
tx.stmts.v = append(tx.stmts.v, stmt)
tx.stmts.Unlock()
return stmt, nil
}
......@@ -1155,7 +1182,7 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
dc.Lock()
si, err := dc.ci.Prepare(stmt.query)
dc.Unlock()
return &Stmt{
txs := &Stmt{
db: tx.db,
tx: tx,
txsi: &driverStmt{
......@@ -1165,6 +1192,10 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
query: stmt.query,
stickyErr: err,
}
tx.stmts.Lock()
tx.stmts.v = append(tx.stmts.v, txs)
tx.stmts.Unlock()
return txs
}
// Exec executes a query that doesn't return rows.
......
......@@ -441,6 +441,33 @@ func TestExec(t *testing.T) {
}
}
func TestTxPrepare(t *testing.T) {
db := newTestDB(t, "")
defer closeDB(t, db)
exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
tx, err := db.Begin()
if err != nil {
t.Fatalf("Begin = %v", err)
}
stmt, err := tx.Prepare("INSERT|t1|name=?,age=?")
if err != nil {
t.Fatalf("Stmt, err = %v, %v", stmt, err)
}
defer stmt.Close()
_, err = stmt.Exec("Bobby", 7)
if err != nil {
t.Fatalf("Exec = %v", err)
}
err = tx.Commit()
if err != nil {
t.Fatalf("Commit = %v", err)
}
// Commit() should have closed the statement
if !stmt.closed {
t.Fatal("Stmt not closed after Commit")
}
}
func TestTxStmt(t *testing.T) {
db := newTestDB(t, "")
defer closeDB(t, db)
......@@ -464,6 +491,10 @@ func TestTxStmt(t *testing.T) {
if err != nil {
t.Fatalf("Commit = %v", err)
}
// Commit() should have closed the statement
if !txs.closed {
t.Fatal("Stmt not closed after Commit")
}
}
// Issue: http://golang.org/issue/2784
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment