Commit f2b359d8 authored by slene's avatar slene

orm full remove orm.Manager for simple use, add struct tag `-` for skip struct field

parent 402932aa
...@@ -33,7 +33,6 @@ import ( ...@@ -33,7 +33,6 @@ import (
type User struct { type User struct {
Id int `orm:"auto"` Id int `orm:"auto"`
Name string `orm:"size(100)"` Name string `orm:"size(100)"`
orm.Manager
} }
func init() { func init() {
...@@ -72,7 +71,6 @@ type Post struct { ...@@ -72,7 +71,6 @@ type Post struct {
Id int `orm:"auto"` Id int `orm:"auto"`
Title string `orm:"size(100)"` Title string `orm:"size(100)"`
User *User `orm:"rel(fk)"` User *User `orm:"rel(fk)"`
orm.Manager
} }
var posts []*Post var posts []*Post
......
...@@ -582,8 +582,6 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error { ...@@ -582,8 +582,6 @@ func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value) error {
return err return err
} else { } else {
elm := reflect.New(mi.addrField.Elem().Type()) elm := reflect.New(mi.addrField.Elem().Type())
md := elm.Interface().(Modeler)
md.Init(md)
mind := reflect.Indirect(elm) mind := reflect.Indirect(elm)
d.setColsValues(mi, &mind, mi.fields.dbcols, refs) d.setColsValues(mi, &mind, mi.fields.dbcols, refs)
...@@ -803,25 +801,27 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi ...@@ -803,25 +801,27 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
val := reflect.ValueOf(container) val := reflect.ValueOf(container)
ind := reflect.Indirect(val) ind := reflect.Indirect(val)
typ := ind.Type()
errTyp := true errTyp := true
one := true one := true
if val.Kind() == reflect.Ptr { if val.Kind() == reflect.Ptr {
tp := typ fn := ""
if ind.Kind() == reflect.Slice { if ind.Kind() == reflect.Slice {
one = false one = false
if ind.Type().Elem().Kind() == reflect.Ptr { if ind.Type().Elem().Kind() == reflect.Ptr {
tp = ind.Type().Elem().Elem() typ := ind.Type().Elem().Elem()
fn = getFullName(typ)
} }
} else {
fn = getFullName(ind.Type())
} }
errTyp = tp.PkgPath()+"."+tp.Name() != mi.fullName errTyp = fn != mi.fullName
} }
if errTyp { if errTyp {
panic(fmt.Sprintf("wrong object type `%s` for rows scan, need *[]*%s or *%s", val.Type(), mi.fullName, mi.fullName)) panic(fmt.Sprintf("wrong object type `%s` for rows scan, need *[]*%s or *%s", ind.Type(), mi.fullName, mi.fullName))
} }
rlimit := qs.limit rlimit := qs.limit
...@@ -873,8 +873,6 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi ...@@ -873,8 +873,6 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
} }
elm := reflect.New(mi.addrField.Elem().Type()) elm := reflect.New(mi.addrField.Elem().Type())
md := elm.Interface().(Modeler)
md.Init(md)
mind := reflect.Indirect(elm) mind := reflect.Indirect(elm)
cacheV := make(map[string]*reflect.Value) cacheV := make(map[string]*reflect.Value)
...@@ -989,9 +987,9 @@ func (d *dbBase) getOperatorParams(operator string, args []interface{}) (params ...@@ -989,9 +987,9 @@ func (d *dbBase) getOperatorParams(operator string, args []interface{}) (params
if ind.Kind() == reflect.Struct { if ind.Kind() == reflect.Struct {
typ := ind.Type() typ := ind.Type()
fullName := typ.PkgPath() + "." + typ.Name() name := getFullName(typ)
var value interface{} var value interface{}
if mmi, ok := modelCache.get(fullName); ok { if mmi, ok := modelCache.getByFN(name); ok {
if _, vu, exist := d.existPk(mmi, ind); exist { if _, vu, exist := d.existPk(mmi, ind); exist {
value = vu value = vu
} }
...@@ -999,7 +997,7 @@ func (d *dbBase) getOperatorParams(operator string, args []interface{}) (params ...@@ -999,7 +997,7 @@ func (d *dbBase) getOperatorParams(operator string, args []interface{}) (params
arg = value arg = value
if arg == nil { if arg == nil {
panic(fmt.Sprintf("`%s` operator need a valid args value, unknown table or value `%v`", operator, val.Type())) panic(fmt.Sprintf("`%s` operator need a valid args value, unknown table or value `%s`", operator, name))
} }
} else { } else {
arg = ind.Interface() arg = ind.Interface()
...@@ -1266,8 +1264,6 @@ setValue: ...@@ -1266,8 +1264,6 @@ setValue:
if value != nil { if value != nil {
fieldType = fi.relModelInfo.fields.pk.fieldType fieldType = fi.relModelInfo.fields.pk.fieldType
mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) mf := reflect.New(fi.relModelInfo.addrField.Elem().Type())
md := mf.Interface().(Modeler)
md.Init(md)
field.Set(mf) field.Set(mf)
f := mf.Elem().Field(fi.relModelInfo.fields.pk.fieldIndex) f := mf.Elem().Field(fi.relModelInfo.fields.pk.fieldIndex)
field = &f field = &f
......
...@@ -11,6 +11,17 @@ orm:"null;rel(fk)" ...@@ -11,6 +11,17 @@ orm:"null;rel(fk)"
多个设置间使用 `;` 分隔,设置的值如果是多个,使用 `,` 分隔。 多个设置间使用 `;` 分隔,设置的值如果是多个,使用 `,` 分隔。
#### 忽略字段
设置 `-` 即可忽略 struct 中的字段
```go
type User struct {
...
AnyField string `orm:"-"`
...
```
#### auto #### auto
设置为 Autoincrement Primary Key 设置为 Autoincrement Primary Key
...@@ -49,23 +60,6 @@ type User struct { ...@@ -49,23 +60,6 @@ type User struct {
... ...
Status int `orm:"default(1)"` Status int `orm:"default(1)"`
``` ```
仅当进行 orm.Manager 初始化时才会赋值
```go
func NewUser() *User {
obj := new(User)
obj.Manager.Init(obj)
return obj
}
u := NewUser()
fmt.Println(u.Status) // 1
```
#### choices
为字段设置一组可选的值,类型必须符合。其他值 clean 会返回错误
```go
Status int `orm:"choices(1,2,3,4)"`
```
#### size (string) #### size (string)
string 类型字段设置 size 以后,db type 将使用 varchar string 类型字段设置 size 以后,db type 将使用 varchar
......
...@@ -17,14 +17,12 @@ type User struct { ...@@ -17,14 +17,12 @@ type User struct {
Id int `orm:"auto"` // 设置为auto主键 Id int `orm:"auto"` // 设置为auto主键
Name string Name string
Profile *Profile `orm:"rel(one)"` // OneToOne relation Profile *Profile `orm:"rel(one)"` // OneToOne relation
orm.Manager // 每个model都需要定义orm.Manager
} }
type Profile struct { type Profile struct {
Id int `orm:"auto"` Id int `orm:"auto"`
Age int16 Age int16
User *User `orm:"reverse(one)"` // 设置反向关系(可选) User *User `orm:"reverse(one)"` // 设置反向关系(可选)
orm.Manager
} }
func init() { func init() {
......
...@@ -18,6 +18,7 @@ var ( ...@@ -18,6 +18,7 @@ var (
cacheByFN: make(map[string]*modelInfo), cacheByFN: make(map[string]*modelInfo),
} }
supportTag = map[string]int{ supportTag = map[string]int{
"-": 1,
"null": 1, "null": 1,
"blank": 1, "blank": 1,
"index": 1, "index": 1,
...@@ -27,7 +28,6 @@ var ( ...@@ -27,7 +28,6 @@ var (
"auto_now": 1, "auto_now": 1,
"auto_now_add": 1, "auto_now_add": 1,
"size": 2, "size": 2,
"choices": 2,
"column": 2, "column": 2,
"default": 2, "default": 2,
"rel": 2, "rel": 2,
...@@ -67,9 +67,11 @@ func (mc *_modelCache) allOrdered() []*modelInfo { ...@@ -67,9 +67,11 @@ func (mc *_modelCache) allOrdered() []*modelInfo {
func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) { func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) {
mi, ok = mc.cache[table] mi, ok = mc.cache[table]
if ok == false { return
mi, ok = mc.cacheByFN[table] }
}
func (mc *_modelCache) getByFN(name string) (mi *modelInfo, ok bool) {
mi, ok = mc.cacheByFN[name]
return return
} }
......
...@@ -8,20 +8,36 @@ import ( ...@@ -8,20 +8,36 @@ import (
"strings" "strings"
) )
func registerModel(model Modeler) { func registerModel(model interface{}) {
info := newModelInfo(model) val := reflect.ValueOf(model)
model.Init(model) ind := reflect.Indirect(val)
table := model.GetTableName() typ := ind.Type()
if val.Kind() != reflect.Ptr {
panic(fmt.Sprintf("<orm.RegisterModel> cannot use non-ptr model struct `%s`", getFullName(typ)))
}
info := newModelInfo(val)
name := getFullName(typ)
if _, ok := modelCache.getByFN(name); ok {
fmt.Printf("<orm.RegisterModel> model `%s` redeclared, must be unique\n", name)
os.Exit(2)
}
table := getTableName(val)
if _, ok := modelCache.get(table); ok { if _, ok := modelCache.get(table); ok {
fmt.Printf("model <%T> redeclared, must be unique\n", model) fmt.Printf("<orm.RegisterModel> table name `%s` redeclared, must be unique\n", table)
os.Exit(2) os.Exit(2)
} }
if info.fields.pk == nil { if info.fields.pk == nil {
fmt.Printf("model <%T> need a primary key field\n", model) fmt.Printf("<orm.RegisterModel> `%s` need a primary key field\n", name)
os.Exit(2) os.Exit(2)
} }
info.table = table info.table = table
info.pkg = getPkgPath(model) info.pkg = typ.PkgPath()
info.model = model info.model = model
info.manual = true info.manual = true
modelCache.set(table, info) modelCache.set(table, info)
...@@ -52,8 +68,8 @@ func bootStrap() { ...@@ -52,8 +68,8 @@ func bootStrap() {
elm = elm.Elem() elm = elm.Elem()
} }
tn := getTableName(reflect.New(elm).Interface().(Modeler)) name := getFullName(elm)
mii, ok := modelCache.get(tn) mii, ok := modelCache.getByFN(name)
if ok == false || mii.pkg != elm.PkgPath() { if ok == false || mii.pkg != elm.PkgPath() {
err = fmt.Errorf("can not found rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String()) err = fmt.Errorf("can not found rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String())
goto end goto end
...@@ -202,7 +218,7 @@ end: ...@@ -202,7 +218,7 @@ end:
} }
} }
func RegisterModel(models ...Modeler) { func RegisterModel(models ...interface{}) {
if modelCache.done { if modelCache.done {
panic(fmt.Errorf("RegisterModel must be run begore BootStrap")) panic(fmt.Errorf("RegisterModel must be run begore BootStrap"))
} }
......
...@@ -7,30 +7,7 @@ import ( ...@@ -7,30 +7,7 @@ import (
"strings" "strings"
) )
type fieldChoices []StrTo var errSkipField = errors.New("skip field")
func (f *fieldChoices) Add(s StrTo) {
if f.Have(s) == false {
*f = append(*f, s)
}
}
func (f *fieldChoices) Clear() {
*f = fieldChoices([]StrTo{})
}
func (f *fieldChoices) Have(s StrTo) bool {
for _, v := range *f {
if v == s {
return true
}
}
return false
}
func (f *fieldChoices) Clone() fieldChoices {
return *f
}
type fields struct { type fields struct {
pk *fieldInfo pk *fieldInfo
...@@ -111,7 +88,7 @@ type fieldInfo struct { ...@@ -111,7 +88,7 @@ type fieldInfo struct {
name string name string
fullName string fullName string
column string column string
addrValue *reflect.Value addrValue reflect.Value
sf *reflect.StructField sf *reflect.StructField
auto bool auto bool
pk bool pk bool
...@@ -120,7 +97,6 @@ type fieldInfo struct { ...@@ -120,7 +97,6 @@ type fieldInfo struct {
index bool index bool
unique bool unique bool
initial StrTo initial StrTo
choices fieldChoices
size int size int
auto_now bool auto_now bool
auto_now_add bool auto_now_add bool
...@@ -142,13 +118,10 @@ func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField) (f ...@@ -142,13 +118,10 @@ func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField) (f
var ( var (
tag string tag string
tagValue string tagValue string
choices fieldChoices
values fieldChoices
initial StrTo initial StrTo
fieldType int fieldType int
attrs map[string]bool attrs map[string]bool
tags map[string]string tags map[string]string
parts []string
addrField reflect.Value addrField reflect.Value
) )
...@@ -162,11 +135,20 @@ func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField) (f ...@@ -162,11 +135,20 @@ func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField) (f
parseStructTag(sf.Tag.Get(defaultStructTagName), &attrs, &tags) parseStructTag(sf.Tag.Get(defaultStructTagName), &attrs, &tags)
if _, ok := attrs["-"]; ok {
return nil, errSkipField
}
digits := tags["digits"] digits := tags["digits"]
decimals := tags["decimals"] decimals := tags["decimals"]
size := tags["size"] size := tags["size"]
onDelete := tags["on_delete"] onDelete := tags["on_delete"]
initial.Clear()
if v, ok := tags["default"]; ok {
initial.Set(v)
}
checkType: checkType:
switch f := addrField.Interface().(type) { switch f := addrField.Interface().(type) {
case Fielder: case Fielder:
...@@ -237,10 +219,6 @@ checkType: ...@@ -237,10 +219,6 @@ checkType:
switch fieldType { switch fieldType {
case RelForeignKey, RelOneToOne, RelReverseOne: case RelForeignKey, RelOneToOne, RelReverseOne:
if _, ok := addrField.Interface().(Modeler); ok == false {
err = fmt.Errorf("rel/reverse:one field must be implements Modeler")
goto end
}
if field.Kind() != reflect.Ptr { if field.Kind() != reflect.Ptr {
err = fmt.Errorf("rel/reverse:one field must be *%s", field.Type().Name()) err = fmt.Errorf("rel/reverse:one field must be *%s", field.Type().Name())
goto end goto end
...@@ -254,10 +232,6 @@ checkType: ...@@ -254,10 +232,6 @@ checkType:
err = fmt.Errorf("rel/reverse:many slice must be []*%s", field.Type().Elem().Name()) err = fmt.Errorf("rel/reverse:many slice must be []*%s", field.Type().Elem().Name())
goto end goto end
} }
if _, ok := reflect.New(field.Type().Elem()).Elem().Interface().(Modeler); ok == false {
err = fmt.Errorf("rel/reverse:many slice element must be implements Modeler")
goto end
}
} }
} }
...@@ -269,7 +243,7 @@ checkType: ...@@ -269,7 +243,7 @@ checkType:
fi.fieldType = fieldType fi.fieldType = fieldType
fi.name = sf.Name fi.name = sf.Name
fi.column = getColumnName(fieldType, addrField, sf, tags["column"]) fi.column = getColumnName(fieldType, addrField, sf, tags["column"])
fi.addrValue = &addrField fi.addrValue = addrField
fi.sf = &sf fi.sf = &sf
fi.fullName = mi.fullName + "." + sf.Name fi.fullName = mi.fullName + "." + sf.Name
...@@ -306,7 +280,7 @@ checkType: ...@@ -306,7 +280,7 @@ checkType:
switch onDelete { switch onDelete {
case od_CASCADE, od_DO_NOTHING: case od_CASCADE, od_DO_NOTHING:
case od_SET_DEFAULT: case od_SET_DEFAULT:
if tags["default"] == "" { if initial.Exist() == false {
err = errors.New("on_delete: set_default need set field a default value") err = errors.New("on_delete: set_default need set field a default value")
goto end goto end
} }
...@@ -397,31 +371,13 @@ checkType: ...@@ -397,31 +371,13 @@ checkType:
fi.index = false fi.index = false
} }
parts = strings.Split(tags["choices"], ",")
if len(parts) > 1 {
for _, v := range parts {
choices.Add(StrTo(strings.TrimSpace(v)))
}
}
initial.Clear()
if v, ok := tags["default"]; ok {
initial.Set(v)
}
if fi.auto || fi.pk || fi.unique || fieldType == TypeDateField || fieldType == TypeDateTimeField { if fi.auto || fi.pk || fi.unique || fieldType == TypeDateField || fieldType == TypeDateTimeField {
// can not set default // can not set default
choices.Clear()
initial.Clear() initial.Clear()
} }
values = choices.Clone()
if initial.Exist() { if initial.Exist() {
values.Add(initial) v := initial
}
for i, v := range values {
switch fieldType { switch fieldType {
case TypeBooleanField: case TypeBooleanField:
_, err = v.Bool() _, err = v.Bool()
...@@ -441,23 +397,11 @@ checkType: ...@@ -441,23 +397,11 @@ checkType:
_, err = v.Uint64() _, err = v.Uint64()
} }
if err != nil { if err != nil {
if initial.Exist() && len(values) == i {
tag, tagValue = "default", tags["default"] tag, tagValue = "default", tags["default"]
} else {
tag, tagValue = "choices", tags["choices"]
}
goto wrongTag goto wrongTag
} }
} }
if len(choices) > 0 && initial.Exist() {
if choices.Have(initial) == false {
err = fmt.Errorf("default value `%s` not in choices `%s`", tags["default"], tags["choices"])
goto end
}
}
fi.choices = choices
fi.initial = initial fi.initial = initial
end: end:
if err != nil { if err != nil {
......
...@@ -12,13 +12,13 @@ type modelInfo struct { ...@@ -12,13 +12,13 @@ type modelInfo struct {
name string name string
fullName string fullName string
table string table string
model Modeler model interface{}
fields *fields fields *fields
manual bool manual bool
addrField reflect.Value addrField reflect.Value
} }
func newModelInfo(model Modeler) (info *modelInfo) { func newModelInfo(val reflect.Value) (info *modelInfo) {
var ( var (
err error err error
fi *fieldInfo fi *fieldInfo
...@@ -28,26 +28,24 @@ func newModelInfo(model Modeler) (info *modelInfo) { ...@@ -28,26 +28,24 @@ func newModelInfo(model Modeler) (info *modelInfo) {
info = &modelInfo{} info = &modelInfo{}
info.fields = newFields() info.fields = newFields()
val := reflect.ValueOf(model)
ind := reflect.Indirect(val) ind := reflect.Indirect(val)
typ := ind.Type() typ := ind.Type()
info.addrField = ind.Addr() info.addrField = ind.Addr()
info.name = typ.Name() info.name = typ.Name()
info.fullName = typ.PkgPath() + "." + typ.Name() info.fullName = getFullName(typ)
for i := 0; i < ind.NumField(); i++ { for i := 0; i < ind.NumField(); i++ {
field := ind.Field(i) field := ind.Field(i)
sf = ind.Type().Field(i) sf = ind.Type().Field(i)
if field.CanAddr() {
addr := field.Addr()
if _, ok := addr.Interface().(*Manager); ok {
continue
}
}
fi, err = newFieldInfo(info, field, sf) fi, err = newFieldInfo(info, field, sf)
if err != nil { if err != nil {
if err == errSkipField {
err = nil
continue
}
break break
} }
......
package orm
import ()
type fieldError struct {
name string
err error
}
func (f *fieldError) Name() string {
return f.name
}
func (f *fieldError) Error() error {
return f.err
}
func NewFieldError(name string, err error) IFieldError {
return &fieldError{name, err}
}
// non cleaned field errors
type fieldErrors struct {
errors map[string]IFieldError
errorList []IFieldError
}
func (fe *fieldErrors) Get(name string) IFieldError {
return fe.errors[name]
}
func (fe *fieldErrors) Set(name string, value IFieldError) {
fe.errors[name] = value
}
func (fe *fieldErrors) List() []IFieldError {
return fe.errorList
}
func NewFieldErrors() IFieldErrors {
return &fieldErrors{errors: make(map[string]IFieldError)}
}
type Manager struct {
ins Modeler
inited bool
}
// func (m *Manager) init(model reflect.Value) {
// elm := model.Elem()
// for i := 0; i < elm.NumField(); i++ {
// field := elm.Field(i)
// if _, ok := field.Interface().(Fielder); ok && field.CanSet() {
// if field.Elem().Kind() != reflect.Struct {
// field.Set(reflect.New(field.Type().Elem()))
// }
// }
// }
// }
func (m *Manager) Init(model Modeler, args ...interface{}) Modeler {
if m.inited {
return m.ins
}
m.inited = true
m.ins = model
skipInitial := false
if len(args) > 0 {
if b, ok := args[0].(bool); ok && b {
skipInitial = true
}
}
_ = skipInitial
return model
}
func (m *Manager) IsInited() bool {
return m.inited
}
func (m *Manager) Clean() IFieldErrors {
return nil
}
func (m *Manager) CleanFields(name string) IFieldErrors {
return nil
}
func (m *Manager) GetTableName() string {
return getTableName(m.ins)
}
...@@ -15,19 +15,18 @@ type User struct { ...@@ -15,19 +15,18 @@ type User struct {
UserName string `orm:"size(30);unique"` UserName string `orm:"size(30);unique"`
Email string `orm:"size(100)"` Email string `orm:"size(100)"`
Password string `orm:"size(100)"` Password string `orm:"size(100)"`
Status int16 `orm:"choices(0,1,2,3);defalut(0)"` Status int16
IsStaff bool `orm:"default(false)"` IsStaff bool
IsActive bool `orm:"default(1)"` IsActive bool `orm:"default(1)"`
Created time.Time `orm:"auto_now_add;type(date)"` Created time.Time `orm:"auto_now_add;type(date)"`
Updated time.Time `orm:"auto_now"` Updated time.Time `orm:"auto_now"`
Profile *Profile `orm:"null;rel(one);on_delete(set_null)"` Profile *Profile `orm:"null;rel(one);on_delete(set_null)"`
Posts []*Post `orm:"reverse(many)" json:"-"` Posts []*Post `orm:"reverse(many)" json:"-"`
Manager `json:"-"` ShouldSkip string `orm:"-"`
} }
func NewUser() *User { func NewUser() *User {
obj := new(User) obj := new(User)
obj.Manager.Init(obj)
return obj return obj
} }
...@@ -36,7 +35,6 @@ type Profile struct { ...@@ -36,7 +35,6 @@ type Profile struct {
Age int16 `` Age int16 ``
Money float64 `` Money float64 ``
User *User `orm:"reverse(one)" json:"-"` User *User `orm:"reverse(one)" json:"-"`
Manager `json:"-"`
} }
func (u *Profile) TableName() string { func (u *Profile) TableName() string {
...@@ -45,7 +43,6 @@ func (u *Profile) TableName() string { ...@@ -45,7 +43,6 @@ func (u *Profile) TableName() string {
func NewProfile() *Profile { func NewProfile() *Profile {
obj := new(Profile) obj := new(Profile)
obj.Manager.Init(obj)
return obj return obj
} }
...@@ -57,12 +54,10 @@ type Post struct { ...@@ -57,12 +54,10 @@ type Post struct {
Created time.Time `orm:"auto_now_add"` Created time.Time `orm:"auto_now_add"`
Updated time.Time `orm:"auto_now"` Updated time.Time `orm:"auto_now"`
Tags []*Tag `orm:"rel(m2m)"` Tags []*Tag `orm:"rel(m2m)"`
Manager `json:"-"`
} }
func NewPost() *Post { func NewPost() *Post {
obj := new(Post) obj := new(Post)
obj.Manager.Init(obj)
return obj return obj
} }
...@@ -70,12 +65,10 @@ type Tag struct { ...@@ -70,12 +65,10 @@ type Tag struct {
Id int `orm:"auto"` Id int `orm:"auto"`
Name string `orm:"size(30)"` Name string `orm:"size(30)"`
Posts []*Post `orm:"reverse(many)" json:"-"` Posts []*Post `orm:"reverse(many)" json:"-"`
Manager `json:"-"`
} }
func NewTag() *Tag { func NewTag() *Tag {
obj := new(Tag) obj := new(Tag)
obj.Manager.Init(obj)
return obj return obj
} }
...@@ -85,12 +78,10 @@ type Comment struct { ...@@ -85,12 +78,10 @@ type Comment struct {
Content string `` Content string ``
Parent *Comment `orm:"null;rel(fk)"` Parent *Comment `orm:"null;rel(fk)"`
Created time.Time `orm:"auto_now_add"` Created time.Time `orm:"auto_now_add"`
Manager `json:"-"`
} }
func NewComment() *Comment { func NewComment() *Comment {
obj := new(Comment) obj := new(Comment)
obj.Manager.Init(obj)
return obj return obj
} }
......
...@@ -7,8 +7,11 @@ import ( ...@@ -7,8 +7,11 @@ import (
"time" "time"
) )
func getTableName(model Modeler) string { func getFullName(typ reflect.Type) string {
val := reflect.ValueOf(model) return typ.PkgPath() + "." + typ.Name()
}
func getTableName(val reflect.Value) string {
ind := reflect.Indirect(val) ind := reflect.Indirect(val)
fun := val.MethodByName("TableName") fun := val.MethodByName("TableName")
if fun.IsValid() { if fun.IsValid() {
...@@ -23,11 +26,6 @@ func getTableName(model Modeler) string { ...@@ -23,11 +26,6 @@ func getTableName(model Modeler) string {
return snakeString(ind.Type().Name()) return snakeString(ind.Type().Name())
} }
func getPkgPath(model Modeler) string {
val := reflect.ValueOf(model)
return val.Type().Elem().PkgPath()
}
func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string { func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string {
column := strings.ToLower(col) column := strings.ToLower(col)
if column == "" { if column == "" {
......
...@@ -39,16 +39,21 @@ type orm struct { ...@@ -39,16 +39,21 @@ type orm struct {
var _ Ormer = new(orm) var _ Ormer = new(orm)
func (o *orm) getMiInd(md Modeler) (mi *modelInfo, ind reflect.Value) { func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) {
md.Init(md, true) val := reflect.ValueOf(md)
name := md.GetTableName() ind = reflect.Indirect(val)
if mi, ok := modelCache.get(name); ok { typ := ind.Type()
return mi, reflect.Indirect(reflect.ValueOf(md)) if val.Kind() != reflect.Ptr {
panic(fmt.Sprintf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ)))
}
name := getFullName(typ)
if mi, ok := modelCache.getByFN(name); ok {
return mi, ind
} }
panic(fmt.Sprintf("<orm> table name: `%s` not exists", name)) panic(fmt.Sprintf("<Ormer> table: `%s` not found, maybe not RegisterModel", name))
} }
func (o *orm) Read(md Modeler) error { func (o *orm) Read(md interface{}) error {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md)
err := o.alias.DbBaser.Read(o.db, mi, ind) err := o.alias.DbBaser.Read(o.db, mi, ind)
if err != nil { if err != nil {
...@@ -57,7 +62,7 @@ func (o *orm) Read(md Modeler) error { ...@@ -57,7 +62,7 @@ func (o *orm) Read(md Modeler) error {
return nil return nil
} }
func (o *orm) Insert(md Modeler) (int64, error) { func (o *orm) Insert(md interface{}) (int64, error) {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md)
id, err := o.alias.DbBaser.Insert(o.db, mi, ind) id, err := o.alias.DbBaser.Insert(o.db, mi, ind)
if err != nil { if err != nil {
...@@ -71,7 +76,7 @@ func (o *orm) Insert(md Modeler) (int64, error) { ...@@ -71,7 +76,7 @@ func (o *orm) Insert(md Modeler) (int64, error) {
return id, nil return id, nil
} }
func (o *orm) Update(md Modeler) (int64, error) { func (o *orm) Update(md interface{}) (int64, error) {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md)
num, err := o.alias.DbBaser.Update(o.db, mi, ind) num, err := o.alias.DbBaser.Update(o.db, mi, ind)
if err != nil { if err != nil {
...@@ -80,7 +85,7 @@ func (o *orm) Update(md Modeler) (int64, error) { ...@@ -80,7 +85,7 @@ func (o *orm) Update(md Modeler) (int64, error) {
return num, nil return num, nil
} }
func (o *orm) Delete(md Modeler) (int64, error) { func (o *orm) Delete(md interface{}) (int64, error) {
mi, ind := o.getMiInd(md) mi, ind := o.getMiInd(md)
num, err := o.alias.DbBaser.Delete(o.db, mi, ind) num, err := o.alias.DbBaser.Delete(o.db, mi, ind)
if err != nil { if err != nil {
...@@ -94,41 +99,48 @@ func (o *orm) Delete(md Modeler) (int64, error) { ...@@ -94,41 +99,48 @@ func (o *orm) Delete(md Modeler) (int64, error) {
return num, nil return num, nil
} }
func (o *orm) M2mAdd(md Modeler, name string, mds ...interface{}) (int64, error) { func (o *orm) M2mAdd(md interface{}, name string, mds ...interface{}) (int64, error) {
// TODO // TODO
panic(ErrNotImplement) panic(ErrNotImplement)
return 0, nil return 0, nil
} }
func (o *orm) M2mDel(md Modeler, name string, mds ...interface{}) (int64, error) { func (o *orm) M2mDel(md interface{}, name string, mds ...interface{}) (int64, error) {
// TODO // TODO
panic(ErrNotImplement) panic(ErrNotImplement)
return 0, nil return 0, nil
} }
func (o *orm) LoadRel(md Modeler, name string) (int64, error) { func (o *orm) LoadRel(md interface{}, name string) (int64, error) {
// TODO // TODO
panic(ErrNotImplement) panic(ErrNotImplement)
return 0, nil return 0, nil
} }
func (o *orm) QueryTable(ptrStructOrTableName interface{}) QuerySeter { func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
name := "" name := ""
if table, ok := ptrStructOrTableName.(string); ok { if table, ok := ptrStructOrTableName.(string); ok {
name = snakeString(table) name = snakeString(table)
} else if md, ok := ptrStructOrTableName.(Modeler); ok {
md.Init(md, true)
name = md.GetTableName()
}
if mi, ok := modelCache.get(name); ok { if mi, ok := modelCache.get(name); ok {
return newQuerySet(o, mi) qs = newQuerySet(o, mi)
}
} else {
val := reflect.ValueOf(ptrStructOrTableName)
ind := reflect.Indirect(val)
name = getFullName(ind.Type())
if mi, ok := modelCache.getByFN(name); ok {
qs = newQuerySet(o, mi)
}
}
if qs == nil {
panic(fmt.Sprintf("<Ormer.QueryTable> table name: `%s` not exists", name))
} }
panic(fmt.Sprintf("<orm.SetTable> table name: `%s` not exists", name)) return
} }
func (o *orm) Using(name string) error { func (o *orm) Using(name string) error {
if o.isTx { if o.isTx {
panic("<orm.Using> transaction has been start, cannot change db") panic("<Ormer.Using> transaction has been start, cannot change db")
} }
if al, ok := dataBaseCache.get(name); ok { if al, ok := dataBaseCache.get(name); ok {
o.alias = al o.alias = al
...@@ -138,7 +150,7 @@ func (o *orm) Using(name string) error { ...@@ -138,7 +150,7 @@ func (o *orm) Using(name string) error {
o.db = al.DB o.db = al.DB
} }
} else { } else {
return errors.New(fmt.Sprintf("<orm.Using> unknown db alias name `%s`", name)) return errors.New(fmt.Sprintf("<Ormer.Using> unknown db alias name `%s`", name))
} }
return nil return nil
} }
......
...@@ -14,15 +14,19 @@ type insertSet struct { ...@@ -14,15 +14,19 @@ type insertSet struct {
var _ Inserter = new(insertSet) var _ Inserter = new(insertSet)
func (o *insertSet) Insert(md Modeler) (int64, error) { func (o *insertSet) Insert(md interface{}) (int64, error) {
if o.closed { if o.closed {
return 0, ErrStmtClosed return 0, ErrStmtClosed
} }
md.Init(md, true)
val := reflect.ValueOf(md) val := reflect.ValueOf(md)
ind := reflect.Indirect(val) ind := reflect.Indirect(val)
if val.Type() != o.mi.addrField.Type() { typ := ind.Type()
panic(fmt.Sprintf("<Inserter.Insert> need type `%s` but found `%s`", o.mi.addrField.Type(), val.Type())) name := getFullName(typ)
if val.Kind() != reflect.Ptr {
panic(fmt.Sprintf("<Inserter.Insert> cannot use non-ptr model struct `%s`", name))
}
if name != o.mi.fullName {
panic(fmt.Sprintf("<Inserter.Insert> need model `%s` but found `%s`", o.mi.fullName, name))
} }
id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind) id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind)
if err != nil { if err != nil {
......
...@@ -63,7 +63,7 @@ func (o querySet) RelatedSel(params ...interface{}) QuerySeter { ...@@ -63,7 +63,7 @@ func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
case int: case int:
o.relDepth = val o.relDepth = val
default: default:
panic(fmt.Sprintf("<querySet.RelatedSel> wrong param kind: %v", val)) panic(fmt.Sprintf("<QuerySeter.RelatedSel> wrong param kind: %v", val))
} }
} }
} }
...@@ -96,7 +96,7 @@ func (o *querySet) All(container interface{}) (int64, error) { ...@@ -96,7 +96,7 @@ func (o *querySet) All(container interface{}) (int64, error) {
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container) return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container)
} }
func (o *querySet) One(container Modeler) error { func (o *querySet) One(container interface{}) error {
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container) num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container)
if err != nil { if err != nil {
return err return err
......
...@@ -152,6 +152,14 @@ func throwFailNow(t *testing.T, err error, args ...interface{}) { ...@@ -152,6 +152,14 @@ func throwFailNow(t *testing.T, err error, args ...interface{}) {
} }
} }
func TestModelSyntax(t *testing.T) {
mi, ok := modelCache.get("user")
throwFail(t, AssertIs(ok, T_Equal, true))
if ok {
throwFail(t, AssertIs(mi.fields.GetByName("ShouldSkip") == nil, T_Equal, true))
}
}
func TestCRUD(t *testing.T) { func TestCRUD(t *testing.T) {
profile := NewProfile() profile := NewProfile()
profile.Age = 30 profile.Age = 30
......
...@@ -18,22 +18,14 @@ type Fielder interface { ...@@ -18,22 +18,14 @@ type Fielder interface {
Clean() error Clean() error
} }
type Modeler interface {
Init(Modeler, ...interface{}) Modeler
IsInited() bool
Clean() IFieldErrors
CleanFields(string) IFieldErrors
GetTableName() string
}
type Ormer interface { type Ormer interface {
Read(Modeler) error Read(interface{}) error
Insert(Modeler) (int64, error) Insert(interface{}) (int64, error)
Update(Modeler) (int64, error) Update(interface{}) (int64, error)
Delete(Modeler) (int64, error) Delete(interface{}) (int64, error)
M2mAdd(Modeler, string, ...interface{}) (int64, error) M2mAdd(interface{}, string, ...interface{}) (int64, error)
M2mDel(Modeler, string, ...interface{}) (int64, error) M2mDel(interface{}, string, ...interface{}) (int64, error)
LoadRel(Modeler, string) (int64, error) LoadRel(interface{}, string) (int64, error)
QueryTable(interface{}) QuerySeter QueryTable(interface{}) QuerySeter
Using(string) error Using(string) error
Begin() error Begin() error
...@@ -44,7 +36,7 @@ type Ormer interface { ...@@ -44,7 +36,7 @@ type Ormer interface {
} }
type Inserter interface { type Inserter interface {
Insert(Modeler) (int64, error) Insert(interface{}) (int64, error)
Close() error Close() error
} }
...@@ -61,7 +53,7 @@ type QuerySeter interface { ...@@ -61,7 +53,7 @@ type QuerySeter interface {
Delete() (int64, error) Delete() (int64, error)
PrepareInsert() (Inserter, error) PrepareInsert() (Inserter, error)
All(interface{}) (int64, error) All(interface{}) (int64, error)
One(Modeler) error One(interface{}) error
Values(*[]Params, ...string) (int64, error) Values(*[]Params, ...string) (int64, error)
ValuesList(*[]ParamsList, ...string) (int64, error) ValuesList(*[]ParamsList, ...string) (int64, error)
ValuesFlat(*ParamsList, string) (int64, error) ValuesFlat(*ParamsList, string) (int64, error)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment