Commit f45b271b authored by astaxie's avatar astaxie

Merge pull request #1723 from miraclesu/feature/orm_inline_struct

orm: inline struct support
parents 589616b3 64e0858d
......@@ -113,7 +113,7 @@ func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Val
if fi.pk {
_, value, _ = getExistPk(mi, ind)
} else {
field := ind.Field(fi.fieldIndex)
field := ind.FieldByIndex(fi.fieldIndex)
if fi.isFielder {
f := field.Addr().Interface().(Fielder)
value = f.RawValue()
......@@ -517,9 +517,9 @@ func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
if num > 0 {
if mi.fields.pk.auto {
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
ind.Field(mi.fields.pk.fieldIndex).SetUint(0)
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(0)
} else {
ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0)
}
}
err := d.deleteRels(q, mi, []interface{}{pkValue}, tz)
......@@ -859,13 +859,13 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
mmi = fi.relModelInfo
field := last
if last.Kind() != reflect.Invalid {
field = reflect.Indirect(last.Field(fi.fieldIndex))
field = reflect.Indirect(last.FieldByIndex(fi.fieldIndex))
if field.IsValid() {
d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)], tz)
for _, fi := range mmi.fields.fieldsReverse {
if fi.inModel && fi.reverseFieldInfo.mi == lastm {
if fi.reverseFieldInfo != nil {
f := field.Field(fi.fieldIndex)
f := field.FieldByIndex(fi.fieldIndex)
if f.Kind() == reflect.Ptr {
f.Set(last.Addr())
}
......@@ -1014,7 +1014,7 @@ func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string,
fi := mi.fields.GetByColumn(column)
field := ind.Field(fi.fieldIndex)
field := ind.FieldByIndex(fi.fieldIndex)
value, err := d.convertValueFromDB(fi, val, tz)
if err != nil {
......@@ -1350,7 +1350,7 @@ setValue:
fieldType = fi.relModelInfo.fields.pk.fieldType
mf := reflect.New(fi.relModelInfo.addrField.Elem().Type())
field.Set(mf)
f := mf.Elem().Field(fi.relModelInfo.fields.pk.fieldIndex)
f := mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex)
field = f
goto setValue
}
......
......@@ -32,7 +32,7 @@ func getDbAlias(name string) *alias {
func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
fi := mi.fields.pk
v := ind.Field(fi.fieldIndex)
v := ind.FieldByIndex(fi.fieldIndex)
if fi.fieldType&IsPostiveIntegerField > 0 {
vu := v.Uint()
exist = vu > 0
......
......@@ -102,7 +102,7 @@ func newFields() *fields {
// single field info
type fieldInfo struct {
mi *modelInfo
fieldIndex int
fieldIndex []int
fieldType int
dbcol bool
inModel bool
......@@ -138,7 +138,7 @@ type fieldInfo struct {
}
// new field info
func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField) (fi *fieldInfo, err error) {
func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mName string) (fi *fieldInfo, err error) {
var (
tag string
tagValue string
......@@ -278,7 +278,7 @@ checkType:
fi.column = getColumnName(fieldType, addrField, sf, tags["column"])
fi.addrValue = addrField
fi.sf = sf
fi.fullName = mi.fullName + "." + sf.Name
fi.fullName = mi.fullName + mName + "." + sf.Name
fi.null = attrs["null"]
fi.index = attrs["index"]
......
......@@ -36,11 +36,6 @@ type modelInfo struct {
// new model info
func newModelInfo(val reflect.Value) (info *modelInfo) {
var (
err error
fi *fieldInfo
sf reflect.StructField
)
info = &modelInfo{}
info.fields = newFields()
......@@ -53,13 +48,31 @@ func newModelInfo(val reflect.Value) (info *modelInfo) {
info.name = typ.Name()
info.fullName = getFullName(typ)
addModelFields(info, ind, "", []int{})
return
}
func addModelFields(info *modelInfo, ind reflect.Value, mName string, index []int) {
var (
err error
fi *fieldInfo
sf reflect.StructField
)
for i := 0; i < ind.NumField(); i++ {
field := ind.Field(i)
sf = ind.Type().Field(i)
if sf.PkgPath != "" {
continue
}
fi, err = newFieldInfo(info, field, sf)
// add anonymous struct fields
if sf.Anonymous {
addModelFields(info, field, mName+"."+sf.Name, append(index, i))
continue
}
fi, err = newFieldInfo(info, field, sf, mName)
if err != nil {
if err == errSkipField {
......@@ -84,7 +97,7 @@ func newModelInfo(val reflect.Value) (info *modelInfo) {
}
}
fi.fieldIndex = i
fi.fieldIndex = append(index, i)
fi.mi = info
fi.inModel = true
}
......@@ -93,8 +106,6 @@ func newModelInfo(val reflect.Value) (info *modelInfo) {
fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err))
os.Exit(2)
}
return
}
// combine related model info to new model info.
......
......@@ -25,7 +25,6 @@ import (
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
// As tidb can't use go get, so disable the tidb testing now
// _ "github.com/pingcap/tidb"
)
......@@ -352,6 +351,30 @@ type GroupPermissions struct {
Permission *Permission `orm:"rel(fk)"`
}
type ModelID struct {
Id int64
}
type ModelBase struct {
ModelID
Created time.Time `orm:"auto_now_add;type(datetime)"`
Updated time.Time `orm:"auto_now;type(datetime)"`
}
type InLine struct {
// Common Fields
ModelBase
// Other Fields
Name string `orm:"unique"`
Email string
}
func NewInLine() *InLine {
return new(InLine)
}
var DBARGS = struct {
Driver string
Source string
......
......@@ -140,7 +140,7 @@ func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, i
return (err == nil), id, err
}
return false, ind.Field(mi.fields.pk.fieldIndex).Int(), err
return false, ind.FieldByIndex(mi.fields.pk.fieldIndex).Int(), err
}
// insert model data to database
......@@ -160,9 +160,9 @@ func (o *orm) Insert(md interface{}) (int64, error) {
func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) {
if mi.fields.pk.auto {
if mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
ind.Field(mi.fields.pk.fieldIndex).SetUint(uint64(id))
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id))
} else {
ind.Field(mi.fields.pk.fieldIndex).SetInt(id)
ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(id)
}
}
}
......@@ -290,7 +290,7 @@ func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int
qs.orders = []string{order}
}
find := ind.Field(fi.fieldIndex)
find := ind.FieldByIndex(fi.fieldIndex)
var nums int64
var err error
......
......@@ -51,9 +51,9 @@ func (o *insertSet) Insert(md interface{}) (int64, error) {
if id > 0 {
if o.mi.fields.pk.auto {
if o.mi.fields.pk.fieldType&IsPostiveIntegerField > 0 {
ind.Field(o.mi.fields.pk.fieldIndex).SetUint(uint64(id))
ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetUint(uint64(id))
} else {
ind.Field(o.mi.fields.pk.fieldIndex).SetInt(id)
ind.FieldByIndex(o.mi.fields.pk.fieldIndex).SetInt(id)
}
}
}
......
......@@ -342,7 +342,7 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
for _, col := range columns {
if fi := sMi.fields.GetByColumn(col); fi != nil {
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
o.setFieldValue(ind.FieldByIndex([]int{fi.fieldIndex}), value)
o.setFieldValue(ind.FieldByIndex(fi.fieldIndex), value)
}
}
} else {
......@@ -480,7 +480,7 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
for _, col := range columns {
if fi := sMi.fields.GetByColumn(col); fi != nil {
value := reflect.ValueOf(columnsMp[col]).Elem().Interface()
o.setFieldValue(ind.FieldByIndex([]int{fi.fieldIndex}), value)
o.setFieldValue(ind.FieldByIndex(fi.fieldIndex), value)
}
}
} else {
......
......@@ -187,6 +187,7 @@ func TestSyncDb(t *testing.T) {
RegisterModel(new(Group))
RegisterModel(new(Permission))
RegisterModel(new(GroupPermissions))
RegisterModel(new(InLine))
err := RunSyncdb("default", true, Debug)
throwFail(t, err)
......@@ -206,6 +207,7 @@ func TestRegisterModels(t *testing.T) {
RegisterModel(new(Group))
RegisterModel(new(Permission))
RegisterModel(new(GroupPermissions))
RegisterModel(new(InLine))
BootStrap()
......@@ -1928,3 +1930,25 @@ func TestReadOrCreate(t *testing.T) {
dORM.Delete(u)
}
func TestInLine(t *testing.T) {
name := "inline"
email := "hello@go.com"
inline := NewInLine()
inline.Name = name
inline.Email = email
id, err := dORM.Insert(inline)
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
il := NewInLine()
il.Id = 1
err = dORM.Read(il)
throwFail(t, err)
throwFail(t, AssertIs(il.Name, name))
throwFail(t, AssertIs(il.Email, email))
throwFail(t, AssertIs(il.Created.In(DefaultTimeLoc), inline.Created.In(DefaultTimeLoc), testDate))
throwFail(t, AssertIs(il.Updated.In(DefaultTimeLoc), inline.Updated.In(DefaultTimeLoc), testDateTime))
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment