Commit 6c41e6dd authored by slene's avatar slene

orm add sqlite3 support, may be support postgres in next commit

parent 9631c663
This diff is collapsed.
package orm package orm
var mysqlOperators = map[string]string{
"exact": "= ?",
"iexact": "LIKE ?",
"contains": "LIKE BINARY ?",
"icontains": "LIKE ?",
// "regex": "REGEXP BINARY ?",
// "iregex": "REGEXP ?",
"gt": "> ?",
"gte": ">= ?",
"lt": "< ?",
"lte": "<= ?",
"startswith": "LIKE BINARY ?",
"endswith": "LIKE BINARY ?",
"istartswith": "LIKE ?",
"iendswith": "LIKE ?",
}
type dbBaseMysql struct { type dbBaseMysql struct {
dbBase dbBase
} }
func (d *dbBaseMysql) GetOperatorSql(mi *modelInfo, operator string, args []interface{}) (sql string, params []interface{}) { var _ dbBaser = new(dbBaseMysql)
return d.dbBase.GetOperatorSql(mi, operator, args)
func (d *dbBaseMysql) OperatorSql(operator string) string {
return mysqlOperators[operator]
} }
func newdbBaseMysql() dbBaser { func newdbBaseMysql() dbBaser {
......
...@@ -4,6 +4,12 @@ type dbBaseOracle struct { ...@@ -4,6 +4,12 @@ type dbBaseOracle struct {
dbBase dbBase
} }
var _ dbBaser = new(dbBaseOracle)
func (d *dbBase) OperatorSql(operator string) string {
return ""
}
func newdbBaseOracle() dbBaser { func newdbBaseOracle() dbBaser {
b := new(dbBaseOracle) b := new(dbBaseOracle)
b.ins = b b.ins = b
......
package orm package orm
import (
"strconv"
)
var postgresOperators = map[string]string{
"exact": "= ?",
"iexact": "= UPPER(?)",
"contains": "LIKE ?",
"icontains": "LIKE UPPER(?)",
"gt": "> ?",
"gte": ">= ?",
"lt": "< ?",
"lte": "<= ?",
"startswith": "LIKE ?",
"endswith": "LIKE ?",
"istartswith": "LIKE UPPER(?)",
"iendswith": "LIKE UPPER(?)",
}
type dbBasePostgres struct { type dbBasePostgres struct {
dbBase dbBase
} }
var _ dbBaser = new(dbBasePostgres)
func (d *dbBasePostgres) OperatorSql(operator string) string {
return postgresOperators[operator]
}
func (d *dbBasePostgres) TableQuote() string {
return `"`
}
func (d *dbBasePostgres) ReplaceMarks(query *string) {
q := *query
num := 0
for _, c := range q {
if c == '?' {
num += 1
}
}
if num == 0 {
return
}
data := make([]byte, 0, len(q)+num)
num = 1
for i := 0; i < len(q); i++ {
c := q[i]
if c == '?' {
data = append(data, '$')
data = append(data, []byte(strconv.Itoa(num))...)
num += 1
} else {
data = append(data, c)
}
}
*query = string(data)
}
// func (d *dbBasePostgres)
func newdbBasePostgres() dbBaser { func newdbBasePostgres() dbBaser {
b := new(dbBasePostgres) b := new(dbBasePostgres)
b.ins = b b.ins = b
......
package orm package orm
var sqliteOperators = map[string]string{
"exact": "= ?",
"iexact": "LIKE ? ESCAPE '\\'",
"contains": "LIKE ? ESCAPE '\\'",
"icontains": "LIKE ? ESCAPE '\\'",
"gt": "> ?",
"gte": ">= ?",
"lt": "< ?",
"lte": "<= ?",
"startswith": "LIKE ? ESCAPE '\\'",
"endswith": "LIKE ? ESCAPE '\\'",
"istartswith": "LIKE ? ESCAPE '\\'",
"iendswith": "LIKE ? ESCAPE '\\'",
}
type dbBaseSqlite struct { type dbBaseSqlite struct {
dbBase dbBase
} }
var _ dbBaser = new(dbBaseSqlite)
func (d *dbBaseSqlite) OperatorSql(operator string) string {
return sqliteOperators[operator]
}
func (d *dbBaseSqlite) SupportUpdateJoin() bool {
return false
}
func (d *dbBaseSqlite) MaxLimit() uint64 {
return 9223372036854775807
}
func newdbBaseSqlite() dbBaser { func newdbBaseSqlite() dbBaser {
b := new(dbBaseSqlite) b := new(dbBaseSqlite)
b.ins = b b.ins = b
......
package orm
import (
"fmt"
"strings"
)
type dbTable struct {
id int
index string
name string
names []string
sel bool
inner bool
mi *modelInfo
fi *fieldInfo
jtl *dbTable
}
type dbTables struct {
tablesM map[string]*dbTable
tables []*dbTable
mi *modelInfo
base dbBaser
}
func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
name := strings.Join(names, ExprSep)
if j, ok := t.tablesM[name]; ok {
j.name = name
j.mi = mi
j.fi = fi
j.inner = inner
} else {
i := len(t.tables) + 1
jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
t.tablesM[name] = jt
t.tables = append(t.tables, jt)
}
return t.tablesM[name]
}
func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
name := strings.Join(names, ExprSep)
if _, ok := t.tablesM[name]; ok == false {
i := len(t.tables) + 1
jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
t.tablesM[name] = jt
t.tables = append(t.tables, jt)
return jt, true
}
return t.tablesM[name], false
}
func (t *dbTables) get(name string) (*dbTable, bool) {
j, ok := t.tablesM[name]
return j, ok
}
func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
if depth < 0 || fi.fieldType == RelManyToMany {
return related
}
if prefix == "" {
prefix = fi.name
} else {
prefix = prefix + ExprSep + fi.name
}
related = append(related, prefix)
depth--
for _, fi := range fi.relModelInfo.fields.fieldsRel {
related = t.loopDepth(depth, prefix, fi, related)
}
return related
}
func (t *dbTables) parseRelated(rels []string, depth int) {
relsNum := len(rels)
related := make([]string, relsNum)
copy(related, rels)
relDepth := depth
if relsNum != 0 {
relDepth = 0
}
relDepth--
for _, fi := range t.mi.fields.fieldsRel {
related = t.loopDepth(relDepth, "", fi, related)
}
for i, s := range related {
var (
exs = strings.Split(s, ExprSep)
names = make([]string, 0, len(exs))
mmi = t.mi
cansel = true
jtl *dbTable
)
for _, ex := range exs {
if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany {
names = append(names, fi.name)
mmi = fi.relModelInfo
jt := t.set(names, mmi, fi, fi.null == false)
jt.jtl = jtl
if fi.reverse {
cansel = false
}
if cansel {
jt.sel = depth > 0
if i < relsNum {
jt.sel = true
}
}
jtl = jt
} else {
panic(fmt.Sprintf("unknown model/table name `%s`", ex))
}
}
}
}
func (t *dbTables) getJoinSql() (join string) {
Q := t.base.TableQuote()
for _, jt := range t.tables {
if jt.inner {
join += "INNER JOIN "
} else {
join += "LEFT OUTER JOIN "
}
var (
table string
t1, t2 string
c1, c2 string
)
t1 = "T0"
if jt.jtl != nil {
t1 = jt.jtl.index
}
t2 = jt.index
table = jt.mi.table
switch {
case jt.fi.fieldType == RelManyToMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
c1 = jt.fi.mi.fields.pk.column
for _, ffi := range jt.mi.fields.fieldsRel {
if jt.fi.mi == ffi.relModelInfo {
c2 = ffi.column
break
}
}
default:
c1 = jt.fi.column
c2 = jt.fi.relModelInfo.fields.pk.column
if jt.fi.reverse {
c1 = jt.mi.fields.pk.column
c2 = jt.fi.reverseFieldInfo.column
}
}
join += fmt.Sprintf("%s%s%s %s ON %s.%s%s%s = %s.%s%s%s ", Q, table, Q, t2,
t2, Q, c2, Q, t1, Q, c1, Q)
}
return
}
func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column, name string, info *fieldInfo, success bool) {
var (
ffi *fieldInfo
jtl *dbTable
mmi = mi
)
num := len(exprs) - 1
names := make([]string, 0)
for i, ex := range exprs {
exist := false
check:
fi, ok := mmi.fields.GetByAny(ex)
if ok {
if num != i {
names = append(names, fi.name)
switch {
case fi.rel:
mmi = fi.relModelInfo
if fi.fieldType == RelManyToMany {
mmi = fi.relThroughModelInfo
}
case fi.reverse:
mmi = fi.reverseFieldInfo.mi
if fi.reverseFieldInfo.fieldType == RelManyToMany {
mmi = fi.reverseFieldInfo.relThroughModelInfo
}
default:
return
}
jt, _ := d.add(names, mmi, fi, fi.null == false)
jt.jtl = jtl
jtl = jt
if fi.rel && fi.fieldType == RelManyToMany {
ex = fi.relModelInfo.name
goto check
}
if fi.reverse && fi.reverseFieldInfo.fieldType == RelManyToMany {
ex = fi.reverseFieldInfo.mi.name
goto check
}
exist = true
} else {
if ffi == nil {
index = "T0"
} else {
index = jtl.index
}
column = fi.column
info = fi
if jtl != nil {
name = jtl.name + ExprSep + fi.name
} else {
name = fi.name
}
switch fi.fieldType {
case RelManyToMany, RelReverseMany:
default:
exist = true
}
}
ffi = fi
}
if exist == false {
index = ""
column = ""
name = ""
success = false
return
}
}
success = index != "" && column != ""
return
}
func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params []interface{}) {
if cond == nil || cond.IsEmpty() {
return
}
Q := d.base.TableQuote()
mi := d.mi
// outFor:
for i, p := range cond.params {
if i > 0 {
if p.isOr {
where += "OR "
} else {
where += "AND "
}
}
if p.isNot {
where += "NOT "
}
if p.isCond {
w, ps := d.getCondSql(p.cond, true)
if w != "" {
w = fmt.Sprintf("( %s) ", w)
}
where += w
params = append(params, ps...)
} else {
exprs := p.exprs
num := len(exprs) - 1
operator := ""
if operators[exprs[num]] {
operator = exprs[num]
exprs = exprs[:num]
}
index, column, _, _, suc := d.parseExprs(mi, exprs)
if suc == false {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep)))
}
if operator == "" {
operator = "exact"
}
operSql, args := d.base.GenerateOperatorSql(mi, operator, p.args)
where += fmt.Sprintf("%s.%s%s%s %s ", index, Q, column, Q, operSql)
params = append(params, args...)
}
}
if sub == false && where != "" {
where = "WHERE " + where
}
return
}
func (d *dbTables) getOrderSql(orders []string) (orderSql string) {
if len(orders) == 0 {
return
}
Q := d.base.TableQuote()
orderSqls := make([]string, 0, len(orders))
for _, order := range orders {
asc := "ASC"
if order[0] == '-' {
asc = "DESC"
order = order[1:]
}
exprs := strings.Split(order, ExprSep)
index, column, _, _, suc := d.parseExprs(d.mi, exprs)
if suc == false {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
}
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, column, Q, asc))
}
orderSql = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
return
}
func (d *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int) (limits string) {
if limit == 0 {
limit = DefaultRowsLimit
}
if limit < 0 {
// no limit
if offset > 0 {
maxLimit := d.base.MaxLimit()
limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset)
}
} else if offset <= 0 {
limits = fmt.Sprintf("LIMIT %d", limit)
} else {
limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset)
}
return
}
func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
tables := &dbTables{}
tables.tablesM = make(map[string]*dbTable)
tables.mi = mi
tables.base = base
return tables
}
...@@ -79,7 +79,7 @@ func newModelInfo(val reflect.Value) (info *modelInfo) { ...@@ -79,7 +79,7 @@ func newModelInfo(val reflect.Value) (info *modelInfo) {
func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) { func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
info = new(modelInfo) info = new(modelInfo)
info.fields = newFields() info.fields = newFields()
info.table = m1.table + "_" + m2.table + "_rel" info.table = m1.table + "_" + m2.table + "s"
info.name = camelString(info.table) info.name = camelString(info.table)
info.fullName = m1.pkg + "." + info.name info.fullName = m1.pkg + "." + info.name
......
...@@ -3,10 +3,11 @@ package orm ...@@ -3,10 +3,11 @@ package orm
import ( import (
"fmt" "fmt"
"os" "os"
"strings"
"time" "time"
_ "github.com/bmizerany/pq"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
...@@ -95,8 +96,178 @@ var DBARGS = struct { ...@@ -95,8 +96,178 @@ var DBARGS = struct {
os.Getenv("ORM_DEBUG"), os.Getenv("ORM_DEBUG"),
} }
var (
IsMysql = DBARGS.Driver == "mysql"
IsSqlite = DBARGS.Driver == "sqlite3"
IsPostgres = DBARGS.Driver == "postgres"
)
var dORM Ormer var dORM Ormer
var initSQLs = map[string]string{
"mysql": "DROP TABLE IF EXISTS `user_profile`;\n" +
"DROP TABLE IF EXISTS `user`;\n" +
"DROP TABLE IF EXISTS `post`;\n" +
"DROP TABLE IF EXISTS `tag`;\n" +
"DROP TABLE IF EXISTS `post_tags`;\n" +
"DROP TABLE IF EXISTS `comment`;\n" +
"CREATE TABLE `user_profile` (\n" +
" `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
" `age` smallint NOT NULL,\n" +
" `money` double precision NOT NULL\n" +
") ENGINE=INNODB;\n" +
"CREATE TABLE `user` (\n" +
" `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
" `user_name` varchar(30) NOT NULL UNIQUE,\n" +
" `email` varchar(100) NOT NULL,\n" +
" `password` varchar(100) NOT NULL,\n" +
" `status` smallint NOT NULL,\n" +
" `is_staff` bool NOT NULL,\n" +
" `is_active` bool NOT NULL,\n" +
" `created` date NOT NULL,\n" +
" `updated` datetime NOT NULL,\n" +
" `profile_id` integer\n" +
") ENGINE=INNODB;\n" +
"CREATE TABLE `post` (\n" +
" `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
" `user_id` integer NOT NULL,\n" +
" `title` varchar(60) NOT NULL,\n" +
" `content` longtext NOT NULL,\n" +
" `created` datetime NOT NULL,\n" +
" `updated` datetime NOT NULL\n" +
") ENGINE=INNODB;\n" +
"CREATE TABLE `tag` (\n" +
" `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
" `name` varchar(30) NOT NULL\n" +
") ENGINE=INNODB;\n" +
"CREATE TABLE `post_tags` (\n" +
" `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
" `post_id` integer NOT NULL,\n" +
" `tag_id` integer NOT NULL,\n" +
" UNIQUE (`post_id`, `tag_id`)\n" +
") ENGINE=INNODB;\n" +
"CREATE TABLE `comment` (\n" +
" `id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY,\n" +
" `post_id` integer NOT NULL,\n" +
" `content` longtext NOT NULL,\n" +
" `parent_id` integer,\n" +
" `created` datetime NOT NULL\n" +
") ENGINE=INNODB;\n" +
"CREATE INDEX `user_141c6eec` ON `user` (`profile_id`);\n" +
"CREATE INDEX `post_fbfc09f1` ON `post` (`user_id`);\n" +
"CREATE INDEX `comment_699ae8ca` ON `comment` (`post_id`);\n" +
"CREATE INDEX `comment_63f17a16` ON `comment` (`parent_id`);",
"sqlite3": `
DROP TABLE IF EXISTS "user_profile";
DROP TABLE IF EXISTS "user";
DROP TABLE IF EXISTS "post";
DROP TABLE IF EXISTS "tag";
DROP TABLE IF EXISTS "post_tags";
DROP TABLE IF EXISTS "comment";
CREATE TABLE "user_profile" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"age" smallint NOT NULL,
"money" real NOT NULL
);
CREATE TABLE "user" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"user_name" varchar(30) NOT NULL UNIQUE,
"email" varchar(100) NOT NULL,
"password" varchar(100) NOT NULL,
"status" smallint NOT NULL,
"is_staff" bool NOT NULL,
"is_active" bool NOT NULL,
"created" date NOT NULL,
"updated" datetime NOT NULL,
"profile_id" integer
);
CREATE TABLE "post" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"user_id" integer NOT NULL,
"title" varchar(60) NOT NULL,
"content" text NOT NULL,
"created" datetime NOT NULL,
"updated" datetime NOT NULL
);
CREATE TABLE "tag" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"name" varchar(30) NOT NULL
);
CREATE TABLE "post_tags" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"post_id" integer NOT NULL,
"tag_id" integer NOT NULL,
UNIQUE ("post_id", "tag_id")
);
CREATE TABLE "comment" (
"id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"post_id" integer NOT NULL,
"content" text NOT NULL,
"parent_id" integer,
"created" datetime NOT NULL
);
CREATE INDEX "user_141c6eec" ON "user" ("profile_id");
CREATE INDEX "post_fbfc09f1" ON "post" ("user_id");
CREATE INDEX "comment_699ae8ca" ON "comment" ("post_id");
CREATE INDEX "comment_63f17a16" ON "comment" ("parent_id");
`,
"postgres": `
DROP TABLE IF EXISTS "user_profile";
DROP TABLE IF EXISTS "user";
DROP TABLE IF EXISTS "post";
DROP TABLE IF EXISTS "tag";
DROP TABLE IF EXISTS "post_tags";
DROP TABLE IF EXISTS "comment";
CREATE TABLE "user_profile" (
"id" serial NOT NULL PRIMARY KEY,
"age" smallint NOT NULL,
"money" double precision NOT NULL
);
CREATE TABLE "user" (
"id" serial NOT NULL PRIMARY KEY,
"user_name" varchar(30) NOT NULL UNIQUE,
"email" varchar(100) NOT NULL,
"password" varchar(100) NOT NULL,
"status" smallint NOT NULL,
"is_staff" boolean NOT NULL,
"is_active" boolean NOT NULL,
"created" date NOT NULL,
"updated" timestamp with time zone NOT NULL,
"profile_id" integer
);
CREATE TABLE "post" (
"id" serial NOT NULL PRIMARY KEY,
"user_id" integer NOT NULL,
"title" varchar(60) NOT NULL,
"content" text NOT NULL,
"created" timestamp with time zone NOT NULL,
"updated" timestamp with time zone NOT NULL
);
CREATE TABLE "tag" (
"id" serial NOT NULL PRIMARY KEY,
"name" varchar(30) NOT NULL
);
CREATE TABLE "post_tags" (
"id" serial NOT NULL PRIMARY KEY,
"post_id" integer NOT NULL,
"tag_id" integer NOT NULL,
UNIQUE ("post_id", "tag_id")
);
CREATE TABLE "comment" (
"id" serial NOT NULL PRIMARY KEY,
"post_id" integer NOT NULL,
"content" text NOT NULL,
"parent_id" integer,
"created" timestamp with time zone NOT NULL
);
CREATE INDEX "user_profile_id" ON "user" ("profile_id");
CREATE INDEX "post_user_id" ON "post" ("user_id");
CREATE INDEX "comment_post_id" ON "comment" ("post_id");
CREATE INDEX "comment_parent_id" ON "comment" ("parent_id");
`}
func init() { func init() {
RegisterModel(new(User)) RegisterModel(new(User))
RegisterModel(new(Profile)) RegisterModel(new(Profile))
...@@ -114,7 +285,7 @@ Default DB Drivers. ...@@ -114,7 +285,7 @@ Default DB Drivers.
driver: url driver: url
mysql: https://github.com/go-sql-driver/mysql mysql: https://github.com/go-sql-driver/mysql
sqlite3: https://github.com/mattn/go-sqlite3 sqlite3: https://github.com/mattn/go-sqlite3
postgres: https://github.com/bmizerany/pq postgres: https://github.com/lib/pq
eg: mysql eg: mysql
ORM_DRIVER=mysql ORM_SOURCE="root:root@/my_db?charset=utf8" go test github.com/astaxie/beego/orm ORM_DRIVER=mysql ORM_SOURCE="root:root@/my_db?charset=utf8" go test github.com/astaxie/beego/orm
...@@ -126,20 +297,16 @@ ORM_DRIVER=mysql ORM_SOURCE="root:root@/my_db?charset=utf8" go test github.com/a ...@@ -126,20 +297,16 @@ ORM_DRIVER=mysql ORM_SOURCE="root:root@/my_db?charset=utf8" go test github.com/a
BootStrap() BootStrap()
truncateTables()
dORM = NewOrm() dORM = NewOrm()
}
func truncateTables() { queries := strings.Split(initSQLs[DBARGS.Driver], ";")
logs := "truncate tables for test\n"
o := NewOrm() for _, query := range queries {
for _, m := range modelCache.allOrdered() { if strings.TrimSpace(query) == "" {
query := fmt.Sprintf("truncate table `%s`", m.table) continue
_, err := o.Raw(query).Exec() }
logs += query + "\n" _, err := dORM.Raw(query).Exec()
if err != nil { if err != nil {
fmt.Println(logs)
fmt.Println(err) fmt.Println(err)
os.Exit(2) os.Exit(2)
} }
......
...@@ -135,7 +135,7 @@ func (d *dbQueryLog) Commit() error { ...@@ -135,7 +135,7 @@ func (d *dbQueryLog) Commit() error {
func (d *dbQueryLog) Rollback() error { func (d *dbQueryLog) Rollback() error {
a := time.Now() a := time.Now()
err := d.db.(txEnder).Commit() err := d.db.(txEnder).Rollback()
debugLogQueies(d.alias, "tx.Rollback", "ROLLBACK", a, err) debugLogQueies(d.alias, "tx.Rollback", "ROLLBACK", a, err)
return err return err
} }
......
...@@ -6,39 +6,17 @@ import ( ...@@ -6,39 +6,17 @@ import (
"reflect" "reflect"
) )
func getResult(res sql.Result) (int64, error) {
if num, err := res.LastInsertId(); err != nil {
return 0, err
} else {
if num > 0 {
return num, nil
}
}
if num, err := res.RowsAffected(); err != nil {
return num, err
} else {
if num > 0 {
return num, nil
}
}
return 0, nil
}
type rawPrepare struct { type rawPrepare struct {
rs *rawSet rs *rawSet
stmt stmtQuerier stmt stmtQuerier
closed bool closed bool
} }
func (o *rawPrepare) Exec(args ...interface{}) (int64, error) { func (o *rawPrepare) Exec(args ...interface{}) (sql.Result, error) {
if o.closed { if o.closed {
return 0, ErrStmtClosed return nil, ErrStmtClosed
}
res, err := o.stmt.Exec(args...)
if err != nil {
return 0, err
} }
return getResult(res) return o.stmt.Exec(args...)
} }
func (o *rawPrepare) Close() error { func (o *rawPrepare) Close() error {
...@@ -74,12 +52,8 @@ func (o rawSet) SetArgs(args ...interface{}) RawSeter { ...@@ -74,12 +52,8 @@ func (o rawSet) SetArgs(args ...interface{}) RawSeter {
return &o return &o
} }
func (o *rawSet) Exec() (int64, error) { func (o *rawSet) Exec() (sql.Result, error) {
res, err := o.orm.db.Exec(o.query, o.args...) return o.orm.db.Exec(o.query, o.args...)
if err != nil {
return 0, err
}
return getResult(res)
} }
func (o *rawSet) QueryRow(...interface{}) error { func (o *rawSet) QueryRow(...interface{}) error {
......
This diff is collapsed.
...@@ -60,12 +60,12 @@ type QuerySeter interface { ...@@ -60,12 +60,12 @@ type QuerySeter interface {
} }
type RawPreparer interface { type RawPreparer interface {
Exec(...interface{}) (int64, error) Exec(...interface{}) (sql.Result, error)
Close() error Close() error
} }
type RawSeter interface { type RawSeter interface {
Exec() (int64, error) Exec() (sql.Result, error)
QueryRow(...interface{}) error QueryRow(...interface{}) error
QueryRows(...interface{}) (int64, error) QueryRows(...interface{}) (int64, error)
SetArgs(...interface{}) RawSeter SetArgs(...interface{}) RawSeter
...@@ -116,10 +116,15 @@ type dbBaser interface { ...@@ -116,10 +116,15 @@ type dbBaser interface {
Update(dbQuerier, *modelInfo, reflect.Value) (int64, error) Update(dbQuerier, *modelInfo, reflect.Value) (int64, error)
Delete(dbQuerier, *modelInfo, reflect.Value) (int64, error) Delete(dbQuerier, *modelInfo, reflect.Value) (int64, error)
ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}) (int64, error) ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}) (int64, error)
SupportUpdateJoin() bool
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params) (int64, error) UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params) (int64, error)
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error) DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error)
Count(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error) Count(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error)
GetOperatorSql(*modelInfo, string, []interface{}) (string, []interface{}) OperatorSql(string) string
GenerateOperatorSql(*modelInfo, string, []interface{}) (string, []interface{})
PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error) PrepareInsert(dbQuerier, *modelInfo) (stmtQuerier, string, error)
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}) (int64, error) ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}) (int64, error)
MaxLimit() uint64
TableQuote() string
ReplaceMarks(*string)
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment