Commit 8c37a07a authored by astaxie's avatar astaxie

optimize the ORM

parent 161c0613
......@@ -154,7 +154,7 @@ outFor:
typ := val.Type()
name := getFullName(typ)
var value interface{}
if mmi, ok := modelCache.getByFN(name); ok {
if mmi, ok := modelCache.getByFullName(name); ok {
if _, vu, exist := getExistPk(mmi, val); exist {
value = vu
}
......
......@@ -29,39 +29,18 @@ const (
var (
modelCache = &_modelCache{
cache: make(map[string]*modelInfo),
cacheByFN: make(map[string]*modelInfo),
}
supportTag = map[string]int{
"-": 1,
"null": 1,
"index": 1,
"unique": 1,
"pk": 1,
"auto": 1,
"auto_now": 1,
"auto_now_add": 1,
"size": 2,
"column": 2,
"default": 2,
"rel": 2,
"reverse": 2,
"rel_table": 2,
"rel_through": 2,
"digits": 2,
"decimals": 2,
"on_delete": 2,
"type": 2,
cache: make(map[string]*modelInfo),
cacheByFullName: make(map[string]*modelInfo),
}
)
// model info collection
type _modelCache struct {
sync.RWMutex // only used outsite for bootStrap
orders []string
cache map[string]*modelInfo
cacheByFN map[string]*modelInfo
done bool
sync.RWMutex // only used outsite for bootStrap
orders []string
cache map[string]*modelInfo
cacheByFullName map[string]*modelInfo
done bool
}
// get all model info
......@@ -88,9 +67,9 @@ func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) {
return
}
// get model info by field name
func (mc *_modelCache) getByFN(name string) (mi *modelInfo, ok bool) {
mi, ok = mc.cacheByFN[name]
// get model info by full name
func (mc *_modelCache) getByFullName(name string) (mi *modelInfo, ok bool) {
mi, ok = mc.cacheByFullName[name]
return
}
......@@ -98,7 +77,7 @@ func (mc *_modelCache) getByFN(name string) (mi *modelInfo, ok bool) {
func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
mii := mc.cache[table]
mc.cache[table] = mi
mc.cacheByFN[mi.fullName] = mi
mc.cacheByFullName[mi.fullName] = mi
if mii == nil {
mc.orders = append(mc.orders, table)
}
......@@ -109,7 +88,7 @@ func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
func (mc *_modelCache) clean() {
mc.orders = make([]string, 0)
mc.cache = make(map[string]*modelInfo)
mc.cacheByFN = make(map[string]*modelInfo)
mc.cacheByFullName = make(map[string]*modelInfo)
mc.done = false
}
......
......@@ -26,12 +26,14 @@ import (
// prefix means table name prefix.
func registerModel(prefix string, model interface{}) {
val := reflect.ValueOf(model)
ind := reflect.Indirect(val)
typ := ind.Type()
typ := reflect.Indirect(val).Type()
if val.Kind() != reflect.Ptr {
panic(fmt.Errorf("<orm.RegisterModel> cannot use non-ptr model struct `%s`", getFullName(typ)))
}
// For this case:
// u := &User{}
// registerModel(&u)
if typ.Kind() == reflect.Ptr {
panic(fmt.Errorf("<orm.RegisterModel> only allow ptr model struct, it looks you use two reference to the struct `%s`", typ))
}
......@@ -41,9 +43,9 @@ func registerModel(prefix string, model interface{}) {
if prefix != "" {
table = prefix + table
}
// models's fullname is pkgpath + struct name
name := getFullName(typ)
if _, ok := modelCache.getByFN(name); ok {
if _, ok := modelCache.getByFullName(name); ok {
fmt.Printf("<orm.RegisterModel> model `%s` repeat register, must be unique\n", name)
os.Exit(2)
}
......@@ -110,7 +112,7 @@ func bootStrap() {
}
name := getFullName(elm)
mii, ok := modelCache.getByFN(name)
mii, ok := modelCache.getByFullName(name)
if ok == false || mii.pkg != elm.PkgPath() {
err = fmt.Errorf("can not found rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String())
goto end
......@@ -123,7 +125,7 @@ func bootStrap() {
msg := fmt.Sprintf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough)
if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) {
pn := fi.relThrough[:i]
rmi, ok := modelCache.getByFN(fi.relThrough)
rmi, ok := modelCache.getByFullName(fi.relThrough)
if ok == false || pn != rmi.pkg {
err = errors.New(msg + " cannot find table")
goto end
......
......@@ -152,6 +152,10 @@ func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mN
fi = new(fieldInfo)
// if field which CanAddr is the follow type
// A value is addressable if it is an element of a slice,
// an element of an addressable array, a field of an
// addressable struct, or the result of dereferencing a pointer.
addrField = field
if field.CanAddr() && field.Kind() != reflect.Ptr {
addrField = field.Addr()
......@@ -162,7 +166,7 @@ func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField, mN
}
}
parseStructTag(sf.Tag.Get(defaultStructTagName), &attrs, &tags)
attrs, tags = parseStructTag(sf.Tag.Get(defaultStructTagName))
if _, ok := attrs["-"]; ok {
return nil, errSkipField
......@@ -188,7 +192,7 @@ checkType:
}
fieldType = f.FieldType()
if fieldType&IsRelField > 0 {
err = fmt.Errorf("unsupport rel type custom field")
err = fmt.Errorf("unsupport type custom field, please refer to https://github.com/astaxie/beego/blob/master/orm/models_fields.go#L24-L42")
goto end
}
default:
......
......@@ -29,31 +29,25 @@ type modelInfo struct {
model interface{}
fields *fields
manual bool
addrField reflect.Value
addrField reflect.Value //store the original struct value
uniques []string
isThrough bool
}
// new model info
func newModelInfo(val reflect.Value) (info *modelInfo) {
info = &modelInfo{}
info.fields = newFields()
func newModelInfo(val reflect.Value) (mi *modelInfo) {
mi = &modelInfo{}
mi.fields = newFields()
ind := reflect.Indirect(val)
typ := ind.Type()
info.addrField = val
info.name = typ.Name()
info.fullName = getFullName(typ)
addModelFields(info, ind, "", []int{})
mi.addrField = val
mi.name = ind.Type().Name()
mi.fullName = getFullName(ind.Type())
addModelFields(mi, ind, "", []int{})
return
}
func addModelFields(info *modelInfo, ind reflect.Value, mName string, index []int) {
// index: FieldByIndex returns the nested field corresponding to index
func addModelFields(mi *modelInfo, ind reflect.Value, mName string, index []int) {
var (
err error
fi *fieldInfo
......@@ -63,43 +57,39 @@ func addModelFields(info *modelInfo, ind reflect.Value, mName string, index []in
for i := 0; i < ind.NumField(); i++ {
field := ind.Field(i)
sf = ind.Type().Field(i)
// if the field is unexported skip
if sf.PkgPath != "" {
continue
}
// add anonymous struct fields
if sf.Anonymous {
addModelFields(info, field, mName+"."+sf.Name, append(index, i))
addModelFields(mi, field, mName+"."+sf.Name, append(index, i))
continue
}
fi, err = newFieldInfo(info, field, sf, mName)
if err != nil {
if err == errSkipField {
err = nil
continue
}
fi, err = newFieldInfo(mi, field, sf, mName)
if err == errSkipField {
err = nil
continue
} else if err != nil {
break
}
added := info.fields.Add(fi)
if added == false {
//record current field index
fi.fieldIndex = append(index, i)
fi.mi = mi
fi.inModel = true
if mi.fields.Add(fi) == false {
err = fmt.Errorf("duplicate column name: %s", fi.column)
break
}
if fi.pk {
if info.fields.pk != nil {
if mi.fields.pk != nil {
err = fmt.Errorf("one model must have one pk field only")
break
} else {
info.fields.pk = fi
mi.fields.pk = fi
}
}
fi.fieldIndex = append(index, i)
fi.mi = info
fi.inModel = true
}
if err != nil {
......@@ -110,12 +100,12 @@ func addModelFields(info *modelInfo, ind reflect.Value, mName string, index []in
// combine related model info to new model info.
// prepare for relation models query.
func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
info = new(modelInfo)
info.fields = newFields()
info.table = m1.table + "_" + m2.table + "s"
info.name = camelString(info.table)
info.fullName = m1.pkg + "." + info.name
func newM2MModelInfo(m1, m2 *modelInfo) (mi *modelInfo) {
mi = new(modelInfo)
mi.fields = newFields()
mi.table = m1.table + "_" + m2.table + "s"
mi.name = camelString(mi.table)
mi.fullName = m1.pkg + "." + mi.name
fa := new(fieldInfo)
f1 := new(fieldInfo)
......@@ -126,7 +116,7 @@ func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
fa.dbcol = true
fa.name = "Id"
fa.column = "id"
fa.fullName = info.fullName + "." + fa.name
fa.fullName = mi.fullName + "." + fa.name
f1.dbcol = true
f2.dbcol = true
......@@ -134,8 +124,8 @@ func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
f2.fieldType = RelForeignKey
f1.name = camelString(m1.table)
f2.name = camelString(m2.table)
f1.fullName = info.fullName + "." + f1.name
f2.fullName = info.fullName + "." + f2.name
f1.fullName = mi.fullName + "." + f1.name
f2.fullName = mi.fullName + "." + f2.name
f1.column = m1.table + "_id"
f2.column = m2.table + "_id"
f1.rel = true
......@@ -144,14 +134,14 @@ func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
f2.relTable = m2.table
f1.relModelInfo = m1
f2.relModelInfo = m2
f1.mi = info
f2.mi = info
f1.mi = mi
f2.mi = mi
info.fields.Add(fa)
info.fields.Add(f1)
info.fields.Add(f2)
info.fields.pk = fa
mi.fields.Add(fa)
mi.fields.Add(f1)
mi.fields.Add(f2)
mi.fields.pk = fa
info.uniques = []string{f1.column, f2.column}
mi.uniques = []string{f1.column, f2.column}
return
}
......@@ -22,25 +22,47 @@ import (
"time"
)
// 1 is attr
// 2 is tag
var supportTag = map[string]int{
"-": 1,
"null": 1,
"index": 1,
"unique": 1,
"pk": 1,
"auto": 1,
"auto_now": 1,
"auto_now_add": 1,
"size": 2,
"column": 2,
"default": 2,
"rel": 2,
"reverse": 2,
"rel_table": 2,
"rel_through": 2,
"digits": 2,
"decimals": 2,
"on_delete": 2,
"type": 2,
}
// get reflect.Type name with package path.
func getFullName(typ reflect.Type) string {
return typ.PkgPath() + "." + typ.Name()
}
// get table name. method, or field name. auto snaked.
// getTableName get struct table name.
// If the struct implement the TableName, then get the result as tablename
// else use the struct name which will apply snakeString.
func getTableName(val reflect.Value) string {
ind := reflect.Indirect(val)
fun := val.MethodByName("TableName")
if fun.IsValid() {
if fun := val.MethodByName("TableName"); fun.IsValid() {
vals := fun.Call([]reflect.Value{})
if len(vals) > 0 {
val := vals[0]
if val.Kind() == reflect.String {
return val.String()
}
// has return and the first val is string
if len(vals) > 0 && vals[0].Kind() == reflect.String {
return vals[0].String()
}
}
return snakeString(ind.Type().Name())
return snakeString(reflect.Indirect(val).Type().Name())
}
// get table engine, mysiam or innodb.
......@@ -189,21 +211,25 @@ func getFieldType(val reflect.Value) (ft int, err error) {
}
// parse struct tag string
func parseStructTag(data string, attrs *map[string]bool, tags *map[string]string) {
attr := make(map[string]bool)
tag := make(map[string]string)
func parseStructTag(data string) (attrs map[string]bool, tags map[string]string) {
attrs = make(map[string]bool)
tags = make(map[string]string)
for _, v := range strings.Split(data, defaultStructTagDelim) {
if v == "" {
continue
}
v = strings.TrimSpace(v)
if t := strings.ToLower(v); supportTag[t] == 1 {
attr[t] = true
attrs[t] = true
} else if i := strings.Index(v, "("); i > 0 && strings.Index(v, ")") == len(v)-1 {
name := t[:i]
if supportTag[name] == 2 {
v = v[i+1 : len(v)-1]
tag[name] = v
tags[name] = v
}
} else {
DebugLog.Println("unsupport orm tag", v)
}
}
*attrs = attr
*tags = tag
return
}
......@@ -104,7 +104,7 @@ func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect
panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ)))
}
name := getFullName(typ)
if mi, ok := modelCache.getByFN(name); ok {
if mi, ok := modelCache.getByFullName(name); ok {
return mi, ind
}
panic(fmt.Errorf("<Ormer> table: `%s` not found, maybe not RegisterModel", name))
......@@ -427,7 +427,7 @@ func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
}
} else {
name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName)))
if mi, ok := modelCache.getByFN(name); ok {
if mi, ok := modelCache.getByFullName(name); ok {
qs = newQuerySet(o, mi)
}
}
......
......@@ -286,7 +286,7 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
structMode = true
fn := getFullName(typ)
if mi, ok := modelCache.getByFN(fn); ok {
if mi, ok := modelCache.getByFullName(fn); ok {
sMi = mi
}
} else {
......@@ -355,12 +355,9 @@ func (o *rawSet) QueryRow(containers ...interface{}) error {
for i := 0; i < ind.NumField(); i++ {
f := ind.Field(i)
fe := ind.Type().Field(i)
var attrs map[string]bool
var tags map[string]string
parseStructTag(fe.Tag.Get("orm"), &attrs, &tags)
_, tags := parseStructTag(fe.Tag.Get(defaultStructTagName))
var col string
if col = tags["column"]; len(col) == 0 {
if col = tags["column"]; col == "" {
col = snakeString(fe.Name)
}
if v, ok := columnsMp[col]; ok {
......@@ -422,7 +419,7 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
structMode = true
fn := getFullName(typ)
if mi, ok := modelCache.getByFN(fn); ok {
if mi, ok := modelCache.getByFullName(fn); ok {
sMi = mi
}
} else {
......@@ -499,12 +496,9 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
for i := 0; i < ind.NumField(); i++ {
f := ind.Field(i)
fe := ind.Type().Field(i)
var attrs map[string]bool
var tags map[string]string
parseStructTag(fe.Tag.Get("orm"), &attrs, &tags)
_, tags := parseStructTag(fe.Tag.Get(defaultStructTagName))
var col string
if col = tags["column"]; len(col) == 0 {
if col = tags["column"]; col == "" {
col = snakeString(fe.Name)
}
if v, ok := columnsMp[col]; ok {
......
......@@ -227,7 +227,7 @@ func TestModelSyntax(t *testing.T) {
user := &User{}
ind := reflect.ValueOf(user).Elem()
fn := getFullName(ind.Type())
mi, ok := modelCache.getByFN(fn)
mi, ok := modelCache.getByFullName(fn)
throwFail(t, AssertIs(ok, true))
mi, ok = modelCache.get("user")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment