Commit 46668b81 authored by slene's avatar slene

some fix / add test

parent 10f4e822
This diff is collapsed.
...@@ -9,24 +9,37 @@ import ( ...@@ -9,24 +9,37 @@ import (
const defaultMaxIdle = 30 const defaultMaxIdle = 30
type driverType int type DriverType int
const ( const (
_ driverType = iota _ DriverType = iota
DR_MySQL DR_MySQL
DR_Sqlite DR_Sqlite
DR_Oracle DR_Oracle
DR_Postgres DR_Postgres
) )
type driver string
func (d driver) Type() DriverType {
a, _ := dataBaseCache.get(string(d))
return a.Driver
}
func (d driver) Name() string {
return string(d)
}
var _ Driver = new(driver)
var ( var (
dataBaseCache = &_dbCache{cache: make(map[string]*alias)} dataBaseCache = &_dbCache{cache: make(map[string]*alias)}
drivers = map[string]driverType{ drivers = map[string]DriverType{
"mysql": DR_MySQL, "mysql": DR_MySQL,
"postgres": DR_Postgres, "postgres": DR_Postgres,
"sqlite3": DR_Sqlite, "sqlite3": DR_Sqlite,
} }
dbBasers = map[driverType]dbBaser{ dbBasers = map[DriverType]dbBaser{
DR_MySQL: newdbBaseMysql(), DR_MySQL: newdbBaseMysql(),
DR_Sqlite: newdbBaseSqlite(), DR_Sqlite: newdbBaseSqlite(),
DR_Oracle: newdbBaseMysql(), DR_Oracle: newdbBaseMysql(),
...@@ -63,6 +76,7 @@ func (ac *_dbCache) getDefault() (al *alias) { ...@@ -63,6 +76,7 @@ func (ac *_dbCache) getDefault() (al *alias) {
type alias struct { type alias struct {
Name string Name string
Driver DriverType
DriverName string DriverName string
DataSource string DataSource string
MaxIdle int MaxIdle int
...@@ -87,6 +101,7 @@ func RegisterDataBase(name, driverName, dataSource string, maxIdle int) { ...@@ -87,6 +101,7 @@ func RegisterDataBase(name, driverName, dataSource string, maxIdle int) {
if dr, ok := drivers[driverName]; ok { if dr, ok := drivers[driverName]; ok {
al.DbBaser = dbBasers[dr] al.DbBaser = dbBasers[dr]
al.Driver = dr
} else { } else {
err = fmt.Errorf("driver name `%s` have not registered", driverName) err = fmt.Errorf("driver name `%s` have not registered", driverName)
goto end goto end
...@@ -116,7 +131,7 @@ end: ...@@ -116,7 +131,7 @@ end:
} }
} }
func RegisterDriver(name string, typ driverType) { func RegisterDriver(name string, typ DriverType) {
if t, ok := drivers[name]; ok == false { if t, ok := drivers[name]; ok == false {
drivers[name] = typ drivers[name] = typ
} else { } else {
......
...@@ -49,6 +49,7 @@ type _modelCache struct { ...@@ -49,6 +49,7 @@ type _modelCache struct {
sync.RWMutex sync.RWMutex
orders []string orders []string
cache map[string]*modelInfo cache map[string]*modelInfo
done bool
} }
func (mc *_modelCache) all() map[string]*modelInfo { func (mc *_modelCache) all() map[string]*modelInfo {
......
...@@ -8,7 +8,7 @@ import ( ...@@ -8,7 +8,7 @@ import (
"strings" "strings"
) )
func RegisterModel(model Modeler) { func registerModel(model Modeler) {
info := newModelInfo(model) info := newModelInfo(model)
model.Init(model) model.Init(model)
table := model.GetTableName() table := model.GetTableName()
...@@ -27,9 +27,10 @@ func RegisterModel(model Modeler) { ...@@ -27,9 +27,10 @@ func RegisterModel(model Modeler) {
modelCache.set(table, info) modelCache.set(table, info)
} }
func BootStrap() { func bootStrap() {
modelCache.Lock() if modelCache.done {
defer modelCache.Unlock() return
}
var ( var (
err error err error
...@@ -59,14 +60,6 @@ func BootStrap() { ...@@ -59,14 +60,6 @@ func BootStrap() {
} }
fi.relModelInfo = mii fi.relModelInfo = mii
if fi.rel {
if mii.fields.pk.IsMulti() {
err = fmt.Errorf("field `%s` unsupport rel to multi primary key field", fi.fullName)
goto end
}
}
switch fi.fieldType { switch fi.fieldType {
case RelManyToMany: case RelManyToMany:
if fi.relThrough != "" { if fi.relThrough != "" {
...@@ -207,6 +200,25 @@ end: ...@@ -207,6 +200,25 @@ end:
fmt.Println(err) fmt.Println(err)
os.Exit(2) os.Exit(2)
} }
}
runCommand() func RegisterModel(models ...Modeler) {
if modelCache.done {
panic(fmt.Errorf("RegisterModel must be run begore BootStrap"))
}
for _, model := range models {
registerModel(model)
}
}
func BootStrap() {
if modelCache.done {
return
}
modelCache.Lock()
defer modelCache.Unlock()
bootStrap()
modelCache.done = true
} }
...@@ -32,32 +32,8 @@ func (f *fieldChoices) Clone() fieldChoices { ...@@ -32,32 +32,8 @@ func (f *fieldChoices) Clone() fieldChoices {
return *f return *f
} }
type primaryKeys []*fieldInfo
func (p *primaryKeys) Add(fi *fieldInfo) {
*p = append(*p, fi)
}
func (p primaryKeys) Exist(fi *fieldInfo) (int, bool) {
for i, v := range p {
if v == fi {
return i, true
}
}
return -1, false
}
func (p primaryKeys) IsMulti() bool {
return len(p) > 1
}
func (p primaryKeys) IsEmpty() bool {
return len(p) == 0
}
type fields struct { type fields struct {
pk primaryKeys pk *fieldInfo
auto *fieldInfo
columns map[string]*fieldInfo columns map[string]*fieldInfo
fields map[string]*fieldInfo fields map[string]*fieldInfo
fieldsLow map[string]*fieldInfo fieldsLow map[string]*fieldInfo
......
...@@ -50,41 +50,31 @@ func newModelInfo(model Modeler) (info *modelInfo) { ...@@ -50,41 +50,31 @@ func newModelInfo(model Modeler) (info *modelInfo) {
if err != nil { if err != nil {
break break
} }
added := info.fields.Add(fi) added := info.fields.Add(fi)
if added == false { if added == false {
err = errors.New(fmt.Sprintf("duplicate column name: %s", fi.column)) err = errors.New(fmt.Sprintf("duplicate column name: %s", fi.column))
break break
} }
if fi.pk { if fi.pk {
if info.fields.pk != nil { if info.fields.pk != nil {
err = errors.New(fmt.Sprintf("one model must have one pk field only")) err = errors.New(fmt.Sprintf("one model must have one pk field only"))
break break
} else { } else {
info.fields.pk.Add(fi) info.fields.pk = fi
} }
} }
if fi.auto {
info.fields.auto = fi
}
fi.fieldIndex = i fi.fieldIndex = i
fi.mi = info fi.mi = info
} }
if _, ok := info.fields.pk.Exist(info.fields.auto); info.fields.auto != nil && ok == false {
err = errors.New(fmt.Sprintf("when auto field exists, you cannot set other pk field"))
goto end
}
if err != nil { if err != nil {
fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err)) fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err))
os.Exit(2) os.Exit(2)
} }
end:
if err != nil {
fmt.Println(err)
os.Exit(2)
}
return return
} }
...@@ -125,6 +115,6 @@ func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) { ...@@ -125,6 +115,6 @@ func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
info.fields.Add(fa) info.fields.Add(fa)
info.fields.Add(f1) info.fields.Add(f1)
info.fields.Add(f2) info.fields.Add(f2)
info.fields.pk.Add(fa) info.fields.pk = fa
return return
} }
package orm
import (
"fmt"
"os"
"time"
_ "github.com/bmizerany/pq"
_ "github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3"
)
type User struct {
Id int `orm:"auto"`
UserName string `orm:"size(30);unique"`
Email string `orm:"size(100)"`
Password string `orm:"size(100)"`
Status int16 `orm:"choices(0,1,2,3);defalut(0)"`
IsStaff bool `orm:"default(false)"`
IsActive bool `orm:"default(1)"`
Created time.Time `orm:"auto_now_add;type(date)"`
Updated time.Time `orm:"auto_now"`
Profile *Profile `orm:"null;rel(one);on_delete(set_null)"`
Posts []*Post `orm:"reverse(many)" json:"-"`
Manager `json:"-"`
}
func NewUser() *User {
obj := new(User)
obj.Manager.Init(obj)
return obj
}
type Profile struct {
Id int `orm:"auto"`
Age int16 ``
Money float64 ``
User *User `orm:"reverse(one)" json:"-"`
Manager `json:"-"`
}
func (u *Profile) TableName() string {
return "user_profile"
}
func NewProfile() *Profile {
obj := new(Profile)
obj.Manager.Init(obj)
return obj
}
type Post struct {
Id int `orm:"auto"`
User *User `orm:"rel(fk)"` //
Title string `orm:"size(60)"`
Content string ``
Created time.Time `orm:"auto_now_add"`
Updated time.Time `orm:"auto_now"`
Tags []*Tag `orm:"rel(m2m)"`
Manager `json:"-"`
}
func NewPost() *Post {
obj := new(Post)
obj.Manager.Init(obj)
return obj
}
type Tag struct {
Id int `orm:"auto"`
Name string `orm:"size(30)"`
Posts []*Post `orm:"reverse(many)" json:"-"`
Manager `json:"-"`
}
func NewTag() *Tag {
obj := new(Tag)
obj.Manager.Init(obj)
return obj
}
type Comment struct {
Id int `orm:"auto"`
Post *Post `orm:"rel(fk)"`
Content string ``
Parent *Comment `orm:"null;rel(fk)"`
Created time.Time `orm:"auto_now_add"`
Manager `json:"-"`
}
func NewComment() *Comment {
obj := new(Comment)
obj.Manager.Init(obj)
return obj
}
var DBARGS = struct {
Driver string
Source string
}{
os.Getenv("ORM_DRIVER"),
os.Getenv("ORM_SOURCE"),
}
var dORM Ormer
func init() {
RegisterModel(new(User))
RegisterModel(new(Profile))
RegisterModel(new(Post))
RegisterModel(new(Tag))
RegisterModel(new(Comment))
if DBARGS.Driver == "" || DBARGS.Source == "" {
fmt.Println(`need driver and source!
Default DB Drivers.
driver: url
mysql: https://github.com/go-sql-driver/mysql
sqlite3: https://github.com/mattn/go-sqlite3
postgres: https://github.com/bmizerany/pq
eg: mysql
ORM_DRIVER=mysql ORM_SOURCE="root:root@/my_db?charset=utf8" go test github.com/astaxie/beego/orm
`)
os.Exit(2)
}
RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20)
BootStrap()
truncateTables()
dORM = NewOrm()
}
func truncateTables() {
logs := "truncate tables for test\n"
o := NewOrm()
for _, m := range modelCache.allOrdered() {
query := fmt.Sprintf("truncate table `%s`", m.table)
_, err := o.Raw(query).Exec()
logs += query + "\n"
if err != nil {
fmt.Println(logs)
fmt.Println(err)
os.Exit(2)
}
}
}
...@@ -9,13 +9,15 @@ import ( ...@@ -9,13 +9,15 @@ import (
) )
var ( var (
ErrTXHasBegin = errors.New("<Ormer.Begin> transaction already begin")
ErrTXNotBegin = errors.New("<Ormer.Commit/Rollback> transaction not begin")
ErrMultiRows = errors.New("<QuerySeter.One> return multi rows")
ErrStmtClosed = errors.New("<QuerySeter.Insert> stmt already closed")
DefaultRowsLimit = 1000 DefaultRowsLimit = 1000
DefaultRelsDepth = 5 DefaultRelsDepth = 5
DefaultTimeLoc = time.Local DefaultTimeLoc = time.Local
ErrTxHasBegan = errors.New("<Ormer.Begin> transaction already begin")
ErrTxDone = errors.New("<Ormer.Commit/Rollback> transaction not begin")
ErrMultiRows = errors.New("<QuerySeter> return multi rows")
ErrNoRows = errors.New("<QuerySeter> not row found")
ErrStmtClosed = errors.New("<QuerySeter> stmt already closed")
ErrNotImplement = errors.New("have not implement")
) )
type Params map[string]interface{} type Params map[string]interface{}
...@@ -27,13 +29,15 @@ type orm struct { ...@@ -27,13 +29,15 @@ type orm struct {
isTx bool isTx bool
} }
var _ Ormer = new(orm)
func (o *orm) getMiInd(md Modeler) (mi *modelInfo, ind reflect.Value) { func (o *orm) getMiInd(md Modeler) (mi *modelInfo, ind reflect.Value) {
md.Init(md, true) md.Init(md, true)
name := md.GetTableName() name := md.GetTableName()
if mi, ok := modelCache.get(name); ok { if mi, ok := modelCache.get(name); ok {
return mi, reflect.Indirect(reflect.ValueOf(md)) return mi, reflect.Indirect(reflect.ValueOf(md))
} }
panic(fmt.Sprintf("<orm.Object> table name: `%s` not exists", name)) panic(fmt.Sprintf("<orm> table name: `%s` not exists", name))
} }
func (o *orm) Read(md Modeler) error { func (o *orm) Read(md Modeler) error {
...@@ -52,8 +56,8 @@ func (o *orm) Insert(md Modeler) (int64, error) { ...@@ -52,8 +56,8 @@ func (o *orm) Insert(md Modeler) (int64, error) {
return id, err return id, err
} }
if id > 0 { if id > 0 {
if mi.fields.auto != nil { if mi.fields.pk.auto {
ind.Field(mi.fields.auto.fieldIndex).SetInt(id) ind.Field(mi.fields.pk.fieldIndex).SetInt(id)
} }
} }
return id, nil return id, nil
...@@ -75,13 +79,31 @@ func (o *orm) Delete(md Modeler) (int64, error) { ...@@ -75,13 +79,31 @@ func (o *orm) Delete(md Modeler) (int64, error) {
return num, err return num, err
} }
if num > 0 { if num > 0 {
if mi.fields.auto != nil { if mi.fields.pk.auto {
ind.Field(mi.fields.auto.fieldIndex).SetInt(0) ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
} }
} }
return num, nil return num, nil
} }
func (o *orm) M2mAdd(md Modeler, name string, mds ...interface{}) (int64, error) {
// TODO
panic(ErrNotImplement)
return 0, nil
}
func (o *orm) M2mDel(md Modeler, name string, mds ...interface{}) (int64, error) {
// TODO
panic(ErrNotImplement)
return 0, nil
}
func (o *orm) LoadRel(md Modeler, name string) (int64, error) {
// TODO
panic(ErrNotImplement)
return 0, nil
}
func (o *orm) QueryTable(ptrStructOrTableName interface{}) QuerySeter { func (o *orm) QueryTable(ptrStructOrTableName interface{}) QuerySeter {
name := "" name := ""
if table, ok := ptrStructOrTableName.(string); ok { if table, ok := ptrStructOrTableName.(string); ok {
...@@ -111,7 +133,7 @@ func (o *orm) Using(name string) error { ...@@ -111,7 +133,7 @@ func (o *orm) Using(name string) error {
func (o *orm) Begin() error { func (o *orm) Begin() error {
if o.isTx { if o.isTx {
return ErrTXHasBegin return ErrTxHasBegan
} }
tx, err := o.alias.DB.Begin() tx, err := o.alias.DB.Begin()
if err != nil { if err != nil {
...@@ -124,24 +146,28 @@ func (o *orm) Begin() error { ...@@ -124,24 +146,28 @@ func (o *orm) Begin() error {
func (o *orm) Commit() error { func (o *orm) Commit() error {
if o.isTx == false { if o.isTx == false {
return ErrTXNotBegin return ErrTxDone
} }
err := o.db.(*sql.Tx).Commit() err := o.db.(*sql.Tx).Commit()
if err == nil { if err == nil {
o.isTx = false o.isTx = false
o.db = o.alias.DB o.db = o.alias.DB
} else if err == sql.ErrTxDone {
return ErrTxDone
} }
return err return err
} }
func (o *orm) Rollback() error { func (o *orm) Rollback() error {
if o.isTx == false { if o.isTx == false {
return ErrTXNotBegin return ErrTxDone
} }
err := o.db.(*sql.Tx).Rollback() err := o.db.(*sql.Tx).Rollback()
if err == nil { if err == nil {
o.isTx = false o.isTx = false
o.db = o.alias.DB o.db = o.alias.DB
} else if err == sql.ErrTxDone {
return ErrTxDone
} }
return err return err
} }
...@@ -150,7 +176,13 @@ func (o *orm) Raw(query string, args ...interface{}) RawSeter { ...@@ -150,7 +176,13 @@ func (o *orm) Raw(query string, args ...interface{}) RawSeter {
return newRawSet(o, query, args) return newRawSet(o, query, args)
} }
func (o *orm) Driver() Driver {
return driver(o.alias.Name)
}
func NewOrm() Ormer { func NewOrm() Ormer {
BootStrap() // execute only once
o := new(orm) o := new(orm)
err := o.Using("default") err := o.Using("default")
if err != nil { if err != nil {
......
...@@ -26,23 +26,24 @@ func NewCondition() *Condition { ...@@ -26,23 +26,24 @@ func NewCondition() *Condition {
return c return c
} }
func (c *Condition) And(expr string, args ...interface{}) *Condition { func (c Condition) And(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic("<Condition.And> args cannot empty") panic("<Condition.And> args cannot empty")
} }
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args}) c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args})
return c return &c
} }
func (c *Condition) AndNot(expr string, args ...interface{}) *Condition { func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic("<Condition.AndNot> args cannot empty") panic("<Condition.AndNot> args cannot empty")
} }
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true}) c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true})
return c return &c
} }
func (c *Condition) AndCond(cond *Condition) *Condition { func (c *Condition) AndCond(cond *Condition) *Condition {
c = c.clone()
if c == cond { if c == cond {
panic("cannot use self as sub cond") panic("cannot use self as sub cond")
} }
...@@ -52,23 +53,24 @@ func (c *Condition) AndCond(cond *Condition) *Condition { ...@@ -52,23 +53,24 @@ func (c *Condition) AndCond(cond *Condition) *Condition {
return c return c
} }
func (c *Condition) Or(expr string, args ...interface{}) *Condition { func (c Condition) Or(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic("<Condition.Or> args cannot empty") panic("<Condition.Or> args cannot empty")
} }
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isOr: true}) c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isOr: true})
return c return &c
} }
func (c *Condition) OrNot(expr string, args ...interface{}) *Condition { func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 { if expr == "" || len(args) == 0 {
panic("<Condition.OrNot> args cannot empty") panic("<Condition.OrNot> args cannot empty")
} }
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true, isOr: true}) c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true, isOr: true})
return c return &c
} }
func (c *Condition) OrCond(cond *Condition) *Condition { func (c *Condition) OrCond(cond *Condition) *Condition {
c = c.clone()
if c == cond { if c == cond {
panic("cannot use self as sub cond") panic("cannot use self as sub cond")
} }
...@@ -82,13 +84,6 @@ func (c *Condition) IsEmpty() bool { ...@@ -82,13 +84,6 @@ func (c *Condition) IsEmpty() bool {
return len(c.params) == 0 return len(c.params) == 0
} }
func (c Condition) Clone() *Condition { func (c Condition) clone() *Condition {
params := c.params
c.params = make([]condValue, len(params))
copy(c.params, params)
return &c return &c
} }
func (c *Condition) Merge() (expr string, args []interface{}) {
return expr, args
}
...@@ -13,6 +13,8 @@ type insertSet struct { ...@@ -13,6 +13,8 @@ type insertSet struct {
closed bool closed bool
} }
var _ Inserter = new(insertSet)
func (o *insertSet) Insert(md Modeler) (int64, error) { func (o *insertSet) Insert(md Modeler) (int64, error) {
if o.closed { if o.closed {
return 0, ErrStmtClosed return 0, ErrStmtClosed
...@@ -28,14 +30,17 @@ func (o *insertSet) Insert(md Modeler) (int64, error) { ...@@ -28,14 +30,17 @@ func (o *insertSet) Insert(md Modeler) (int64, error) {
return id, err return id, err
} }
if id > 0 { if id > 0 {
if o.mi.fields.auto != nil { if o.mi.fields.pk.auto {
ind.Field(o.mi.fields.auto.fieldIndex).SetInt(id) ind.Field(o.mi.fields.pk.fieldIndex).SetInt(id)
} }
} }
return id, nil return id, nil
} }
func (o *insertSet) Close() error { func (o *insertSet) Close() error {
if o.closed {
return ErrStmtClosed
}
o.closed = true o.closed = true
return o.stmt.Close() return o.stmt.Close()
} }
......
...@@ -15,47 +15,43 @@ type querySet struct { ...@@ -15,47 +15,43 @@ type querySet struct {
orm *orm orm *orm
} }
func (o *querySet) Filter(expr string, args ...interface{}) QuerySeter { var _ QuerySeter = new(querySet)
o = o.clone()
func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
if o.cond == nil { if o.cond == nil {
o.cond = NewCondition() o.cond = NewCondition()
} }
o.cond.And(expr, args...) o.cond = o.cond.And(expr, args...)
return o return &o
} }
func (o *querySet) Exclude(expr string, args ...interface{}) QuerySeter { func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
o = o.clone()
if o.cond == nil { if o.cond == nil {
o.cond = NewCondition() o.cond = NewCondition()
} }
o.cond.AndNot(expr, args...) o.cond = o.cond.AndNot(expr, args...)
return o return &o
} }
func (o *querySet) Limit(limit int, args ...int64) QuerySeter { func (o querySet) Limit(limit int, args ...int64) QuerySeter {
o = o.clone()
o.limit = limit o.limit = limit
if len(args) > 0 { if len(args) > 0 {
o.offset = args[0] o.offset = args[0]
} }
return o return &o
} }
func (o *querySet) Offset(offset int64) QuerySeter { func (o querySet) Offset(offset int64) QuerySeter {
o = o.clone()
o.offset = offset o.offset = offset
return o return &o
} }
func (o *querySet) OrderBy(exprs ...string) QuerySeter { func (o querySet) OrderBy(exprs ...string) QuerySeter {
o = o.clone()
o.orders = exprs o.orders = exprs
return o return &o
} }
func (o *querySet) RelatedSel(params ...interface{}) QuerySeter { func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
o = o.clone()
var related []string var related []string
if len(params) == 0 { if len(params) == 0 {
o.relDepth = DefaultRelsDepth o.relDepth = DefaultRelsDepth
...@@ -72,13 +68,6 @@ func (o *querySet) RelatedSel(params ...interface{}) QuerySeter { ...@@ -72,13 +68,6 @@ func (o *querySet) RelatedSel(params ...interface{}) QuerySeter {
} }
} }
o.related = related o.related = related
return o
}
func (o querySet) clone() *querySet {
if o.cond != nil {
o.cond = o.cond.Clone()
}
return &o return &o
} }
...@@ -115,6 +104,9 @@ func (o *querySet) One(container Modeler) error { ...@@ -115,6 +104,9 @@ func (o *querySet) One(container Modeler) error {
if num > 1 { if num > 1 {
return ErrMultiRows return ErrMultiRows
} }
if num == 0 {
return ErrNoRows
}
return nil return nil
} }
......
...@@ -63,6 +63,8 @@ type rawSet struct { ...@@ -63,6 +63,8 @@ type rawSet struct {
orm *orm orm *orm
} }
var _ RawSeter = new(rawSet)
func (o rawSet) SetArgs(args ...interface{}) RawSeter { func (o rawSet) SetArgs(args ...interface{}) RawSeter {
o.args = args o.args = args
return &o return &o
...@@ -76,7 +78,12 @@ func (o *rawSet) Exec() (int64, error) { ...@@ -76,7 +78,12 @@ func (o *rawSet) Exec() (int64, error) {
return getResult(res) return getResult(res)
} }
func (o *rawSet) Mapper(...interface{}) (int64, error) { func (o *rawSet) QueryRow(...interface{}) error {
//TODO
return nil
}
func (o *rawSet) QueryRows(...interface{}) (int64, error) {
//TODO //TODO
return 0, nil return 0, nil
} }
...@@ -120,7 +127,7 @@ func (o *rawSet) readValues(container interface{}) (int64, error) { ...@@ -120,7 +127,7 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
cols = columns cols = columns
refs = make([]interface{}, len(cols)) refs = make([]interface{}, len(cols))
for i, _ := range refs { for i, _ := range refs {
var ref string var ref sql.NullString
refs[i] = &ref refs[i] = &ref
} }
} }
...@@ -134,21 +141,21 @@ func (o *rawSet) readValues(container interface{}) (int64, error) { ...@@ -134,21 +141,21 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
case 1: case 1:
params := make(Params, len(cols)) params := make(Params, len(cols))
for i, ref := range refs { for i, ref := range refs {
value := reflect.Indirect(reflect.ValueOf(ref)).Interface() value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
params[cols[i]] = value params[cols[i]] = value.String
} }
maps = append(maps, params) maps = append(maps, params)
case 2: case 2:
params := make(ParamsList, 0, len(cols)) params := make(ParamsList, 0, len(cols))
for _, ref := range refs { for _, ref := range refs {
value := reflect.Indirect(reflect.ValueOf(ref)).Interface() value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
params = append(params, value) params = append(params, value.String)
} }
lists = append(lists, params) lists = append(lists, params)
case 3: case 3:
for _, ref := range refs { for _, ref := range refs {
value := reflect.Indirect(reflect.ValueOf(ref)).Interface() value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
list = append(list, value) list = append(list, value.String)
} }
} }
......
This diff is collapsed.
...@@ -5,6 +5,11 @@ import ( ...@@ -5,6 +5,11 @@ import (
"reflect" "reflect"
) )
type Driver interface {
Name() string
Type() DriverType
}
type Fielder interface { type Fielder interface {
String() string String() string
FieldType() int FieldType() int
...@@ -26,12 +31,16 @@ type Ormer interface { ...@@ -26,12 +31,16 @@ type Ormer interface {
Insert(Modeler) (int64, error) Insert(Modeler) (int64, error)
Update(Modeler) (int64, error) Update(Modeler) (int64, error)
Delete(Modeler) (int64, error) Delete(Modeler) (int64, error)
M2mAdd(Modeler, string, ...interface{}) (int64, error)
M2mDel(Modeler, string, ...interface{}) (int64, error)
LoadRel(Modeler, string) (int64, error)
QueryTable(interface{}) QuerySeter QueryTable(interface{}) QuerySeter
Using(string) error Using(string) error
Begin() error Begin() error
Commit() error Commit() error
Rollback() error Rollback() error
Raw(string, ...interface{}) RawSeter Raw(string, ...interface{}) RawSeter
Driver() Driver
} }
type Inserter interface { type Inserter interface {
...@@ -42,16 +51,15 @@ type Inserter interface { ...@@ -42,16 +51,15 @@ type Inserter interface {
type QuerySeter interface { type QuerySeter interface {
Filter(string, ...interface{}) QuerySeter Filter(string, ...interface{}) QuerySeter
Exclude(string, ...interface{}) QuerySeter Exclude(string, ...interface{}) QuerySeter
SetCond(*Condition) QuerySeter
Limit(int, ...int64) QuerySeter Limit(int, ...int64) QuerySeter
Offset(int64) QuerySeter Offset(int64) QuerySeter
OrderBy(...string) QuerySeter OrderBy(...string) QuerySeter
RelatedSel(...interface{}) QuerySeter RelatedSel(...interface{}) QuerySeter
SetCond(*Condition) QuerySeter
Count() (int64, error) Count() (int64, error)
Update(Params) (int64, error) Update(Params) (int64, error)
Delete() (int64, error) Delete() (int64, error)
PrepareInsert() (Inserter, error) PrepareInsert() (Inserter, error)
All(interface{}) (int64, error) All(interface{}) (int64, error)
One(Modeler) error One(Modeler) error
Values(*[]Params, ...string) (int64, error) Values(*[]Params, ...string) (int64, error)
...@@ -60,12 +68,15 @@ type QuerySeter interface { ...@@ -60,12 +68,15 @@ type QuerySeter interface {
} }
type RawPreparer interface { type RawPreparer interface {
Exec(...interface{}) (int64, error)
Close() error Close() error
} }
type RawSeter interface { type RawSeter interface {
Exec() (int64, error) Exec() (int64, error)
Mapper(...interface{}) (int64, error) QueryRow(...interface{}) error
QueryRows(...interface{}) (int64, error)
SetArgs(...interface{}) RawSeter
Values(*[]Params) (int64, error) Values(*[]Params) (int64, error)
ValuesList(*[]ParamsList) (int64, error) ValuesList(*[]ParamsList) (int64, error)
ValuesFlat(*ParamsList) (int64, error) ValuesFlat(*ParamsList) (int64, error)
......
...@@ -171,6 +171,18 @@ func (a argInt) Get(i int, args ...int) (r int) { ...@@ -171,6 +171,18 @@ func (a argInt) Get(i int, args ...int) (r int) {
return return
} }
type argAny []interface{}
func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
if i >= 0 && i < len(a) {
r = a[i]
}
if len(args) > 0 {
r = args[0]
}
return
}
func timeParse(dateString, format string) (time.Time, error) { func timeParse(dateString, format string) (time.Time, error) {
tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc) tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
return tp, err return tp, err
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment