Commit d2df8498 authored by Daniel Theophanes's avatar Daniel Theophanes Committed by Brad Fitzpatrick

database/sql: close Rows when context is cancelled

To prevent leaking connections, close any open Rows when the
context is cancelled. Also enforce context cancel while reading
rows off of the wire.

Change-Id: I62237ecdb7d250d6734f6ce3d2b0bcb16dc6fda7
Reviewed-on: https://go-review.googlesource.com/29957Reviewed-by: 's avatarBrad Fitzpatrick <bradfitz@golang.org>
parent 6dc356a7
......@@ -14,6 +14,10 @@ func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver
if ciCtx, is := ci.(driver.ConnPrepareContext); is {
return ciCtx.PrepareContext(ctx, query)
}
if ctx.Done() == context.Background().Done() {
return ci.Prepare(query)
}
type R struct {
err error
panic interface{}
......@@ -50,6 +54,10 @@ func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, darg
if execerCtx, is := execer.(driver.ExecerContext); is {
return execerCtx.ExecContext(ctx, query, dargs)
}
if ctx.Done() == context.Background().Done() {
return execer.Exec(query, dargs)
}
type R struct {
err error
panic interface{}
......@@ -86,6 +94,10 @@ func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, d
if queryerCtx, is := queryer.(driver.QueryerContext); is {
return queryerCtx.QueryContext(ctx, query, dargs)
}
if ctx.Done() == context.Background().Done() {
return queryer.Query(query, dargs)
}
type R struct {
err error
panic interface{}
......@@ -122,6 +134,10 @@ func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, dargs []driver.Value
if siCtx, is := si.(driver.StmtExecContext); is {
return siCtx.ExecContext(ctx, dargs)
}
if ctx.Done() == context.Background().Done() {
return si.Exec(dargs)
}
type R struct {
err error
panic interface{}
......@@ -158,6 +174,10 @@ func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, dargs []driver.Valu
if siCtx, is := si.(driver.StmtQueryContext); is {
return siCtx.QueryContext(ctx, dargs)
}
if ctx.Done() == context.Background().Done() {
return si.Query(dargs)
}
type R struct {
err error
panic interface{}
......@@ -196,6 +216,10 @@ func ctxDriverBegin(ctx context.Context, ci driver.Conn) (driver.Tx, error) {
if ciCtx, is := ci.(driver.ConnBeginContext); is {
return ciCtx.BeginContext(ctx)
}
if ctx.Done() == context.Background().Done() {
return ci.Begin()
}
// TODO(kardianos): check the transaction level in ctx. If set and non-default
// then return an error here as the BeginContext driver value is not supported.
......
......@@ -974,7 +974,8 @@ const maxBadConnRetries = 2
// returned statement.
// The caller must call the statement's Close method
// when the statement is no longer needed.
// Context is for the preparation of the statment, not for the execution of
//
// The provided context is for the preparation of the statment, not for the execution of
// the statement.
func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
var stmt *Stmt
......@@ -1148,6 +1149,7 @@ func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(er
releaseConn: releaseConn,
rowsi: rowsi,
}
rows.initContextClose(ctx)
return rows, nil
}
}
......@@ -1180,6 +1182,7 @@ func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(er
rowsi: rowsi,
closeStmt: si,
}
rows.initContextClose(ctx)
return rows, nil
}
......@@ -1364,7 +1367,8 @@ func (tx *Tx) Rollback() error {
// be used once the transaction has been committed or rolled back.
//
// To use an existing prepared statement on this transaction, see Tx.Stmt.
// Context will be used for the preparation of the context, not
//
// The provided context will be used for the preparation of the context, not
// for the execution of the returned statement. The returned statement
// will run in the transaction context.
func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
......@@ -1759,6 +1763,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er
rowsi: rowsi,
// releaseConn set below
}
rows.initContextClose(ctx)
s.db.addDep(s, rows)
rows.releaseConn = func(err error) {
releaseConn(err)
......@@ -1899,12 +1904,30 @@ type Rows struct {
releaseConn func(error)
rowsi driver.Rows
closed bool
// closed value is 1 when the Rows is closed.
// Use atomic operations on value when checking value.
closed int32
ctxClose chan struct{} // closed when Rows is closed, may be null.
lastcols []driver.Value
lasterr error // non-nil only if closed is true
closeStmt driver.Stmt // if non-nil, statement to Close on close
}
func (rs *Rows) initContextClose(ctx context.Context) {
if ctx.Done() == context.Background().Done() {
return
}
rs.ctxClose = make(chan struct{})
go func() {
select {
case <-ctx.Done():
rs.Close()
case <-rs.ctxClose:
}
}()
}
// Next prepares the next result row for reading with the Scan method. It
// returns true on success, or false if there is no next result row or an error
// happened while preparing it. Err should be consulted to distinguish between
......@@ -1912,7 +1935,7 @@ type Rows struct {
//
// Every call to Scan, even the first one, must be preceded by a call to Next.
func (rs *Rows) Next() bool {
if rs.closed {
if rs.isClosed() {
return false
}
if rs.lastcols == nil {
......@@ -1939,7 +1962,7 @@ func (rs *Rows) Err() error {
// Columns returns an error if the rows are closed, or if the rows
// are from QueryRow and there was a deferred error.
func (rs *Rows) Columns() ([]string, error) {
if rs.closed {
if rs.isClosed() {
return nil, errors.New("sql: Rows are closed")
}
if rs.rowsi == nil {
......@@ -2000,7 +2023,7 @@ func (rs *Rows) Columns() ([]string, error) {
// For scanning into *bool, the source may be true, false, 1, 0, or
// string inputs parseable by strconv.ParseBool.
func (rs *Rows) Scan(dest ...interface{}) error {
if rs.closed {
if rs.isClosed() {
return errors.New("sql: Rows are closed")
}
if rs.lastcols == nil {
......@@ -2020,14 +2043,20 @@ func (rs *Rows) Scan(dest ...interface{}) error {
var rowsCloseHook func(*Rows, *error)
func (rs *Rows) isClosed() bool {
return atomic.LoadInt32(&rs.closed) != 0
}
// Close closes the Rows, preventing further enumeration. If Next returns
// false, the Rows are closed automatically and it will suffice to check the
// result of Err. Close is idempotent and does not affect the result of Err.
func (rs *Rows) Close() error {
if rs.closed {
if !atomic.CompareAndSwapInt32(&rs.closed, 0, 1) {
return nil
}
rs.closed = true
if rs.ctxClose != nil {
close(rs.ctxClose)
}
err := rs.rowsi.Close()
if fn := rowsCloseHook; fn != nil {
fn(rs, &err)
......
......@@ -261,6 +261,64 @@ func TestQuery(t *testing.T) {
}
}
func TestQueryContext(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
prepares0 := numPrepares(t, db)
ctx, cancel := context.WithCancel(context.Background())
rows, err := db.QueryContext(ctx, "SELECT|people|age,name|")
if err != nil {
t.Fatalf("Query: %v", err)
}
type row struct {
age int
name string
}
got := []row{}
index := 0
for rows.Next() {
if index == 2 {
cancel()
time.Sleep(10 * time.Millisecond)
}
var r row
err = rows.Scan(&r.age, &r.name)
if err != nil {
if index == 2 {
break
}
t.Fatalf("Scan: %v", err)
}
if index == 2 && err == nil {
t.Fatal("expected an error on last scan")
}
got = append(got, r)
index++
}
err = rows.Err()
if err != nil {
t.Fatalf("Err: %v", err)
}
want := []row{
{age: 1, name: "Alice"},
{age: 2, name: "Bob"},
}
if !reflect.DeepEqual(got, want) {
t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
}
// And verify that the final rows.Next() call, which hit EOF,
// also closed the rows connection.
if n := db.numFreeConns(); n != 1 {
t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
}
if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
t.Errorf("executed %d Prepare statements; want 1", prepares)
}
}
func TestByteOwnership(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment