Unverified Commit e22a5143 authored by astaxie's avatar astaxie Committed by GitHub

Merge pull request #3403 from nlimpid/develop

add context for db operation
parents a17eb545 d5cf1050
...@@ -762,7 +762,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con ...@@ -762,7 +762,13 @@ func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
} }
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
res, err := q.Exec(query, values...) var err error
var res sql.Result
if qs != nil && qs.forContext {
res, err = q.ExecContext(qs.ctx, query, values...)
} else {
res, err = q.Exec(query, values...)
}
if err == nil { if err == nil {
return res.RowsAffected() return res.RowsAffected()
} }
...@@ -851,11 +857,16 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con ...@@ -851,11 +857,16 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
for i := range marks { for i := range marks {
marks[i] = "?" marks[i] = "?"
} }
sql := fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) sqlIn := fmt.Sprintf("IN (%s)", strings.Join(marks, ", "))
query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sql) query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn)
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
res, err := q.Exec(query, args...) var res sql.Result
if qs != nil && qs.forContext {
res, err = q.ExecContext(qs.ctx, query, args...)
} else {
res, err = q.Exec(query, args...)
}
if err == nil { if err == nil {
num, err := res.RowsAffected() num, err := res.RowsAffected()
if err != nil { if err != nil {
...@@ -978,11 +989,18 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi ...@@ -978,11 +989,18 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
var rs *sql.Rows var rs *sql.Rows
r, err := q.Query(query, args...) var err error
if err != nil { if qs != nil && qs.forContext {
return 0, err rs, err = q.QueryContext(qs.ctx, query, args...)
if err != nil {
return 0, err
}
} else {
rs, err = q.Query(query, args...)
if err != nil {
return 0, err
}
} }
rs = r
refs := make([]interface{}, colsNum) refs := make([]interface{}, colsNum)
for i := range refs { for i := range refs {
...@@ -1111,8 +1129,12 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition ...@@ -1111,8 +1129,12 @@ func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition
d.ins.ReplaceMarks(&query) d.ins.ReplaceMarks(&query)
row := q.QueryRow(query, args...) var row *sql.Row
if qs != nil && qs.forContext {
row = q.QueryRowContext(qs.ctx, query, args...)
} else {
row = q.QueryRow(query, args...)
}
err = row.Scan(&cnt) err = row.Scan(&cnt)
return return
} }
......
...@@ -123,6 +123,13 @@ func (d *dbQueryLog) Prepare(query string) (*sql.Stmt, error) { ...@@ -123,6 +123,13 @@ func (d *dbQueryLog) Prepare(query string) (*sql.Stmt, error) {
return stmt, err return stmt, err
} }
func (d *dbQueryLog) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
a := time.Now()
stmt, err := d.db.PrepareContext(ctx, query)
debugLogQueies(d.alias, "db.Prepare", query, a, err)
return stmt, err
}
func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error) { func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error) {
a := time.Now() a := time.Now()
res, err := d.db.Exec(query, args...) res, err := d.db.Exec(query, args...)
...@@ -130,6 +137,13 @@ func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error) ...@@ -130,6 +137,13 @@ func (d *dbQueryLog) Exec(query string, args ...interface{}) (sql.Result, error)
return res, err return res, err
} }
func (d *dbQueryLog) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
a := time.Now()
res, err := d.db.ExecContext(ctx, query, args...)
debugLogQueies(d.alias, "db.Exec", query, a, err, args...)
return res, err
}
func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error) { func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error) {
a := time.Now() a := time.Now()
res, err := d.db.Query(query, args...) res, err := d.db.Query(query, args...)
...@@ -137,6 +151,13 @@ func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error) ...@@ -137,6 +151,13 @@ func (d *dbQueryLog) Query(query string, args ...interface{}) (*sql.Rows, error)
return res, err return res, err
} }
func (d *dbQueryLog) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
a := time.Now()
res, err := d.db.QueryContext(ctx, query, args...)
debugLogQueies(d.alias, "db.Query", query, a, err, args...)
return res, err
}
func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row { func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row {
a := time.Now() a := time.Now()
res := d.db.QueryRow(query, args...) res := d.db.QueryRow(query, args...)
...@@ -144,6 +165,13 @@ func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row { ...@@ -144,6 +165,13 @@ func (d *dbQueryLog) QueryRow(query string, args ...interface{}) *sql.Row {
return res return res
} }
func (d *dbQueryLog) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
a := time.Now()
res := d.db.QueryRowContext(ctx, query, args...)
debugLogQueies(d.alias, "db.QueryRow", query, a, nil, args...)
return res
}
func (d *dbQueryLog) Begin() (*sql.Tx, error) { func (d *dbQueryLog) Begin() (*sql.Tx, error) {
a := time.Now() a := time.Now()
tx, err := d.db.(txer).Begin() tx, err := d.db.(txer).Begin()
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
package orm package orm
import ( import (
"context"
"fmt" "fmt"
) )
...@@ -55,17 +56,19 @@ func ColValue(opt operator, value interface{}) interface{} { ...@@ -55,17 +56,19 @@ func ColValue(opt operator, value interface{}) interface{} {
// real query struct // real query struct
type querySet struct { type querySet struct {
mi *modelInfo mi *modelInfo
cond *Condition cond *Condition
related []string related []string
relDepth int relDepth int
limit int64 limit int64
offset int64 offset int64
groups []string groups []string
orders []string orders []string
distinct bool distinct bool
forupdate bool forupdate bool
orm *orm orm *orm
ctx context.Context
forContext bool
} }
var _ QuerySeter = new(querySet) var _ QuerySeter = new(querySet)
...@@ -275,6 +278,13 @@ func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string) ...@@ -275,6 +278,13 @@ func (o *querySet) RowsToStruct(ptrStruct interface{}, keyCol, valueCol string)
panic(ErrNotImplement) panic(ErrNotImplement)
} }
// set context to QuerySeter.
func (o querySet) WithContext(ctx context.Context) QuerySeter {
o.ctx = ctx
o.forContext = true
return &o
}
// create new QuerySeter. // create new QuerySeter.
func newQuerySet(orm *orm, mi *modelInfo) QuerySeter { func newQuerySet(orm *orm, mi *modelInfo) QuerySeter {
o := new(querySet) o := new(querySet)
......
...@@ -395,16 +395,23 @@ type RawSeter interface { ...@@ -395,16 +395,23 @@ type RawSeter interface {
type stmtQuerier interface { type stmtQuerier interface {
Close() error Close() error
Exec(args ...interface{}) (sql.Result, error) Exec(args ...interface{}) (sql.Result, error)
//ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
Query(args ...interface{}) (*sql.Rows, error) Query(args ...interface{}) (*sql.Rows, error)
//QueryContext(args ...interface{}) (*sql.Rows, error)
QueryRow(args ...interface{}) *sql.Row QueryRow(args ...interface{}) *sql.Row
//QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row
} }
// db querier // db querier
type dbQuerier interface { type dbQuerier interface {
Prepare(query string) (*sql.Stmt, error) Prepare(query string) (*sql.Stmt, error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
Exec(query string, args ...interface{}) (sql.Result, error) Exec(query string, args ...interface{}) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error) Query(query string, args ...interface{}) (*sql.Rows, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row QueryRow(query string, args ...interface{}) *sql.Row
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
} }
// type DB interface { // type DB interface {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment