Commit b0d592c3 authored by Daniel Theophanes's avatar Daniel Theophanes

database/sql: prevent race on Rows close with Tx Rollback

In addition to adding a guard to the Rows close, add a var
in the fakeConn that gets read and written to on each
operation, simulating writing or reading from the server.

TestConcurrency/TxStmt* tests have been commented out
as they now fail after checking for races on the fakeConn.
See issue #20646 for more information.

Fixes #20622

Change-Id: I80b36ea33d776e5b4968be1683ff8c61728ee1ea
Reviewed-on: https://go-review.googlesource.com/45275
Run-TryBot: Daniel Theophanes <kardianos@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: 's avatarBrad Fitzpatrick <bradfitz@golang.org>
parent 38201918
...@@ -89,6 +89,10 @@ type fakeConn struct { ...@@ -89,6 +89,10 @@ type fakeConn struct {
currTx *fakeTx currTx *fakeTx
// Every operation writes to line to enable the race detector
// check for data races.
line int64
// Stats for tests: // Stats for tests:
mu sync.Mutex mu sync.Mutex
stmtsMade int stmtsMade int
...@@ -299,6 +303,7 @@ func (c *fakeConn) Begin() (driver.Tx, error) { ...@@ -299,6 +303,7 @@ func (c *fakeConn) Begin() (driver.Tx, error) {
if c.currTx != nil { if c.currTx != nil {
return nil, errors.New("already in a transaction") return nil, errors.New("already in a transaction")
} }
c.line++
c.currTx = &fakeTx{c: c} c.currTx = &fakeTx{c: c}
return c.currTx, nil return c.currTx, nil
} }
...@@ -340,6 +345,7 @@ func (c *fakeConn) Close() (err error) { ...@@ -340,6 +345,7 @@ func (c *fakeConn) Close() (err error) {
drv.mu.Unlock() drv.mu.Unlock()
} }
}() }()
c.line++
if c.currTx != nil { if c.currTx != nil {
return errors.New("can't close fakeConn; in a Transaction") return errors.New("can't close fakeConn; in a Transaction")
} }
...@@ -527,6 +533,7 @@ func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stm ...@@ -527,6 +533,7 @@ func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stm
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
c.line++
var firstStmt, prev *fakeStmt var firstStmt, prev *fakeStmt
for _, query := range strings.Split(query, ";") { for _, query := range strings.Split(query, ";") {
parts := strings.Split(query, "|") parts := strings.Split(query, "|")
...@@ -615,6 +622,7 @@ func (s *fakeStmt) Close() error { ...@@ -615,6 +622,7 @@ func (s *fakeStmt) Close() error {
if s.c.db == nil { if s.c.db == nil {
panic("in fakeStmt.Close, conn's db is nil (already closed)") panic("in fakeStmt.Close, conn's db is nil (already closed)")
} }
s.c.line++
if !s.closed { if !s.closed {
s.c.incrStat(&s.c.stmtsClosed) s.c.incrStat(&s.c.stmtsClosed)
s.closed = true s.closed = true
...@@ -649,6 +657,7 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d ...@@ -649,6 +657,7 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.c.line++
if s.wait > 0 { if s.wait > 0 {
time.Sleep(s.wait) time.Sleep(s.wait)
...@@ -761,6 +770,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) ( ...@@ -761,6 +770,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
return nil, err return nil, err
} }
s.c.line++
db := s.c.db db := s.c.db
if len(args) != s.placeholders { if len(args) != s.placeholders {
panic("error in pkg db; should only get here if size is correct") panic("error in pkg db; should only get here if size is correct")
...@@ -856,6 +866,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) ( ...@@ -856,6 +866,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
} }
cursor := &rowsCursor{ cursor := &rowsCursor{
c: s.c,
posRow: -1, posRow: -1,
rows: setMRows, rows: setMRows,
cols: setColumns, cols: setColumns,
...@@ -880,6 +891,7 @@ func (tx *fakeTx) Commit() error { ...@@ -880,6 +891,7 @@ func (tx *fakeTx) Commit() error {
if hookCommitBadConn != nil && hookCommitBadConn() { if hookCommitBadConn != nil && hookCommitBadConn() {
return driver.ErrBadConn return driver.ErrBadConn
} }
tx.c.line++
return nil return nil
} }
...@@ -891,10 +903,12 @@ func (tx *fakeTx) Rollback() error { ...@@ -891,10 +903,12 @@ func (tx *fakeTx) Rollback() error {
if hookRollbackBadConn != nil && hookRollbackBadConn() { if hookRollbackBadConn != nil && hookRollbackBadConn() {
return driver.ErrBadConn return driver.ErrBadConn
} }
tx.c.line++
return nil return nil
} }
type rowsCursor struct { type rowsCursor struct {
c *fakeConn
cols [][]string cols [][]string
colType [][]string colType [][]string
posSet int posSet int
...@@ -918,6 +932,7 @@ func (rc *rowsCursor) Close() error { ...@@ -918,6 +932,7 @@ func (rc *rowsCursor) Close() error {
bs[0] = 255 // first byte corrupted bs[0] = 255 // first byte corrupted
} }
} }
rc.c.line++
rc.closed = true rc.closed = true
return nil return nil
} }
...@@ -940,6 +955,7 @@ func (rc *rowsCursor) Next(dest []driver.Value) error { ...@@ -940,6 +955,7 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
if rc.closed { if rc.closed {
return errors.New("fakedb: cursor is closed") return errors.New("fakedb: cursor is closed")
} }
rc.c.line++
rc.posRow++ rc.posRow++
if rc.posRow == rc.errPos { if rc.posRow == rc.errPos {
return rc.err return rc.err
...@@ -973,10 +989,12 @@ func (rc *rowsCursor) Next(dest []driver.Value) error { ...@@ -973,10 +989,12 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
} }
func (rc *rowsCursor) HasNextResultSet() bool { func (rc *rowsCursor) HasNextResultSet() bool {
rc.c.line++
return rc.posSet < len(rc.rows)-1 return rc.posSet < len(rc.rows)-1
} }
func (rc *rowsCursor) NextResultSet() error { func (rc *rowsCursor) NextResultSet() error {
rc.c.line++
if rc.HasNextResultSet() { if rc.HasNextResultSet() {
rc.posSet++ rc.posSet++
rc.posRow = -1 rc.posRow = -1
......
...@@ -2700,7 +2700,9 @@ func (rs *Rows) close(err error) error { ...@@ -2700,7 +2700,9 @@ func (rs *Rows) close(err error) error {
rs.lasterr = err rs.lasterr = err
} }
err = rs.rowsi.Close() withLock(rs.dc, func() {
err = rs.rowsi.Close()
})
if fn := rowsCloseHook(); fn != nil { if fn := rowsCloseHook(); fn != nil {
fn(rs, &err) fn(rs, &err)
} }
......
...@@ -2471,6 +2471,8 @@ func TestManyErrBadConn(t *testing.T) { ...@@ -2471,6 +2471,8 @@ func TestManyErrBadConn(t *testing.T) {
// closing a transaction. Ensure Rows is closed while closing a trasaction. // closing a transaction. Ensure Rows is closed while closing a trasaction.
func TestIssue20575(t *testing.T) { func TestIssue20575(t *testing.T) {
db := newTestDB(t, "people") db := newTestDB(t, "people")
defer closeDB(t, db)
tx, err := db.Begin() tx, err := db.Begin()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -2493,6 +2495,43 @@ func TestIssue20575(t *testing.T) { ...@@ -2493,6 +2495,43 @@ func TestIssue20575(t *testing.T) {
} }
} }
// TestIssue20622 tests closing the transaction before rows is closed, requires
// the race detector to fail.
func TestIssue20622(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
t.Fatal(err)
}
rows, err := tx.Query("SELECT|people|age,name|")
if err != nil {
t.Fatal(err)
}
count := 0
for rows.Next() {
count++
var age int
var name string
if err := rows.Scan(&age, &name); err != nil {
t.Fatal("scan failed", err)
}
if count == 1 {
cancel()
}
time.Sleep(100 * time.Millisecond)
}
rows.Close()
tx.Commit()
}
// golang.org/issue/5718 // golang.org/issue/5718
func TestErrBadConnReconnect(t *testing.T) { func TestErrBadConnReconnect(t *testing.T) {
db := newTestDB(t, "foo") db := newTestDB(t, "foo")
...@@ -2956,8 +2995,9 @@ func (c *concurrentRandomTest) init(t testing.TB, db *DB) { ...@@ -2956,8 +2995,9 @@ func (c *concurrentRandomTest) init(t testing.TB, db *DB) {
new(concurrentStmtExecTest), new(concurrentStmtExecTest),
new(concurrentTxQueryTest), new(concurrentTxQueryTest),
new(concurrentTxExecTest), new(concurrentTxExecTest),
new(concurrentTxStmtQueryTest), // golang.org/issue/20646
new(concurrentTxStmtExecTest), // new(concurrentTxStmtQueryTest),
// new(concurrentTxStmtExecTest),
} }
for _, ct := range c.tests { for _, ct := range c.tests {
ct.init(t, db) ct.init(t, db)
...@@ -3193,15 +3233,26 @@ func TestIssue18719(t *testing.T) { ...@@ -3193,15 +3233,26 @@ func TestIssue18719(t *testing.T) {
} }
func TestConcurrency(t *testing.T) { func TestConcurrency(t *testing.T) {
doConcurrentTest(t, new(concurrentDBQueryTest)) list := []struct {
doConcurrentTest(t, new(concurrentDBExecTest)) name string
doConcurrentTest(t, new(concurrentStmtQueryTest)) ct concurrentTest
doConcurrentTest(t, new(concurrentStmtExecTest)) }{
doConcurrentTest(t, new(concurrentTxQueryTest)) {"Query", new(concurrentDBQueryTest)},
doConcurrentTest(t, new(concurrentTxExecTest)) {"Exec", new(concurrentDBExecTest)},
doConcurrentTest(t, new(concurrentTxStmtQueryTest)) {"StmtQuery", new(concurrentStmtQueryTest)},
doConcurrentTest(t, new(concurrentTxStmtExecTest)) {"StmtExec", new(concurrentStmtExecTest)},
doConcurrentTest(t, new(concurrentRandomTest)) {"TxQuery", new(concurrentTxQueryTest)},
{"TxExec", new(concurrentTxExecTest)},
// golang.org/issue/20646
// {"TxStmtQuery", new(concurrentTxStmtQueryTest)},
// {"TxStmtExec", new(concurrentTxStmtExecTest)},
{"Random", new(concurrentRandomTest)},
}
for _, item := range list {
t.Run(item.name, func(t *testing.T) {
doConcurrentTest(t, item.ct)
})
}
} }
func TestConnectionLeak(t *testing.T) { func TestConnectionLeak(t *testing.T) {
...@@ -3531,6 +3582,7 @@ func BenchmarkConcurrentTxExec(b *testing.B) { ...@@ -3531,6 +3582,7 @@ func BenchmarkConcurrentTxExec(b *testing.B) {
} }
func BenchmarkConcurrentTxStmtQuery(b *testing.B) { func BenchmarkConcurrentTxStmtQuery(b *testing.B) {
b.Skip("golang.org/issue/20646")
b.ReportAllocs() b.ReportAllocs()
ct := new(concurrentTxStmtQueryTest) ct := new(concurrentTxStmtQueryTest)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
...@@ -3539,6 +3591,7 @@ func BenchmarkConcurrentTxStmtQuery(b *testing.B) { ...@@ -3539,6 +3591,7 @@ func BenchmarkConcurrentTxStmtQuery(b *testing.B) {
} }
func BenchmarkConcurrentTxStmtExec(b *testing.B) { func BenchmarkConcurrentTxStmtExec(b *testing.B) {
b.Skip("golang.org/issue/20646")
b.ReportAllocs() b.ReportAllocs()
ct := new(concurrentTxStmtExecTest) ct := new(concurrentTxStmtExecTest)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment