Commit 86b2f296 authored by Daniel Theophanes's avatar Daniel Theophanes Committed by Brad Fitzpatrick

database/sql: add support for multiple result sets

Many database systems allow returning multiple result sets
in a single query. This can be useful when dealing with many
intermediate results on the server and there is a need
to return more then one arity of data to the client.

Fixes #12382

Change-Id: I480a9ac6dadfc8743e0ba8b6d868ccf8442a9ca1
Reviewed-on: https://go-review.googlesource.com/30592Reviewed-by: 's avatarBrad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent be48aa3f
...@@ -213,6 +213,22 @@ type Rows interface { ...@@ -213,6 +213,22 @@ type Rows interface {
Next(dest []Value) error Next(dest []Value) error
} }
// RowsNextResultSet extends the Rows interface by providing a way to signal
// the driver to advance to the next result set.
type RowsNextResultSet interface {
Rows
// HasNextResultSet is called at the end of the current result set and
// reports whether there is another result set after the current one.
HasNextResultSet() bool
// NextResultSet advances the driver to the next result set even
// if there are remaining rows in the current result set.
//
// NextResultSet should return io.EOF when there are no more result sets.
NextResultSet() error
}
// Tx is a transaction. // Tx is a transaction.
type Tx interface { type Tx interface {
Commit() error Commit() error
......
...@@ -44,3 +44,65 @@ func ExampleDB_QueryRow() { ...@@ -44,3 +44,65 @@ func ExampleDB_QueryRow() {
fmt.Printf("Username is %s\n", username) fmt.Printf("Username is %s\n", username)
} }
} }
func ExampleDB_QueryMultipleResultSets() {
age := 27
q := `
create temp table uid (id bigint); -- Create temp table for queries.
insert into uid
select id from users where age < ?; -- Populate temp table.
-- First result set.
select
users.id, name
from
users
join uid on users.id = uid.id
;
-- Second result set.
select
ur.user, ur.role
from
user_roles as ur
join uid on uid.id = ur.user
;
`
rows, err := db.Query(q, age)
if err != nil {
log.Fatal(err)
}
defer rows.Close()
for rows.Next() {
var (
id int64
name string
)
if err := rows.Scan(&id, &name); err != nil {
log.Fatal(err)
}
fmt.Printf("id %d name is %s\n", id, name)
}
if !rows.NextResultSet() {
log.Fatal("expected more result sets", rows.Err())
}
var roleMap = map[int64]string{
1: "user",
2: "admin",
3: "gopher",
}
for rows.Next() {
var (
id int64
role int64
)
if err := rows.Scan(&id, &role); err != nil {
log.Fatal(err)
}
fmt.Printf("id %d has role %s\n", id, roleMap[role])
}
if err := rows.Err(); err != nil {
log.Fatal(err)
}
}
...@@ -36,6 +36,8 @@ var _ = log.Printf ...@@ -36,6 +36,8 @@ var _ = log.Printf
// Any of these can be preceded by PANIC|<method>|, to cause the // Any of these can be preceded by PANIC|<method>|, to cause the
// named method on fakeStmt to panic. // named method on fakeStmt to panic.
// //
// Multiple of these can be combined when separated with a semicolon.
//
// When opening a fakeDriver's database, it starts empty with no // When opening a fakeDriver's database, it starts empty with no
// tables. All tables and data are stored in memory only. // tables. All tables and data are stored in memory only.
type fakeDriver struct { type fakeDriver struct {
...@@ -109,6 +111,8 @@ type fakeStmt struct { ...@@ -109,6 +111,8 @@ type fakeStmt struct {
table string table string
panic string panic string
next *fakeStmt // used for returning multiple results.
closed bool closed bool
colName []string // used by CREATE, INSERT, SELECT (selected columns) colName []string // used by CREATE, INSERT, SELECT (selected columns)
...@@ -377,7 +381,7 @@ func errf(msg string, args ...interface{}) error { ...@@ -377,7 +381,7 @@ func errf(msg string, args ...interface{}) error {
// parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=? // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
// (note that where columns must always contain ? marks, // (note that where columns must always contain ? marks,
// just a limitation for fakedb) // just a limitation for fakedb)
func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) { func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
if len(parts) != 3 { if len(parts) != 3 {
stmt.Close() stmt.Close()
return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
...@@ -411,7 +415,7 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e ...@@ -411,7 +415,7 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e
} }
// parts are table|col=type,col2=type2 // parts are table|col=type,col2=type2
func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) { func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
if len(parts) != 2 { if len(parts) != 2 {
stmt.Close() stmt.Close()
return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
...@@ -430,7 +434,7 @@ func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, e ...@@ -430,7 +434,7 @@ func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, e
} }
// parts are table|col=?,col2=val // parts are table|col=?,col2=val
func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) { func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
if len(parts) != 2 { if len(parts) != 2 {
stmt.Close() stmt.Close()
return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
...@@ -492,38 +496,52 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { ...@@ -492,38 +496,52 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
parts := strings.Split(query, "|") var firstStmt, prev *fakeStmt
if len(parts) < 1 { for _, query := range strings.Split(query, ";") {
return nil, errf("empty query") parts := strings.Split(query, "|")
} if len(parts) < 1 {
stmt := &fakeStmt{q: query, c: c} return nil, errf("empty query")
if len(parts) >= 3 && parts[0] == "PANIC" { }
stmt.panic = parts[1] stmt := &fakeStmt{q: query, c: c}
parts = parts[2:] if firstStmt == nil {
} firstStmt = stmt
cmd := parts[0] }
stmt.cmd = cmd if len(parts) >= 3 && parts[0] == "PANIC" {
parts = parts[1:] stmt.panic = parts[1]
parts = parts[2:]
c.incrStat(&c.stmtsMade) }
switch cmd { cmd := parts[0]
case "WIPE": stmt.cmd = cmd
// Nothing parts = parts[1:]
case "SELECT":
return c.prepareSelect(stmt, parts) c.incrStat(&c.stmtsMade)
case "CREATE": var err error
return c.prepareCreate(stmt, parts) switch cmd {
case "INSERT": case "WIPE":
return c.prepareInsert(stmt, parts) // Nothing
case "NOSERT": case "SELECT":
// Do all the prep-work like for an INSERT but don't actually insert the row. stmt, err = c.prepareSelect(stmt, parts)
// Used for some of the concurrent tests. case "CREATE":
return c.prepareInsert(stmt, parts) stmt, err = c.prepareCreate(stmt, parts)
default: case "INSERT":
stmt.Close() stmt, err = c.prepareInsert(stmt, parts)
return nil, errf("unsupported command type %q", cmd) case "NOSERT":
// Do all the prep-work like for an INSERT but don't actually insert the row.
// Used for some of the concurrent tests.
stmt, err = c.prepareInsert(stmt, parts)
default:
stmt.Close()
return nil, errf("unsupported command type %q", cmd)
}
if err != nil {
return nil, err
}
if prev != nil {
prev.next = stmt
}
prev = stmt
} }
return stmt, nil return firstStmt, nil
} }
func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
...@@ -550,6 +568,9 @@ func (s *fakeStmt) Close() error { ...@@ -550,6 +568,9 @@ func (s *fakeStmt) Close() error {
s.c.incrStat(&s.c.stmtsClosed) s.c.incrStat(&s.c.stmtsClosed)
s.closed = true s.closed = true
} }
if s.next != nil {
s.next.Close()
}
return nil return nil
} }
...@@ -667,64 +688,80 @@ func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { ...@@ -667,64 +688,80 @@ func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
panic("error in pkg db; should only get here if size is correct") panic("error in pkg db; should only get here if size is correct")
} }
db.mu.Lock() setMRows := make([][]*row, 0, 1)
t, ok := db.table(s.table) setColumns := make([][]string, 0, 1)
db.mu.Unlock()
if !ok {
return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
}
if s.table == "magicquery" { for {
if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" { db.mu.Lock()
if args[0] == "sleep" { t, ok := db.table(s.table)
time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond) db.mu.Unlock()
} if !ok {
return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
} }
}
t.mu.Lock() if s.table == "magicquery" {
defer t.mu.Unlock() if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" {
if args[0] == "sleep" {
colIdx := make(map[string]int) // select column name -> column index in table time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond)
for _, name := range s.colName { }
idx := t.columnIndex(name) }
if idx == -1 {
return nil, fmt.Errorf("fakedb: unknown column name %q", name)
} }
colIdx[name] = idx
}
mrows := []*row{} t.mu.Lock()
rows:
for _, trow := range t.rows { colIdx := make(map[string]int) // select column name -> column index in table
// Process the where clause, skipping non-match rows. This is lazy for _, name := range s.colName {
// and just uses fmt.Sprintf("%v") to test equality. Good enough idx := t.columnIndex(name)
// for test code.
for widx, wcol := range s.whereCol {
idx := t.columnIndex(wcol)
if idx == -1 { if idx == -1 {
return nil, fmt.Errorf("db: invalid where clause column %q", wcol) t.mu.Unlock()
return nil, fmt.Errorf("fakedb: unknown column name %q", name)
} }
tcol := trow.cols[idx] colIdx[name] = idx
if bs, ok := tcol.([]byte); ok { }
// lazy hack to avoid sprintf %v on a []byte
tcol = string(bs) mrows := []*row{}
rows:
for _, trow := range t.rows {
// Process the where clause, skipping non-match rows. This is lazy
// and just uses fmt.Sprintf("%v") to test equality. Good enough
// for test code.
for widx, wcol := range s.whereCol {
idx := t.columnIndex(wcol)
if idx == -1 {
t.mu.Unlock()
return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
}
tcol := trow.cols[idx]
if bs, ok := tcol.([]byte); ok {
// lazy hack to avoid sprintf %v on a []byte
tcol = string(bs)
}
if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) {
continue rows
}
} }
if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) { mrow := &row{cols: make([]interface{}, len(s.colName))}
continue rows for seli, name := range s.colName {
mrow.cols[seli] = trow.cols[colIdx[name]]
} }
mrows = append(mrows, mrow)
} }
mrow := &row{cols: make([]interface{}, len(s.colName))}
for seli, name := range s.colName { t.mu.Unlock()
mrow.cols[seli] = trow.cols[colIdx[name]]
setMRows = append(setMRows, mrows)
setColumns = append(setColumns, s.colName)
if s.next == nil {
break
} }
mrows = append(mrows, mrow) s = s.next
} }
cursor := &rowsCursor{ cursor := &rowsCursor{
pos: -1, posRow: -1,
rows: mrows, rows: setMRows,
cols: s.colName, cols: setColumns,
errPos: -1, errPos: -1,
} }
return cursor, nil return cursor, nil
...@@ -760,9 +797,10 @@ func (tx *fakeTx) Rollback() error { ...@@ -760,9 +797,10 @@ func (tx *fakeTx) Rollback() error {
} }
type rowsCursor struct { type rowsCursor struct {
cols []string cols [][]string
pos int posSet int
rows []*row posRow int
rows [][]*row
closed bool closed bool
// errPos and err are for making Next return early with error. // errPos and err are for making Next return early with error.
...@@ -786,7 +824,7 @@ func (rc *rowsCursor) Close() error { ...@@ -786,7 +824,7 @@ func (rc *rowsCursor) Close() error {
} }
func (rc *rowsCursor) Columns() []string { func (rc *rowsCursor) Columns() []string {
return rc.cols return rc.cols[rc.posSet]
} }
var rowsCursorNextHook func(dest []driver.Value) error var rowsCursorNextHook func(dest []driver.Value) error
...@@ -799,14 +837,14 @@ func (rc *rowsCursor) Next(dest []driver.Value) error { ...@@ -799,14 +837,14 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
if rc.closed { if rc.closed {
return errors.New("fakedb: cursor is closed") return errors.New("fakedb: cursor is closed")
} }
rc.pos++ rc.posRow++
if rc.pos == rc.errPos { if rc.posRow == rc.errPos {
return rc.err return rc.err
} }
if rc.pos >= len(rc.rows) { if rc.posRow >= len(rc.rows[rc.posSet]) {
return io.EOF // per interface spec return io.EOF // per interface spec
} }
for i, v := range rc.rows[rc.pos].cols { for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
// TODO(bradfitz): convert to subset types? naah, I // TODO(bradfitz): convert to subset types? naah, I
// think the subset types should only be input to // think the subset types should only be input to
// driver, but the sql package should be able to handle // driver, but the sql package should be able to handle
...@@ -831,6 +869,19 @@ func (rc *rowsCursor) Next(dest []driver.Value) error { ...@@ -831,6 +869,19 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
return nil return nil
} }
func (rc *rowsCursor) HasNextResultSet() bool {
return rc.posSet < len(rc.rows)-1
}
func (rc *rowsCursor) NextResultSet() error {
if rc.HasNextResultSet() {
rc.posSet++
rc.posRow = -1
return nil
}
return io.EOF // Per interface spec.
}
// fakeDriverString is like driver.String, but indirects pointers like // fakeDriverString is like driver.String, but indirects pointers like
// DefaultValueConverter. // DefaultValueConverter.
// //
......
...@@ -1942,6 +1942,47 @@ func (rs *Rows) Next() bool { ...@@ -1942,6 +1942,47 @@ func (rs *Rows) Next() bool {
rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns())) rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
} }
rs.lasterr = rs.rowsi.Next(rs.lastcols) rs.lasterr = rs.rowsi.Next(rs.lastcols)
if rs.lasterr != nil {
// Close the connection if there is a driver error.
if rs.lasterr != io.EOF {
rs.Close()
return false
}
nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
if !ok {
rs.Close()
return false
}
// The driver is at the end of the current result set.
// Test to see if there is another result set after the current one.
// Only close Rows if there is no futher result sets to read.
if !nextResultSet.HasNextResultSet() {
rs.Close()
}
return false
}
return true
}
// NextResultSet prepares the next result set for reading. It returns true if
// there is further result sets, or false if there is no further result set
// or if there is an error advancing to it. The Err method should be consulted
// to distinguish between the two cases.
//
// After calling NextResultSet, the Next method should always be called before
// scanning. If there are further result sets they may not have rows in the result
// set.
func (rs *Rows) NextResultSet() bool {
if rs.isClosed() {
return false
}
rs.lastcols = nil
nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
if !ok {
rs.Close()
return false
}
rs.lasterr = nextResultSet.NextResultSet()
if rs.lasterr != nil { if rs.lasterr != nil {
rs.Close() rs.Close()
return false return false
...@@ -2047,7 +2088,8 @@ func (rs *Rows) isClosed() bool { ...@@ -2047,7 +2088,8 @@ func (rs *Rows) isClosed() bool {
return atomic.LoadInt32(&rs.closed) != 0 return atomic.LoadInt32(&rs.closed) != 0
} }
// Close closes the Rows, preventing further enumeration. If Next returns // Close closes the Rows, preventing further enumeration. If Next and
// NextResultSet both return
// false, the Rows are closed automatically and it will suffice to check the // false, the Rows are closed automatically and it will suffice to check the
// result of Err. Close is idempotent and does not affect the result of Err. // result of Err. Close is idempotent and does not affect the result of Err.
func (rs *Rows) Close() error { func (rs *Rows) Close() error {
......
...@@ -319,6 +319,82 @@ func TestQueryContext(t *testing.T) { ...@@ -319,6 +319,82 @@ func TestQueryContext(t *testing.T) {
} }
} }
func TestMultiResultSetQuery(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
prepares0 := numPrepares(t, db)
rows, err := db.Query("SELECT|people|age,name|;SELECT|people|name|")
if err != nil {
t.Fatalf("Query: %v", err)
}
type row1 struct {
age int
name string
}
type row2 struct {
name string
}
got1 := []row1{}
for rows.Next() {
var r row1
err = rows.Scan(&r.age, &r.name)
if err != nil {
t.Fatalf("Scan: %v", err)
}
got1 = append(got1, r)
}
err = rows.Err()
if err != nil {
t.Fatalf("Err: %v", err)
}
want1 := []row1{
{age: 1, name: "Alice"},
{age: 2, name: "Bob"},
{age: 3, name: "Chris"},
}
if !reflect.DeepEqual(got1, want1) {
t.Errorf("mismatch.\n got1: %#v\nwant: %#v", got1, want1)
}
if !rows.NextResultSet() {
t.Errorf("expected another result set")
}
got2 := []row2{}
for rows.Next() {
var r row2
err = rows.Scan(&r.name)
if err != nil {
t.Fatalf("Scan: %v", err)
}
got2 = append(got2, r)
}
err = rows.Err()
if err != nil {
t.Fatalf("Err: %v", err)
}
want2 := []row2{
{name: "Alice"},
{name: "Bob"},
{name: "Chris"},
}
if !reflect.DeepEqual(got2, want2) {
t.Errorf("mismatch.\n got: %#v\nwant: %#v", got2, want2)
}
if rows.NextResultSet() {
t.Errorf("expected no more result sets")
}
// And verify that the final rows.Next() call, which hit EOF,
// also closed the rows connection.
if n := db.numFreeConns(); n != 1 {
t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
}
if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
t.Errorf("executed %d Prepare statements; want 1", prepares)
}
}
func TestByteOwnership(t *testing.T) { func TestByteOwnership(t *testing.T) {
db := newTestDB(t, "people") db := newTestDB(t, "people")
defer closeDB(t, db) defer closeDB(t, db)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment