Commit d043ebcd authored by slene's avatar slene

orm support complete m2m operation api / auto load related api

parent e11c40ee
......@@ -151,7 +151,11 @@ func getDbCreateSql(al *alias) (sqls []string, tableIndexes map[string][]dbIndex
}
if mi.model != nil {
for _, names := range getTableUnique(mi.addrField) {
allnames := getTableUnique(mi.addrField)
if !mi.manual && len(mi.uniques) > 0 {
allnames = append(allnames, mi.uniques)
}
for _, names := range allnames {
cols := make([]string, 0, len(names))
for _, name := range names {
if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol {
......
......@@ -52,7 +52,6 @@ type dbBase struct {
var _ dbBaser = new(dbBase)
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, tz *time.Location) (columns []string, values []interface{}, err error) {
_, pkValue, _ := getExistPk(mi, ind)
for _, column := range cols {
var fi *fieldInfo
if fi, _ = mi.fields.GetByAny(column); fi != nil {
......@@ -63,82 +62,90 @@ func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string,
if fi.dbcol == false || fi.auto && skipAuto {
continue
}
var value interface{}
if fi.pk {
value = pkValue
value, err := d.collectFieldValue(mi, fi, ind, insert, tz)
if err != nil {
return nil, nil, err
}
columns = append(columns, column)
values = append(values, value)
}
return
}
func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) {
var value interface{}
if fi.pk {
_, value, _ = getExistPk(mi, ind)
} else {
field := ind.Field(fi.fieldIndex)
if fi.isFielder {
f := field.Addr().Interface().(Fielder)
value = f.RawValue()
} else {
field := ind.Field(fi.fieldIndex)
if fi.isFielder {
f := field.Addr().Interface().(Fielder)
value = f.RawValue()
} else {
switch fi.fieldType {
case TypeBooleanField:
value = field.Bool()
case TypeCharField, TypeTextField:
value = field.String()
case TypeFloatField, TypeDecimalField:
vu := field.Interface()
if _, ok := vu.(float32); ok {
value, _ = StrTo(ToStr(vu)).Float64()
switch fi.fieldType {
case TypeBooleanField:
value = field.Bool()
case TypeCharField, TypeTextField:
value = field.String()
case TypeFloatField, TypeDecimalField:
vu := field.Interface()
if _, ok := vu.(float32); ok {
value, _ = StrTo(ToStr(vu)).Float64()
} else {
value = field.Float()
}
case TypeDateField, TypeDateTimeField:
value = field.Interface()
if t, ok := value.(time.Time); ok {
if fi.fieldType == TypeDateField {
d.ins.TimeToDB(&t, DefaultTimeLoc)
} else {
value = field.Float()
d.ins.TimeToDB(&t, tz)
}
case TypeDateField, TypeDateTimeField:
value = field.Interface()
if t, ok := value.(time.Time); ok {
if fi.fieldType == TypeDateField {
d.ins.TimeToDB(&t, DefaultTimeLoc)
value = t
}
default:
switch {
case fi.fieldType&IsPostiveIntegerField > 0:
value = field.Uint()
case fi.fieldType&IsIntegerField > 0:
value = field.Int()
case fi.fieldType&IsRelField > 0:
if field.IsNil() {
value = nil
} else {
if _, vu, ok := getExistPk(fi.relModelInfo, reflect.Indirect(field)); ok {
value = vu
} else {
d.ins.TimeToDB(&t, tz)
}
value = t
}
default:
switch {
case fi.fieldType&IsPostiveIntegerField > 0:
value = field.Uint()
case fi.fieldType&IsIntegerField > 0:
value = field.Int()
case fi.fieldType&IsRelField > 0:
if field.IsNil() {
value = nil
} else {
if _, vu, ok := getExistPk(fi.relModelInfo, reflect.Indirect(field)); ok {
value = vu
} else {
value = nil
}
}
if fi.null == false && value == nil {
return nil, nil, errors.New(fmt.Sprintf("field `%s` cannot be NULL", fi.fullName))
}
}
if fi.null == false && value == nil {
return nil, errors.New(fmt.Sprintf("field `%s` cannot be NULL", fi.fullName))
}
}
}
switch fi.fieldType {
case TypeDateField, TypeDateTimeField:
if fi.auto_now || fi.auto_now_add && insert {
tnow := time.Now()
if fi.fieldType == TypeDateField {
d.ins.TimeToDB(&tnow, DefaultTimeLoc)
} else {
d.ins.TimeToDB(&tnow, tz)
}
value = tnow
if fi.isFielder {
f := field.Addr().Interface().(Fielder)
f.SetRaw(tnow.In(DefaultTimeLoc))
} else {
field.Set(reflect.ValueOf(tnow.In(DefaultTimeLoc)))
}
}
switch fi.fieldType {
case TypeDateField, TypeDateTimeField:
if fi.auto_now || fi.auto_now_add && insert {
tnow := time.Now()
if fi.fieldType == TypeDateField {
d.ins.TimeToDB(&tnow, DefaultTimeLoc)
} else {
d.ins.TimeToDB(&tnow, tz)
}
value = tnow
if fi.isFielder {
f := field.Addr().Interface().(Fielder)
f.SetRaw(tnow.In(DefaultTimeLoc))
} else {
field.Set(reflect.ValueOf(tnow.In(DefaultTimeLoc)))
}
}
}
columns = append(columns, column)
values = append(values, value)
}
return
return value, nil
}
func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) {
......@@ -250,6 +257,10 @@ func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.
return 0, err
}
return d.InsertValue(q, mi, names, values)
}
func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, names []string, values []interface{}) (int64, error) {
Q := d.ins.TableQuote()
marks := make([]string, len(names))
......@@ -653,10 +664,12 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
trefs = refs[len(tCols):]
for _, tbl := range tables.tables {
// loop selected tables
if tbl.sel {
last := mind
names := ""
mmi := mi
// loop cascade models
for _, name := range tbl.names {
names += name
if val, ok := cacheV[names]; ok {
......@@ -665,27 +678,30 @@ func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condi
} else {
fi := mmi.fields.GetByName(name)
lastm := mmi
mmi := fi.relModelInfo
field := reflect.Indirect(last.Field(fi.fieldIndex))
if field.IsValid() {
d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)], tz)
for _, fi := range mmi.fields.fieldsReverse {
if fi.inModel && fi.reverseFieldInfo.mi == lastm {
if fi.reverseFieldInfo != nil {
f := field.Field(fi.fieldIndex)
if f.Kind() == reflect.Ptr {
f.Set(last.Addr())
mmi = fi.relModelInfo
field := last
if last.Kind() != reflect.Invalid {
field = reflect.Indirect(last.Field(fi.fieldIndex))
if field.IsValid() {
d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)], tz)
for _, fi := range mmi.fields.fieldsReverse {
if fi.inModel && fi.reverseFieldInfo.mi == lastm {
if fi.reverseFieldInfo != nil {
f := field.Field(fi.fieldIndex)
if f.Kind() == reflect.Ptr {
f.Set(last.Addr())
}
}
}
}
last = field
}
cacheV[names] = &field
cacheM[names] = mmi
last = field
}
trefs = trefs[len(mmi.fields.dbcols):]
cacheV[names] = &field
cacheM[names] = mmi
}
}
trefs = trefs[len(mmi.fields.dbcols):]
}
}
......
......@@ -100,22 +100,29 @@ func (t *dbTables) parseRelated(rels []string, depth int) {
exs = strings.Split(s, ExprSep)
names = make([]string, 0, len(exs))
mmi = t.mi
cansel = true
cancel = true
jtl *dbTable
)
inner := true
for _, ex := range exs {
if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany {
names = append(names, fi.name)
mmi = fi.relModelInfo
jt := t.set(names, mmi, fi, fi.null == false)
if fi.null {
inner = false
}
jt := t.set(names, mmi, fi, inner)
jt.jtl = jtl
if fi.reverse {
cansel = false
cancel = false
}
if cansel {
if cancel {
jt.sel = depth > 0
if i < relsNum {
......@@ -178,9 +185,8 @@ func (t *dbTables) getJoinSql() (join string) {
return
}
func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) {
func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) {
var (
ffi *fieldInfo
jtl *dbTable
mmi = mi
)
......@@ -188,73 +194,67 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
num := len(exprs) - 1
names := make([]string, 0)
inner := true
for i, ex := range exprs {
exist := false
check:
fi, ok := mmi.fields.GetByAny(ex)
if ok {
if num != i {
names = append(names, fi.name)
isRel := fi.rel || fi.reverse
switch {
case fi.rel:
mmi = fi.relModelInfo
if fi.fieldType == RelManyToMany {
mmi = fi.relThroughModelInfo
}
case fi.reverse:
mmi = fi.reverseFieldInfo.mi
if fi.reverseFieldInfo.fieldType == RelManyToMany {
mmi = fi.reverseFieldInfo.relThroughModelInfo
}
default:
return
}
jt, _ := d.add(names, mmi, fi, fi.null == false)
jt.jtl = jtl
jtl = jt
names = append(names, fi.name)
if fi.rel && fi.fieldType == RelManyToMany {
ex = fi.relModelInfo.name
goto check
switch {
case fi.rel:
mmi = fi.relModelInfo
if fi.fieldType == RelManyToMany {
mmi = fi.relThroughModelInfo
}
case fi.reverse:
mmi = fi.reverseFieldInfo.mi
}
if fi.reverse && fi.reverseFieldInfo.fieldType == RelManyToMany {
ex = fi.reverseFieldInfo.mi.name
goto check
if isRel && (fi.mi.isThrough == false || num != i) {
if fi.null {
inner = false
}
exist = true
} else {
jt, _ := t.add(names, mmi, fi, inner)
jt.jtl = jtl
jtl = jt
}
if ffi == nil {
if num == i {
if i == 0 || jtl == nil {
index = "T0"
} else {
index = jtl.index
}
info = fi
if jtl != nil {
name = jtl.name + ExprSep + fi.name
} else {
if jtl == nil {
name = fi.name
} else {
name = jtl.name + ExprSep + fi.name
}
switch fi.fieldType {
case RelManyToMany, RelReverseMany:
default:
exist = true
switch {
case fi.rel:
case fi.reverse:
switch fi.reverseFieldInfo.fieldType {
case RelOneToOne, RelForeignKey:
index = jtl.index
info = fi.reverseFieldInfo.mi.fields.pk
name = info.name
}
}
}
ffi = fi
}
if exist == false {
} else {
index = ""
name = ""
info = nil
......@@ -267,16 +267,15 @@ func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string
return
}
func (d *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
func (t *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
if cond == nil || cond.IsEmpty() {
return
}
Q := d.base.TableQuote()
Q := t.base.TableQuote()
mi := d.mi
mi := t.mi
// outFor:
for i, p := range cond.params {
if i > 0 {
if p.isOr {
......@@ -289,7 +288,7 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
where += "NOT "
}
if p.isCond {
w, ps := d.getCondSql(p.cond, true, tz)
w, ps := t.getCondSql(p.cond, true, tz)
if w != "" {
w = fmt.Sprintf("( %s) ", w)
}
......@@ -305,7 +304,7 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
exprs = exprs[:num]
}
index, _, fi, suc := d.parseExprs(mi, exprs)
index, _, fi, suc := t.parseExprs(mi, exprs)
if suc == false {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep)))
}
......@@ -314,10 +313,10 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
operator = "exact"
}
operSql, args := d.base.GenerateOperatorSql(mi, fi, operator, p.args, tz)
operSql, args := t.base.GenerateOperatorSql(mi, fi, operator, p.args, tz)
leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)
d.base.GenerateOperatorLeftCol(fi, operator, &leftCol)
t.base.GenerateOperatorLeftCol(fi, operator, &leftCol)
where += fmt.Sprintf("%s %s ", leftCol, operSql)
params = append(params, args...)
......@@ -332,12 +331,12 @@ func (d *dbTables) getCondSql(cond *Condition, sub bool, tz *time.Location) (whe
return
}
func (d *dbTables) getOrderSql(orders []string) (orderSql string) {
func (t *dbTables) getOrderSql(orders []string) (orderSql string) {
if len(orders) == 0 {
return
}
Q := d.base.TableQuote()
Q := t.base.TableQuote()
orderSqls := make([]string, 0, len(orders))
for _, order := range orders {
......@@ -348,7 +347,7 @@ func (d *dbTables) getOrderSql(orders []string) (orderSql string) {
}
exprs := strings.Split(order, ExprSep)
index, _, fi, suc := d.parseExprs(d.mi, exprs)
index, _, fi, suc := t.parseExprs(t.mi, exprs)
if suc == false {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
}
......@@ -360,14 +359,14 @@ func (d *dbTables) getOrderSql(orders []string) (orderSql string) {
return
}
func (d *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits string) {
func (t *dbTables) getLimitSql(mi *modelInfo, offset int64, limit int64) (limits string) {
if limit == 0 {
limit = int64(DefaultRowsLimit)
}
if limit < 0 {
// no limit
if offset > 0 {
maxLimit := d.base.MaxLimit()
maxLimit := t.base.MaxLimit()
if maxLimit == 0 {
limits = fmt.Sprintf("OFFSET %d", offset)
} else {
......
......@@ -121,7 +121,6 @@ func bootStrap() {
err = errors.New(msg)
goto end
}
err = nil
} else {
i := newM2MModelInfo(mi, mii)
if fi.relTable != "" {
......@@ -135,6 +134,8 @@ func bootStrap() {
fi.relTable = i.table
fi.relThroughModelInfo = i
}
fi.relThroughModelInfo.isThrough = true
}
}
}
......@@ -152,6 +153,7 @@ func bootStrap() {
break
}
}
if inModel == false {
rmi := fi.relModelInfo
ffi := new(fieldInfo)
......@@ -185,9 +187,34 @@ func bootStrap() {
}
}
models = modelCache.all()
for _, mi := range models {
if fields, ok := mi.fields.fieldsByType[RelReverseOne]; ok {
for _, fi := range fields {
for _, fi := range mi.fields.fieldsRel {
switch fi.fieldType {
case RelManyToMany:
for _, ffi := range fi.relThroughModelInfo.fields.fieldsRel {
switch ffi.fieldType {
case RelOneToOne, RelForeignKey:
if ffi.relModelInfo == fi.relModelInfo {
fi.reverseFieldInfoTwo = ffi
}
if ffi.relModelInfo == mi {
fi.reverseField = ffi.name
fi.reverseFieldInfo = ffi
}
}
}
if fi.reverseFieldInfoTwo == nil {
err = fmt.Errorf("can not find m2m field for m2m model `%s`, ensure your m2m model defined correct",
fi.relThroughModelInfo.fullName)
goto end
}
}
}
for _, fi := range mi.fields.fieldsReverse {
switch fi.fieldType {
case RelReverseOne:
found := false
mForA:
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] {
......@@ -195,6 +222,9 @@ func bootStrap() {
found = true
fi.reverseField = ffi.name
fi.reverseFieldInfo = ffi
ffi.reverseField = fi.name
ffi.reverseFieldInfo = fi
break mForA
}
}
......@@ -202,10 +232,7 @@ func bootStrap() {
err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
goto end
}
}
}
if fields, ok := mi.fields.fieldsByType[RelReverseMany]; ok {
for _, fi := range fields {
case RelReverseMany:
found := false
mForB:
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] {
......@@ -213,6 +240,10 @@ func bootStrap() {
found = true
fi.reverseField = ffi.name
fi.reverseFieldInfo = ffi
ffi.reverseField = fi.name
ffi.reverseFieldInfo = fi
break mForB
}
}
......@@ -221,14 +252,20 @@ func bootStrap() {
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] {
if ffi.relModelInfo == mi {
found = true
fi.reverseField = ffi.name
fi.reverseFieldInfo = ffi
fi.reverseField = ffi.reverseFieldInfoTwo.name
fi.reverseFieldInfo = ffi.reverseFieldInfoTwo
fi.relThroughModelInfo = ffi.relThroughModelInfo
fi.reverseFieldInfoTwo = ffi.reverseFieldInfo
fi.reverseFieldInfoM2M = ffi
ffi.reverseFieldInfoM2M = fi
break mForC
}
}
}
if found == false {
err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
err = fmt.Errorf("reverse field for `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
goto end
}
}
......
......@@ -103,6 +103,8 @@ type fieldInfo struct {
reverse bool
reverseField string
reverseFieldInfo *fieldInfo
reverseFieldInfoTwo *fieldInfo
reverseFieldInfoM2M *fieldInfo
relTable string
relThrough string
relThroughModelInfo *modelInfo
......
......@@ -16,6 +16,8 @@ type modelInfo struct {
fields *fields
manual bool
addrField reflect.Value
uniques []string
isThrough bool
}
func newModelInfo(val reflect.Value) (info *modelInfo) {
......@@ -118,5 +120,7 @@ func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
info.fields.Add(f1)
info.fields.Add(f2)
info.fields.pk = fa
info.uniques = []string{f1.column, f2.column}
return
}
......@@ -99,10 +99,11 @@ func NewUser() *User {
}
type Profile struct {
Id int
Age int16
Money float64
User *User `orm:"reverse(one)" json:"-"`
Id int
Age int16
Money float64
User *User `orm:"reverse(one)" json:"-"`
BestPost *Post `orm:"rel(one);null"`
}
func (u *Profile) TableName() string {
......@@ -136,9 +137,10 @@ func NewPost() *Post {
}
type Tag struct {
Id int
Name string `orm:"size(30)"`
Posts []*Post `orm:"reverse(many)" json:"-"`
Id int
Name string `orm:"size(30)"`
BestPost *Post `orm:"rel(one);null"`
Posts []*Post `orm:"reverse(many)" json:"-"`
}
func NewTag() *Tag {
......
......@@ -18,7 +18,7 @@ var (
Debug = false
DebugLog = NewLog(os.Stderr)
DefaultRowsLimit = 1000
DefaultRelsDepth = 5
DefaultRelsDepth = 2
DefaultTimeLoc = time.Local
ErrTxHasBegan = errors.New("<Ormer.Begin> transaction already begin")
ErrTxDone = errors.New("<Ormer.Commit/Rollback> transaction not begin")
......@@ -53,6 +53,14 @@ func (o *orm) getMiInd(md interface{}) (mi *modelInfo, ind reflect.Value) {
panic(fmt.Errorf("<Ormer> table: `%s` not found, maybe not RegisterModel", name))
}
func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
fi, ok := mi.fields.GetByAny(name)
if !ok {
panic(fmt.Errorf("<Ormer> cannot find field `%s` for model `%s`", name, mi.fullName))
}
return fi
}
func (o *orm) Read(md interface{}, cols ...string) error {
mi, ind := o.getMiInd(md)
err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols)
......@@ -107,22 +115,152 @@ func (o *orm) Delete(md interface{}) (int64, error) {
return num, nil
}
func (o *orm) M2mAdd(md interface{}, name string, mds ...interface{}) (int64, error) {
// TODO
panic(ErrNotImplement)
return 0, nil
func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer {
mi, ind := o.getMiInd(md)
fi := o.getFieldInfo(mi, name)
if fi.fieldType != RelManyToMany {
panic(fmt.Errorf("<Ormer.QueryM2M> name `%s` for model `%s` is not a m2m field", fi.name, mi.fullName))
}
return newQueryM2M(md, o, mi, fi, ind)
}
func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) {
_, fi, ind, qseter := o.queryRelated(md, name)
qs := qseter.(*querySet)
var relDepth int
var limit, offset int64
var order string
for i, arg := range args {
switch i {
case 0:
if v, ok := arg.(bool); ok {
if v {
relDepth = DefaultRelsDepth
}
} else if v, ok := arg.(int); ok {
relDepth = v
}
case 1:
limit = ToInt64(arg)
case 2:
offset = ToInt64(arg)
case 3:
order, _ = arg.(string)
}
}
switch fi.fieldType {
case RelOneToOne, RelForeignKey, RelReverseOne:
limit = 1
offset = 0
}
qs.limit = limit
qs.offset = offset
qs.relDepth = relDepth
if len(order) > 0 {
qs.orders = []string{order}
}
find := ind.Field(fi.fieldIndex)
var nums int64
var err error
switch fi.fieldType {
case RelOneToOne, RelForeignKey, RelReverseOne:
val := reflect.New(find.Type().Elem())
container := val.Interface()
err = qs.One(container)
if err == nil {
find.Set(val)
nums = 1
}
default:
nums, err = qs.All(find.Addr().Interface())
}
return nums, err
}
func (o *orm) QueryRelated(md interface{}, name string) QuerySeter {
// is this api needed ?
_, _, _, qs := o.queryRelated(md, name)
return qs
}
func (o *orm) M2mDel(md interface{}, name string, mds ...interface{}) (int64, error) {
// TODO
panic(ErrNotImplement)
return 0, nil
func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) {
mi, ind := o.getMiInd(md)
fi := o.getFieldInfo(mi, name)
_, _, exist := getExistPk(mi, ind)
if exist == false {
panic(ErrMissPK)
}
var qs *querySet
switch fi.fieldType {
case RelOneToOne, RelForeignKey, RelManyToMany:
if !fi.inModel {
break
}
qs = o.getRelQs(md, mi, fi)
case RelReverseOne, RelReverseMany:
if !fi.inModel {
break
}
qs = o.getReverseQs(md, mi, fi)
}
if qs == nil {
panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available rel/reverse field"))
}
return mi, fi, ind, qs
}
func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
switch fi.fieldType {
case RelReverseOne, RelReverseMany:
default:
panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available reverse field", fi.name, mi.fullName))
}
var q *querySet
if fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough {
q = newQuerySet(o, fi.relModelInfo).(*querySet)
q.cond = NewCondition().And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md)
} else {
q = newQuerySet(o, fi.reverseFieldInfo.mi).(*querySet)
q.cond = NewCondition().And(fi.reverseFieldInfo.column, md)
}
return q
}
func (o *orm) LoadRel(md interface{}, name string) (int64, error) {
// TODO
panic(ErrNotImplement)
return 0, nil
func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
switch fi.fieldType {
case RelOneToOne, RelForeignKey, RelManyToMany:
default:
panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available rel field", fi.name, mi.fullName))
}
q := newQuerySet(o, fi.relModelInfo).(*querySet)
q.cond = NewCondition()
if fi.fieldType == RelManyToMany {
q.cond = q.cond.And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md)
} else {
q.cond = q.cond.And(fi.reverseFieldInfo.column, md)
}
return q
}
func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
......
package orm
import (
"reflect"
)
type queryM2M struct {
md interface{}
mi *modelInfo
fi *fieldInfo
qs *querySet
ind reflect.Value
}
func (o *queryM2M) Add(mds ...interface{}) (int64, error) {
fi := o.fi
mi := fi.relThroughModelInfo
mfi := fi.reverseFieldInfo
rfi := fi.reverseFieldInfoTwo
orm := o.qs.orm
dbase := orm.alias.DbBaser
var models []interface{}
for _, md := range mds {
val := reflect.ValueOf(md)
if val.Kind() == reflect.Slice || val.Kind() == reflect.Array {
for i := 0; i < val.Len(); i++ {
v := val.Index(i)
if v.CanInterface() {
models = append(models, v.Interface())
}
}
} else {
models = append(models, md)
}
}
_, v1, exist := getExistPk(o.mi, o.ind)
if exist == false {
panic(ErrMissPK)
}
names := []string{mfi.column, rfi.column}
var nums int64
for _, md := range models {
ind := reflect.Indirect(reflect.ValueOf(md))
var v2 interface{}
if ind.Kind() != reflect.Struct {
v2 = ind.Interface()
} else {
_, v2, exist = getExistPk(fi.relModelInfo, ind)
if exist == false {
panic(ErrMissPK)
}
}
values := []interface{}{v1, v2}
_, err := dbase.InsertValue(orm.db, mi, names, values)
if err != nil {
return nums, err
}
nums += 1
}
return nums, nil
}
func (o *queryM2M) Remove(mds ...interface{}) (int64, error) {
fi := o.fi
qs := o.qs.Filter(fi.reverseFieldInfo.name, o.md)
nums, err := qs.Filter(fi.reverseFieldInfoTwo.name+ExprSep+"in", mds).Delete()
if err != nil {
return nums, err
}
return nums, nil
}
func (o *queryM2M) Exist(md interface{}) bool {
fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).
Filter(fi.reverseFieldInfoTwo.name, md).Exist()
}
func (o *queryM2M) Clear() (int64, error) {
fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Delete()
}
func (o *queryM2M) Count() (int64, error) {
fi := o.fi
return o.qs.Filter(fi.reverseFieldInfo.name, o.md).Count()
}
var _ QueryM2Mer = new(queryM2M)
func newQueryM2M(md interface{}, o *orm, mi *modelInfo, fi *fieldInfo, ind reflect.Value) QueryM2Mer {
qm2m := new(queryM2M)
qm2m.md = md
qm2m.mi = mi
qm2m.fi = fi
qm2m.ind = ind
qm2m.qs = newQuerySet(o, fi.relThroughModelInfo).(*querySet)
return qm2m
}
This diff is collapsed.
......@@ -24,9 +24,8 @@ type Ormer interface {
Insert(interface{}) (int64, error)
Update(interface{}, ...string) (int64, error)
Delete(interface{}) (int64, error)
M2mAdd(interface{}, string, ...interface{}) (int64, error)
M2mDel(interface{}, string, ...interface{}) (int64, error)
LoadRel(interface{}, string) (int64, error)
LoadRelated(interface{}, string, ...interface{}) (int64, error)
QueryM2M(interface{}, string) QueryM2Mer
QueryTable(interface{}) QuerySeter
Using(string) error
Begin() error
......@@ -61,6 +60,14 @@ type QuerySeter interface {
ValuesFlat(*ParamsList, string) (int64, error)
}
type QueryM2Mer interface {
Add(...interface{}) (int64, error)
Remove(...interface{}) (int64, error)
Exist(interface{}) bool
Clear() (int64, error)
Count() (int64, error)
}
type RawPreparer interface {
Exec(...interface{}) (sql.Result, error)
Close() error
......@@ -114,6 +121,7 @@ type txEnder interface {
type dbBaser interface {
Read(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) error
Insert(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
InsertValue(dbQuerier, *modelInfo, []string, []interface{}) (int64, error)
InsertStmt(stmtQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
Update(dbQuerier, *modelInfo, reflect.Value, *time.Location, []string) (int64, error)
Delete(dbQuerier, *modelInfo, reflect.Value, *time.Location) (int64, error)
......@@ -139,4 +147,5 @@ type dbBaser interface {
ShowTablesQuery() string
ShowColumnsQuery(string) string
IndexExists(dbQuerier, string, string) bool
collectFieldValue(*modelInfo, *fieldInfo, reflect.Value, bool, *time.Location) (interface{}, error)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment