Commit bce35c70 authored by slene's avatar slene

init orm project, beta, unstable

parent ccbf116f
## beego orm
a powerful orm framework
now, beta, unstable, may be changing some api make your app build failed.
## TODO
- some unrealized api
- examples
- docs
- support postgres
- support sqlite
\ No newline at end of file
package orm
import (
"flag"
"fmt"
"os"
)
func printHelp() {
}
func getSqlAll() (sql string) {
for _, mi := range modelCache.allOrdered() {
_ = mi
}
return
}
func runCommand() {
if len(os.Args) < 2 || os.Args[1] != "orm" {
return
}
_ = flag.NewFlagSet("orm command", flag.ExitOnError)
args := argString(os.Args[2:])
cmd := args.Get(0)
switch cmd {
case "syncdb":
case "sqlall":
sql := getSqlAll()
fmt.Println(sql)
default:
if cmd != "" {
fmt.Printf("unknown command %s", cmd)
} else {
printHelp()
}
os.Exit(2)
}
}
package orm
import (
"database/sql"
"errors"
"fmt"
"reflect"
"strings"
"time"
)
const (
format_Date = "2006-01-02"
format_DateTime = "2006-01-02 15:04:05"
)
var (
ErrMissPK = errors.New("missed pk value")
)
var (
operators = map[string]bool{
"exact": true,
"iexact": true,
"contains": true,
"icontains": true,
// "regex": true,
// "iregex": true,
"gt": true,
"gte": true,
"lt": true,
"lte": true,
"startswith": true,
"endswith": true,
"istartswith": true,
"iendswith": true,
"in": true,
// "range": true,
// "year": true,
// "month": true,
// "day": true,
// "week_day": true,
"isnull": true,
// "search": true,
}
operatorsSQL = map[string]string{
"exact": "= ?",
"iexact": "LIKE ?",
"contains": "LIKE BINARY ?",
"icontains": "LIKE ?",
// "regex": "REGEXP BINARY ?",
// "iregex": "REGEXP ?",
"gt": "> ?",
"gte": ">= ?",
"lt": "< ?",
"lte": "<= ?",
"startswith": "LIKE BINARY ?",
"endswith": "LIKE BINARY ?",
"istartswith": "LIKE ?",
"iendswith": "LIKE ?",
}
)
type dbTable struct {
id int
index string
name string
names []string
sel bool
inner bool
mi *modelInfo
fi *fieldInfo
jtl *dbTable
}
type dbTables struct {
tablesM map[string]*dbTable
tables []*dbTable
mi *modelInfo
base dbBaser
}
func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
name := strings.Join(names, ExprSep)
if j, ok := t.tablesM[name]; ok {
j.name = name
j.mi = mi
j.fi = fi
j.inner = inner
} else {
i := len(t.tables) + 1
jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
t.tablesM[name] = jt
t.tables = append(t.tables, jt)
}
return t.tablesM[name]
}
func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
name := strings.Join(names, ExprSep)
if _, ok := t.tablesM[name]; ok == false {
i := len(t.tables) + 1
jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
t.tablesM[name] = jt
t.tables = append(t.tables, jt)
return jt, true
}
return t.tablesM[name], false
}
func (t *dbTables) get(name string) (*dbTable, bool) {
j, ok := t.tablesM[name]
return j, ok
}
func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
if depth < 0 || fi.fieldType == RelManyToMany {
return related
}
if prefix == "" {
prefix = fi.name
} else {
prefix = prefix + ExprSep + fi.name
}
related = append(related, prefix)
depth--
for _, fi := range fi.relModelInfo.fields.fieldsRel {
related = t.loopDepth(depth, prefix, fi, related)
}
return related
}
func (t *dbTables) parseRelated(rels []string, depth int) {
relsNum := len(rels)
related := make([]string, relsNum)
copy(related, rels)
relDepth := depth
if relsNum != 0 {
relDepth = 0
}
relDepth--
for _, fi := range t.mi.fields.fieldsRel {
related = t.loopDepth(relDepth, "", fi, related)
}
for i, s := range related {
var (
exs = strings.Split(s, ExprSep)
names = make([]string, 0, len(exs))
mmi = t.mi
cansel = true
jtl *dbTable
)
for _, ex := range exs {
if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany {
names = append(names, fi.name)
mmi = fi.relModelInfo
jt := t.set(names, mmi, fi, fi.null == false)
jt.jtl = jtl
if fi.reverse {
cansel = false
}
if cansel {
jt.sel = depth > 0
if i < relsNum {
jt.sel = true
}
}
jtl = jt
} else {
panic(fmt.Sprintf("unknown model/table name `%s`", ex))
}
}
}
}
func (t *dbTables) getJoinSql() (join string) {
for _, jt := range t.tables {
if jt.inner {
join += "INNER JOIN "
} else {
join += "LEFT OUTER JOIN "
}
var (
table string
t1, t2 string
c1, c2 string
)
t1 = "T0"
if jt.jtl != nil {
t1 = jt.jtl.index
}
t2 = jt.index
table = jt.mi.table
switch {
case jt.fi.fieldType == RelManyToMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
c1 = jt.fi.mi.fields.pk[0].column
for _, ffi := range jt.mi.fields.fieldsRel {
if jt.fi.mi == ffi.relModelInfo {
c2 = ffi.column
break
}
}
default:
c1 = jt.fi.column
c2 = jt.fi.relModelInfo.fields.pk[0].column
if jt.fi.reverse {
c1 = jt.mi.fields.pk[0].column
c2 = jt.fi.reverseFieldInfo.column
}
}
join += fmt.Sprintf("`%s` %s ON %s.`%s` = %s.`%s` ", table, t2,
t2, c2, t1, c1)
}
return
}
func (d *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, column string, info *fieldInfo, success bool) {
var (
ffi *fieldInfo
jtl *dbTable
mmi = mi
)
num := len(exprs) - 1
names := make([]string, 0)
for i, ex := range exprs {
exist := false
check:
fi, ok := mmi.fields.GetByAny(ex)
if ok {
if num != i {
names = append(names, fi.name)
switch {
case fi.rel:
mmi = fi.relModelInfo
if fi.fieldType == RelManyToMany {
mmi = fi.relThroughModelInfo
}
case fi.reverse:
mmi = fi.reverseFieldInfo.mi
if fi.reverseFieldInfo.fieldType == RelManyToMany {
mmi = fi.reverseFieldInfo.relThroughModelInfo
}
}
jt, _ := d.add(names, mmi, fi, fi.null == false)
jt.jtl = jtl
jtl = jt
if fi.rel && fi.fieldType == RelManyToMany {
ex = fi.relModelInfo.name
goto check
}
if fi.reverse && fi.reverseFieldInfo.fieldType == RelManyToMany {
ex = fi.reverseFieldInfo.mi.name
goto check
}
exist = true
} else {
if ffi == nil {
index = "T0"
} else {
index = jtl.index
}
column = fi.column
info = fi
switch fi.fieldType {
case RelManyToMany, RelReverseMany:
default:
exist = true
}
}
ffi = fi
}
if exist == false {
index = ""
column = ""
success = false
return
}
}
success = index != "" && column != ""
return
}
func (d *dbTables) getCondSql(cond *Condition, sub bool) (where string, params []interface{}) {
if cond == nil || cond.IsEmpty() {
return
}
mi := d.mi
// outFor:
for i, p := range cond.params {
if i > 0 {
if p.isOr {
where += "OR "
} else {
where += "AND "
}
}
if p.isNot {
where += "NOT "
}
if p.isCond {
w, ps := d.getCondSql(p.cond, true)
if w != "" {
w = fmt.Sprintf("( %s) ", w)
}
where += w
params = append(params, ps...)
} else {
exprs := p.exprs
num := len(exprs) - 1
operator := ""
if operators[exprs[num]] {
operator = exprs[num]
exprs = exprs[:num]
}
index, column, _, suc := d.parseExprs(mi, exprs)
if suc == false {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep)))
}
if operator == "" {
operator = "exact"
}
operSql, args := d.base.GetOperatorSql(mi, operator, p.args)
where += fmt.Sprintf("%s.`%s` %s ", index, column, operSql)
params = append(params, args...)
}
}
if sub == false && where != "" {
where = "WHERE " + where
}
return
}
func (d *dbTables) getOrderSql(orders []string) (orderSql string) {
if len(orders) == 0 {
return
}
orderSqls := make([]string, 0, len(orders))
for _, order := range orders {
asc := "ASC"
if order[0] == '-' {
asc = "DESC"
order = order[1:]
}
exprs := strings.Split(order, ExprSep)
index, column, _, suc := d.parseExprs(d.mi, exprs)
if suc == false {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
}
orderSqls = append(orderSqls, fmt.Sprintf("%s.`%s` %s", index, column, asc))
}
orderSql = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
return
}
func (d *dbTables) getLimitSql(offset int64, limit int) (limits string) {
if limit == 0 {
limit = DefaultRowsLimit
}
if limit < 0 {
// no limit
if offset > 0 {
limits = fmt.Sprintf("OFFSET %d", offset)
}
} else if offset <= 0 {
limits = fmt.Sprintf("LIMIT %d", limit)
} else {
limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset)
}
return
}
func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
tables := &dbTables{}
tables.tablesM = make(map[string]*dbTable)
tables.mi = mi
tables.base = base
return tables
}
type dbBase struct {
ins dbBaser
}
func (d *dbBase) existPk(mi *modelInfo, ind reflect.Value) ([]string, []interface{}, bool) {
exist := true
columns := make([]string, 0, len(mi.fields.pk))
values := make([]interface{}, 0, len(mi.fields.pk))
for _, fi := range mi.fields.pk {
v := ind.Field(fi.fieldIndex)
if fi.fieldType&IsIntegerField > 0 {
vu := v.Int()
if exist {
exist = vu > 0
}
values = append(values, vu)
} else {
vu := v.String()
if exist {
exist = vu != ""
}
values = append(values, vu)
}
columns = append(columns, fi.column)
}
return columns, values, exist
}
func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, skipAuto bool, insert bool) (columns []string, values []interface{}, err error) {
_, pkValues, _ := d.existPk(mi, ind)
for _, column := range mi.fields.orders {
fi := mi.fields.columns[column]
if fi.dbcol == false || fi.auto && skipAuto {
continue
}
var value interface{}
if i, ok := mi.fields.pk.Exist(fi); ok {
value = pkValues[i]
} else {
field := ind.Field(fi.fieldIndex)
if fi.isFielder {
f := field.Addr().Interface().(Fielder)
value = f.RawValue()
} else {
switch fi.fieldType {
case TypeBooleanField:
value = field.Bool()
case TypeCharField, TypeTextField:
value = field.String()
case TypeFloatField, TypeDecimalField:
value = field.Float()
case TypeDateField, TypeDateTimeField:
value = field.Interface()
default:
switch {
case fi.fieldType&IsPostiveIntegerField > 0:
value = field.Uint()
case fi.fieldType&IsIntegerField > 0:
value = field.Int()
case fi.fieldType&IsRelField > 0:
if field.IsNil() {
value = nil
} else {
_, fvalues, fok := d.existPk(fi.relModelInfo, reflect.Indirect(field))
if fok {
value = fvalues[0]
} else {
value = nil
}
}
if fi.null == false && value == nil {
return nil, nil, errors.New(fmt.Sprintf("field `%s` cannot be NULL", fi.fullName))
}
}
}
}
switch fi.fieldType {
case TypeDateField, TypeDateTimeField:
if fi.auto_now || fi.auto_now_add && insert {
tnow := time.Now()
if fi.fieldType == TypeDateField {
value = timeFormat(tnow, format_Date)
} else {
value = timeFormat(tnow, format_DateTime)
}
if fi.isFielder {
f := field.Addr().Interface().(Fielder)
f.SetRaw(tnow)
} else {
field.Set(reflect.ValueOf(tnow))
}
}
}
}
columns = append(columns, column)
values = append(values, value)
}
return
}
func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (*sql.Stmt, error) {
dbcols := make([]string, 0, len(mi.fields.dbcols))
marks := make([]string, 0, len(mi.fields.dbcols))
for _, fi := range mi.fields.fieldsDB {
if fi.auto == false {
dbcols = append(dbcols, fi.column)
marks = append(marks, "?")
}
}
qmarks := strings.Join(marks, ", ")
columns := strings.Join(dbcols, "`,`")
query := fmt.Sprintf("INSERT INTO `%s` (`%s`) VALUES (%s)", mi.table, columns, qmarks)
return q.Prepare(query)
}
func (d *dbBase) InsertStmt(stmt *sql.Stmt, mi *modelInfo, ind reflect.Value) (int64, error) {
_, values, err := d.collectValues(mi, ind, true, true)
if err != nil {
return 0, err
}
if res, err := stmt.Exec(values...); err == nil {
return res.LastInsertId()
} else {
return 0, err
}
}
func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
names, values, err := d.collectValues(mi, ind, true, true)
if err != nil {
return 0, err
}
marks := make([]string, len(names))
for i, _ := range marks {
marks[i] = "?"
}
qmarks := strings.Join(marks, ", ")
columns := strings.Join(names, "`,`")
query := fmt.Sprintf("INSERT INTO `%s` (`%s`) VALUES (%s)", mi.table, columns, qmarks)
if res, err := q.Exec(query, values...); err == nil {
return res.LastInsertId()
} else {
return 0, err
}
}
func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
pkNames, pkValues, ok := d.existPk(mi, ind)
if ok == false {
return 0, ErrMissPK
}
setNames, setValues, err := d.collectValues(mi, ind, true, false)
if err != nil {
return 0, err
}
pkColumns := strings.Join(pkNames, "` = ? AND `")
setColumns := strings.Join(setNames, "` = ?, `")
query := fmt.Sprintf("UPDATE `%s` SET `%s` = ? WHERE `%s` = ?", mi.table, setColumns, pkColumns)
setValues = append(setValues, pkValues...)
if res, err := q.Exec(query, setValues...); err == nil {
return res.RowsAffected()
} else {
return 0, err
}
return 0, nil
}
func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value) (int64, error) {
names, values, ok := d.existPk(mi, ind)
if ok == false {
return 0, ErrMissPK
}
columns := strings.Join(names, "` = ? AND `")
query := fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, columns)
if res, err := q.Exec(query, values...); err == nil {
num, err := res.RowsAffected()
if err != nil {
return 0, err
}
if num > 0 {
if mi.fields.auto != nil {
ind.Field(mi.fields.auto.fieldIndex).SetInt(0)
}
if len(names) == 1 {
err := d.deleteRels(q, mi, values)
if err != nil {
return num, err
}
}
}
return num, err
} else {
return 0, err
}
return 0, nil
}
func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params) (int64, error) {
columns := make([]string, 0, len(params))
values := make([]interface{}, 0, len(params))
for col, val := range params {
column := snakeString(col)
if fi, ok := mi.fields.columns[column]; ok == false || fi.dbcol == false {
panic(fmt.Sprintf("wrong field/column name `%s`", column))
}
columns = append(columns, column)
values = append(values, val)
}
if len(columns) == 0 {
panic("update params cannot empty")
}
tables := newDbTables(mi, d.ins)
if qs != nil {
tables.parseRelated(qs.related, qs.relDepth)
}
where, args := tables.getCondSql(cond, false)
join := tables.getJoinSql()
query := fmt.Sprintf("UPDATE `%s` T0 %sSET T0.`%s` = ? %s", mi.table, join, strings.Join(columns, "` = ?, T0.`"), where)
values = append(values, args...)
if res, err := q.Exec(query, values...); err == nil {
return res.RowsAffected()
} else {
return 0, err
}
return 0, nil
}
func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}) error {
for _, fi := range mi.fields.fieldsReverse {
fi = fi.reverseFieldInfo
switch fi.onDelete {
case od_CASCADE:
cond := NewCondition()
cond.And(fmt.Sprintf("%s__in", fi.name), args...)
_, err := d.DeleteBatch(q, nil, fi.mi, cond)
if err != nil {
return err
}
case od_SET_DEFAULT, od_SET_NULL:
cond := NewCondition()
cond.And(fmt.Sprintf("%s__in", fi.name), args...)
params := Params{fi.column: nil}
if fi.onDelete == od_SET_DEFAULT {
params[fi.column] = fi.initial.String()
}
_, err := d.UpdateBatch(q, nil, fi.mi, cond, params)
if err != nil {
return err
}
case od_DO_NOTHING:
}
}
return nil
}
func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition) (int64, error) {
tables := newDbTables(mi, d.ins)
if qs != nil {
tables.parseRelated(qs.related, qs.relDepth)
}
if cond == nil || cond.IsEmpty() {
panic("delete operation cannot execute without condition")
}
where, args := tables.getCondSql(cond, false)
join := tables.getJoinSql()
colsNum := len(mi.fields.pk)
cols := make([]string, colsNum)
for i, fi := range mi.fields.pk {
cols[i] = fi.column
}
colsql := fmt.Sprintf("T0.`%s`", strings.Join(cols, "`, T0.`"))
query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s", colsql, mi.table, join, where)
var rs *sql.Rows
if r, err := q.Query(query, args...); err != nil {
return 0, err
} else {
rs = r
}
refs := make([]interface{}, colsNum)
for i, _ := range refs {
var ref string
refs[i] = &ref
}
args = make([]interface{}, 0)
cnt := 0
for rs.Next() {
if err := rs.Scan(refs...); err != nil {
return 0, err
}
for _, ref := range refs {
args = append(args, reflect.ValueOf(ref).Elem().Interface())
}
cnt++
}
if cnt == 0 {
return 0, nil
}
if colsNum > 1 {
columns := strings.Join(cols, "` = ? AND `")
query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` = ?", mi.table, columns)
} else {
var sql string
sql, args = d.ins.GetOperatorSql(mi, "in", args)
query = fmt.Sprintf("DELETE FROM `%s` WHERE `%s` %s", mi.table, cols[0], sql)
}
if res, err := q.Exec(query, args...); err == nil {
num, err := res.RowsAffected()
if err != nil {
return 0, err
}
if colsNum == 1 && num > 0 {
err := d.deleteRels(q, mi, args)
if err != nil {
return num, err
}
}
return num, nil
} else {
return 0, err
}
return 0, nil
}
func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}) (int64, error) {
val := reflect.ValueOf(container)
ind := reflect.Indirect(val)
typ := ind.Type()
errTyp := true
one := true
if val.Kind() == reflect.Ptr {
tp := typ
if ind.Kind() == reflect.Slice {
one = false
if ind.Type().Elem().Kind() == reflect.Ptr {
tp = ind.Type().Elem().Elem()
}
}
errTyp = tp.PkgPath()+"."+tp.Name() != mi.fullName
}
if errTyp {
panic(fmt.Sprintf("wrong object type `%s` for rows scan, need *[]*%s or *%s", val.Type(), mi.fullName, mi.fullName))
}
rlimit := qs.limit
offset := qs.offset
if one {
rlimit = 0
offset = 0
}
tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth)
where, args := tables.getCondSql(cond, false)
orderBy := tables.getOrderSql(qs.orders)
limit := tables.getLimitSql(offset, rlimit)
join := tables.getJoinSql()
colsNum := len(mi.fields.dbcols)
cols := fmt.Sprintf("T0.`%s`", strings.Join(mi.fields.dbcols, "`, T0.`"))
for _, tbl := range tables.tables {
if tbl.sel {
colsNum += len(tbl.mi.fields.dbcols)
cols += fmt.Sprintf(", %s.`%s`", tbl.index, strings.Join(tbl.mi.fields.dbcols, "`, "+tbl.index+".`"))
}
}
query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s%s%s", cols, mi.table, join, where, orderBy, limit)
var rs *sql.Rows
if r, err := q.Query(query, args...); err != nil {
return 0, err
} else {
rs = r
}
refs := make([]interface{}, colsNum)
for i, _ := range refs {
var ref string
refs[i] = &ref
}
slice := ind
var cnt int64
for rs.Next() {
if one && cnt == 0 || one == false {
if err := rs.Scan(refs...); err != nil {
return 0, err
}
elm := reflect.New(mi.addrField.Elem().Type())
md := elm.Interface().(Modeler)
md.Init(md)
mind := reflect.Indirect(elm)
cacheV := make(map[string]*reflect.Value)
cacheM := make(map[string]*modelInfo)
trefs := refs
d.setColsValues(mi, &mind, mi.fields.dbcols, refs[:len(mi.fields.dbcols)])
trefs = refs[len(mi.fields.dbcols):]
for _, tbl := range tables.tables {
if tbl.sel {
last := mind
names := ""
mmi := mi
for _, name := range tbl.names {
names += name
if val, ok := cacheV[names]; ok {
last = *val
mmi = cacheM[names]
} else {
fi := mmi.fields.GetByName(name)
lastm := mmi
mmi := fi.relModelInfo
field := reflect.Indirect(last.Field(fi.fieldIndex))
d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)])
for _, fi := range mmi.fields.fieldsReverse {
if fi.reverseFieldInfo.mi == lastm {
if fi.reverseFieldInfo != nil {
field.Field(fi.fieldIndex).Set(last.Addr())
}
}
}
trefs = trefs[len(mmi.fields.dbcols):]
cacheV[names] = &field
cacheM[names] = mmi
last = field
}
}
}
}
if one {
ind.Set(mind)
} else {
slice = reflect.Append(slice, mind.Addr())
}
}
cnt++
}
if one == false {
ind.Set(slice)
}
return cnt, nil
}
func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition) (cnt int64, err error) {
tables := newDbTables(mi, d.ins)
tables.parseRelated(qs.related, qs.relDepth)
where, args := tables.getCondSql(cond, false)
tables.getOrderSql(qs.orders)
join := tables.getJoinSql()
query := fmt.Sprintf("SELECT COUNT(*) FROM `%s` T0 %s%s", mi.table, join, where)
row := q.QueryRow(query, args...)
err = row.Scan(&cnt)
return
}
func (d *dbBase) GetOperatorSql(mi *modelInfo, operator string, args []interface{}) (string, []interface{}) {
params := make([]interface{}, len(args))
copy(params, args)
sql := ""
for i, arg := range args {
if len(mi.fields.pk) == 1 {
if md, ok := arg.(Modeler); ok {
ind := reflect.Indirect(reflect.ValueOf(md))
if _, values, exist := d.existPk(mi, ind); exist {
arg = values[0]
} else {
panic(fmt.Sprintf("`%s` need a valid args value", operator))
}
}
}
params[i] = arg
}
if operator == "in" {
marks := make([]string, len(params))
for i, _ := range marks {
marks[i] = "?"
}
sql = fmt.Sprintf("IN (%s)", strings.Join(marks, ", "))
} else {
if len(params) > 1 {
panic(fmt.Sprintf("operator `%s` need 1 args not %d", operator, len(params)))
}
sql = operatorsSQL[operator]
arg := params[0]
switch operator {
case "iexact", "contains", "icontains", "startswith", "endswith", "istartswith", "iendswith":
param := strings.Replace(ToStr(arg), `%`, `\%`, -1)
switch operator {
case "iexact", "contains", "icontains":
param = fmt.Sprintf("%%%s%%", param)
case "startswith", "istartswith":
param = fmt.Sprintf("%s%%", param)
case "endswith", "iendswith":
param = fmt.Sprintf("%%%s", param)
}
params[0] = param
case "isnull":
if b, ok := arg.(bool); ok {
if b {
sql = "IS NULL"
} else {
sql = "IS NOT NULL"
}
params = nil
} else {
panic(fmt.Sprintf("operator `%s` need a bool value not `%T`", operator, arg))
}
}
}
return sql, params
}
func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}) {
for i, column := range cols {
val := reflect.Indirect(reflect.ValueOf(values[i])).Interface()
fi := mi.fields.GetByColumn(column)
field := ind.Field(fi.fieldIndex)
value, err := d.getValue(fi, val)
if err != nil {
panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error()))
}
_, err = d.setValue(fi, value, &field)
if err != nil {
panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error()))
}
}
}
func (d *dbBase) getValue(fi *fieldInfo, val interface{}) (interface{}, error) {
if val == nil {
return nil, nil
}
var value interface{}
var str *StrTo
switch v := val.(type) {
case []byte:
s := StrTo(string(v))
str = &s
case string:
s := StrTo(v)
str = &s
}
fieldType := fi.fieldType
setValue:
switch {
case fieldType == TypeBooleanField:
if str == nil {
switch v := val.(type) {
case int64:
b := v == 1
value = b
default:
s := StrTo(ToStr(v))
str = &s
}
}
if str != nil {
b, err := str.Bool()
if err != nil {
return nil, err
}
value = b
}
case fieldType == TypeCharField || fieldType == TypeTextField:
s := str.String()
if str == nil {
s = ToStr(val)
}
value = s
case fieldType == TypeDateField || fieldType == TypeDateTimeField:
if str == nil {
switch v := val.(type) {
case time.Time:
value = v
default:
s := StrTo(ToStr(v))
str = &s
}
}
if str != nil {
format := format_DateTime
if fi.fieldType == TypeDateField {
format = format_Date
}
s := str.String()
t, err := timeParse(s, format)
if err != nil && s != "0000-00-00" && s != "0000-00-00 00:00:00" {
return nil, err
}
value = t
}
case fieldType&IsIntegerField > 0:
if str == nil {
s := StrTo(ToStr(val))
str = &s
}
if str != nil {
var err error
switch fieldType {
case TypeSmallIntegerField:
_, err = str.Int16()
case TypeIntegerField:
_, err = str.Int32()
case TypeBigIntegerField:
_, err = str.Int64()
case TypePositiveSmallIntegerField:
_, err = str.Uint16()
case TypePositiveIntegerField:
_, err = str.Uint32()
case TypePositiveBigIntegerField:
_, err = str.Uint64()
}
if err != nil {
return nil, err
}
if fieldType&IsPostiveIntegerField > 0 {
v, _ := str.Uint64()
value = v
} else {
v, _ := str.Int64()
value = v
}
}
case fieldType == TypeFloatField || fieldType == TypeDecimalField:
if str == nil {
switch v := val.(type) {
case float64:
value = v
default:
s := StrTo(ToStr(v))
str = &s
}
}
if str != nil {
v, err := str.Float64()
if err != nil {
return nil, err
}
value = v
}
case fieldType&IsRelField > 0:
fieldType = fi.relModelInfo.fields.pk[0].fieldType
goto setValue
}
return value, nil
}
func (d *dbBase) setValue(fi *fieldInfo, value interface{}, field *reflect.Value) (interface{}, error) {
fieldType := fi.fieldType
isNative := fi.isFielder == false
setValue:
switch {
case fieldType == TypeBooleanField:
if isNative {
field.SetBool(value.(bool))
}
case fieldType == TypeCharField || fieldType == TypeTextField:
if isNative {
field.SetString(value.(string))
}
case fieldType == TypeDateField || fieldType == TypeDateTimeField:
if isNative {
field.Set(reflect.ValueOf(value))
}
case fieldType&IsIntegerField > 0:
if fieldType&IsPostiveIntegerField > 0 {
if isNative {
field.SetUint(value.(uint64))
}
} else {
if isNative {
field.SetInt(value.(int64))
}
}
case fieldType == TypeFloatField || fieldType == TypeDecimalField:
if isNative {
field.SetFloat(value.(float64))
}
case fieldType&IsRelField > 0:
fieldType = fi.relModelInfo.fields.pk[0].fieldType
mf := reflect.New(fi.relModelInfo.addrField.Elem().Type())
md := mf.Interface().(Modeler)
md.Init(md)
field.Set(mf)
f := mf.Elem().Field(fi.relModelInfo.fields.pk[0].fieldIndex)
field = &f
goto setValue
}
if isNative == false {
fd := field.Addr().Interface().(Fielder)
err := fd.SetRaw(value)
if err != nil {
return nil, err
}
}
return value, nil
}
func (d *dbBase) xsetValue(fi *fieldInfo, val interface{}, field *reflect.Value) (interface{}, error) {
if val == nil {
return nil, nil
}
var value interface{}
var str *StrTo
switch v := val.(type) {
case []byte:
s := StrTo(string(v))
str = &s
case string:
s := StrTo(v)
str = &s
}
fieldType := fi.fieldType
isNative := fi.isFielder == false
setValue:
switch {
case fieldType == TypeBooleanField:
if str == nil {
switch v := val.(type) {
case int64:
b := v == 1
if isNative {
field.SetBool(b)
}
value = b
default:
s := StrTo(ToStr(v))
str = &s
}
}
if str != nil {
b, err := str.Bool()
if err != nil {
return nil, err
}
if isNative {
field.SetBool(b)
}
value = b
}
case fieldType == TypeCharField || fieldType == TypeTextField:
s := str.String()
if str == nil {
s = ToStr(val)
}
if isNative {
field.SetString(s)
}
value = s
case fieldType == TypeDateField || fieldType == TypeDateTimeField:
if str == nil {
switch v := val.(type) {
case time.Time:
if isNative {
field.Set(reflect.ValueOf(v))
}
value = v
default:
s := StrTo(ToStr(v))
str = &s
}
}
if str != nil {
format := format_DateTime
if fi.fieldType == TypeDateField {
format = format_Date
}
t, err := timeParse(str.String(), format)
if err != nil {
return nil, err
}
if isNative {
field.Set(reflect.ValueOf(t))
}
value = t
}
case fieldType&IsIntegerField > 0:
if str == nil {
s := StrTo(ToStr(val))
str = &s
}
if str != nil {
var err error
switch fieldType {
case TypeSmallIntegerField:
value, err = str.Int16()
case TypeIntegerField:
value, err = str.Int32()
case TypeBigIntegerField:
value, err = str.Int64()
case TypePositiveSmallIntegerField:
value, err = str.Uint16()
case TypePositiveIntegerField:
value, err = str.Uint32()
case TypePositiveBigIntegerField:
value, err = str.Uint64()
}
if err != nil {
return nil, err
}
if fieldType&IsPostiveIntegerField > 0 {
v, _ := str.Uint64()
if isNative {
field.SetUint(v)
}
} else {
v, _ := str.Int64()
if isNative {
field.SetInt(v)
}
}
}
case fieldType == TypeFloatField || fieldType == TypeDecimalField:
if str == nil {
switch v := val.(type) {
case float64:
if isNative {
field.SetFloat(v)
}
value = v
default:
s := StrTo(ToStr(v))
str = &s
}
}
if str != nil {
v, err := str.Float64()
if err != nil {
return nil, err
}
if isNative {
field.SetFloat(v)
}
value = v
}
case fieldType&IsRelField > 0:
fieldType = fi.relModelInfo.fields.pk[0].fieldType
mf := reflect.New(fi.relModelInfo.addrField.Elem().Type())
md := mf.Interface().(Modeler)
md.Init(md)
field.Set(mf)
f := mf.Elem().Field(fi.relModelInfo.fields.pk[0].fieldIndex)
field = &f
goto setValue
}
if isNative == false {
fd := field.Addr().Interface().(Fielder)
err := fd.SetRaw(value)
if err != nil {
return nil, err
}
}
return value, nil
}
func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}) (int64, error) {
var (
maps []Params
lists []ParamsList
list ParamsList
)
typ := 0
switch container.(type) {
case *[]Params:
typ = 1
case *[]ParamsList:
typ = 2
case *ParamsList:
typ = 3
default:
panic(fmt.Sprintf("unsupport read values type `%T`", container))
}
tables := newDbTables(mi, d.ins)
var (
cols []string
infos []*fieldInfo
)
hasExprs := len(exprs) > 0
if hasExprs {
cols = make([]string, 0, len(exprs))
infos = make([]*fieldInfo, 0, len(exprs))
for _, ex := range exprs {
index, col, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep))
if suc == false {
panic(fmt.Errorf("unknown field/column name `%s`", ex))
}
cols = append(cols, fmt.Sprintf("%s.`%s`", index, col))
infos = append(infos, fi)
}
} else {
cols = make([]string, 0, len(mi.fields.dbcols))
infos = make([]*fieldInfo, 0, len(exprs))
for _, fi := range mi.fields.fieldsDB {
cols = append(cols, fmt.Sprintf("T0.`%s`", fi.column))
infos = append(infos, fi)
}
}
where, args := tables.getCondSql(cond, false)
orderBy := tables.getOrderSql(qs.orders)
limit := tables.getLimitSql(qs.offset, qs.limit)
join := tables.getJoinSql()
sels := strings.Join(cols, ", ")
query := fmt.Sprintf("SELECT %s FROM `%s` T0 %s%s%s%s", sels, mi.table, join, where, orderBy, limit)
var rs *sql.Rows
if r, err := q.Query(query, args...); err != nil {
return 0, err
} else {
rs = r
}
refs := make([]interface{}, len(cols))
for i, _ := range refs {
var ref string
refs[i] = &ref
}
var cnt int64
for rs.Next() {
if err := rs.Scan(refs...); err != nil {
return 0, err
}
switch typ {
case 1:
params := make(Params, len(cols))
for i, ref := range refs {
fi := infos[i]
val := reflect.Indirect(reflect.ValueOf(ref)).Interface()
value, err := d.getValue(fi, val)
if err != nil {
panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error()))
}
if hasExprs {
params[exprs[i]] = value
} else {
params[mi.fields.dbcols[i]] = value
}
}
maps = append(maps, params)
case 2:
params := make(ParamsList, 0, len(cols))
for i, ref := range refs {
fi := infos[i]
val := reflect.Indirect(reflect.ValueOf(ref)).Interface()
value, err := d.getValue(fi, val)
if err != nil {
panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error()))
}
params = append(params, value)
}
lists = append(lists, params)
case 3:
for i, ref := range refs {
fi := infos[i]
val := reflect.Indirect(reflect.ValueOf(ref)).Interface()
value, err := d.getValue(fi, val)
if err != nil {
panic(fmt.Sprintf("db value convert failed `%v` %s", val, err.Error()))
}
list = append(list, value)
}
}
cnt++
}
switch v := container.(type) {
case *[]Params:
*v = maps
case *[]ParamsList:
*v = lists
case *ParamsList:
*v = list
}
return cnt, nil
}
package orm
import (
"database/sql"
"fmt"
"os"
"sync"
)
const defaultMaxIdle = 30
type driverType int
const (
_ driverType = iota
DR_MySQL
DR_Sqlite
DR_Oracle
DR_Postgres
)
var (
dataBaseCache = &_dbCache{cache: make(map[string]*alias)}
drivers = make(map[string]driverType)
dbBasers = map[driverType]dbBaser{
DR_MySQL: newdbBaseMysql(),
DR_Sqlite: newdbBaseSqlite(),
DR_Oracle: newdbBaseMysql(),
DR_Postgres: newdbBasePostgres(),
}
)
type _dbCache struct {
mux sync.RWMutex
cache map[string]*alias
}
func (ac *_dbCache) add(name string, al *alias) (added bool) {
ac.mux.Lock()
defer ac.mux.Unlock()
if _, ok := ac.cache[name]; ok == false {
ac.cache[name] = al
added = true
}
return
}
func (ac *_dbCache) get(name string) (al *alias, ok bool) {
ac.mux.RLock()
defer ac.mux.RUnlock()
al, ok = ac.cache[name]
return
}
func (ac *_dbCache) getDefault() (al *alias) {
al, _ = ac.get("default")
return
}
type alias struct {
Name string
DriverName string
DataSource string
MaxIdle int
DB *sql.DB
DbBaser dbBaser
}
func RegisterDataBase(name, driverName, dataSource string, maxIdle int) {
if maxIdle <= 0 {
maxIdle = defaultMaxIdle
}
al := new(alias)
al.Name = name
al.DriverName = driverName
al.DataSource = dataSource
al.MaxIdle = maxIdle
var (
err error
)
if dr, ok := drivers[driverName]; ok {
al.DbBaser = dbBasers[dr]
} else {
err = fmt.Errorf("driver name `%s` have not registered", driverName)
goto end
}
if dataBaseCache.add(name, al) == false {
err = fmt.Errorf("db name `%s` already registered, cannot reuse", name)
goto end
}
al.DB, err = sql.Open(driverName, dataSource)
if err != nil {
err = fmt.Errorf("register db `%s`, %s", name, err.Error())
goto end
}
err = al.DB.Ping()
if err != nil {
err = fmt.Errorf("register db `%s`, %s", name, err.Error())
goto end
}
end:
if err != nil {
fmt.Println(err.Error())
os.Exit(2)
}
}
func RegisterDriver(name string, typ driverType) {
if _, ok := drivers[name]; ok == false {
drivers[name] = typ
} else {
fmt.Println("name `%s` db driver already registered")
os.Exit(2)
}
}
func init() {
// RegisterDriver("mysql", DR_MySQL)
RegisterDriver("mymysql", DR_MySQL)
}
package orm
type dbBaseMysql struct {
dbBase
}
func (d *dbBaseMysql) GetOperatorSql(mi *modelInfo, operator string, args []interface{}) (sql string, params []interface{}) {
return d.dbBase.GetOperatorSql(mi, operator, args)
}
func newdbBaseMysql() dbBaser {
b := new(dbBaseMysql)
b.ins = b
return b
}
package orm
type dbBaseOracle struct {
dbBase
}
func newdbBaseOracle() dbBaser {
b := new(dbBaseOracle)
b.ins = b
return b
}
package orm
type dbBasePostgres struct {
dbBase
}
func newdbBasePostgres() dbBaser {
b := new(dbBasePostgres)
b.ins = b
return b
}
package orm
type dbBaseSqlite struct {
dbBase
}
func newdbBaseSqlite() dbBaser {
b := new(dbBaseSqlite)
b.ins = b
return b
}
package orm
import (
"log"
"os"
"sync"
)
const (
od_CASCADE = "cascade"
od_SET_NULL = "set_null"
od_SET_DEFAULT = "set_default"
od_DO_NOTHING = "do_nothing"
defaultStructTagName = "orm"
)
var (
errLog *log.Logger
modelCache = &_modelCache{cache: make(map[string]*modelInfo)}
supportTag = map[string]int{
"null": 1,
"blank": 1,
"index": 1,
"unique": 1,
"pk": 1,
"auto": 1,
"auto_now": 1,
"auto_now_add": 1,
"max_length": 2,
"choices": 2,
"column": 2,
"default": 2,
"rel": 2,
"reverse": 2,
"rel_table": 2,
"rel_through": 2,
"digits": 2,
"decimals": 2,
"on_delete": 2,
}
)
func init() {
errLog = log.New(os.Stderr, "[ORM] ", log.Ldate|log.Ltime|log.Lshortfile)
}
type _modelCache struct {
sync.RWMutex
orders []string
cache map[string]*modelInfo
}
func (mc *_modelCache) all() map[string]*modelInfo {
m := make(map[string]*modelInfo, len(mc.cache))
for k, v := range mc.cache {
m[k] = v
}
return m
}
func (mc *_modelCache) allOrdered() []*modelInfo {
m := make([]*modelInfo, 0, len(mc.orders))
for _, v := range mc.cache {
m = append(m, v)
}
return m
}
func (mc *_modelCache) get(table string) (mi *modelInfo, ok bool) {
mi, ok = mc.cache[table]
return
}
func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
mii := mc.cache[table]
mc.cache[table] = mi
if mii == nil {
mc.orders = append(mc.orders, table)
}
return mii
}
package orm
import (
"errors"
"fmt"
"os"
"reflect"
"strings"
)
func RegisterModel(model Modeler) {
info := newModelInfo(model)
model.Init(model)
table := model.GetTableName()
if _, ok := modelCache.get(table); ok {
fmt.Printf("model <%T> redeclared, must be unique\n", model)
os.Exit(2)
}
if info.fields.pk == nil {
fmt.Printf("model <%T> need a primary key field\n", model)
os.Exit(2)
}
info.table = table
info.pkg = getPkgPath(model)
info.model = model
info.manual = true
modelCache.set(table, info)
}
func BootStrap() {
modelCache.Lock()
defer modelCache.Unlock()
var (
err error
models map[string]*modelInfo
)
if dataBaseCache.getDefault() == nil {
err = fmt.Errorf("must have one register alias named `default`")
goto end
}
models = modelCache.all()
for _, mi := range models {
for _, fi := range mi.fields.columns {
if fi.rel || fi.reverse {
elm := fi.addrValue.Type().Elem()
switch fi.fieldType {
case RelReverseMany, RelManyToMany:
elm = elm.Elem()
}
tn := getTableName(reflect.New(elm).Interface().(Modeler))
mii, ok := modelCache.get(tn)
if ok == false || mii.pkg != elm.PkgPath() {
err = fmt.Errorf("can not found rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String())
goto end
}
fi.relModelInfo = mii
if fi.rel {
if mii.fields.pk.IsMulti() {
err = fmt.Errorf("field `%s` unsupport rel to multi primary key field", fi.fullName)
goto end
}
}
switch fi.fieldType {
case RelManyToMany:
if fi.relThrough != "" {
msg := fmt.Sprintf("filed `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough)
if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) {
pn := fi.relThrough[:i]
mn := fi.relThrough[i+1:]
tn := snakeString(mn)
rmi, ok := modelCache.get(tn)
if ok == false || pn != rmi.pkg {
err = errors.New(msg + " cannot find table")
goto end
}
fi.relThroughModelInfo = rmi
fi.relTable = rmi.table
} else {
err = errors.New(msg)
goto end
}
err = nil
} else {
i := newM2MModelInfo(mi, mii)
if fi.relTable != "" {
i.table = fi.relTable
}
if v := modelCache.set(i.table, i); v != nil {
err = fmt.Errorf("the rel table name `%s` already registered, cannot be use, please change one", fi.relTable)
goto end
}
fi.relTable = i.table
fi.relThroughModelInfo = i
}
}
}
}
}
models = modelCache.all()
for _, mi := range models {
for _, fi := range mi.fields.fieldsRel {
switch fi.fieldType {
case RelForeignKey, RelOneToOne, RelManyToMany:
inModel := false
for _, ffi := range fi.relModelInfo.fields.fieldsReverse {
if ffi.relModelInfo == mi {
inModel = true
break
}
}
if inModel == false {
rmi := fi.relModelInfo
ffi := new(fieldInfo)
ffi.name = mi.name
ffi.column = ffi.name
ffi.fullName = rmi.fullName + "." + ffi.name
ffi.reverse = true
ffi.relModelInfo = mi
ffi.mi = rmi
if fi.fieldType == RelOneToOne {
ffi.fieldType = RelReverseOne
} else {
ffi.fieldType = RelReverseMany
}
if rmi.fields.Add(ffi) == false {
added := false
for cnt := 0; cnt < 5; cnt++ {
ffi.name = fmt.Sprintf("%s%d", mi.name, cnt)
ffi.column = ffi.name
ffi.fullName = rmi.fullName + "." + ffi.name
if added = rmi.fields.Add(ffi); added {
break
}
}
if added == false {
panic(fmt.Sprintf("cannot generate auto reverse field info `%s` to `%s`", fi.fullName, ffi.fullName))
}
}
}
}
}
}
for _, mi := range models {
if fields, ok := mi.fields.fieldsByType[RelReverseOne]; ok {
for _, fi := range fields {
found := false
mForA:
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelOneToOne] {
if ffi.relModelInfo == mi {
found = true
fi.reverseField = ffi.name
fi.reverseFieldInfo = ffi
break mForA
}
}
if found == false {
err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
goto end
}
}
}
if fields, ok := mi.fields.fieldsByType[RelReverseMany]; ok {
for _, fi := range fields {
found := false
mForB:
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelForeignKey] {
if ffi.relModelInfo == mi {
found = true
fi.reverseField = ffi.name
fi.reverseFieldInfo = ffi
break mForB
}
}
if found == false {
mForC:
for _, ffi := range fi.relModelInfo.fields.fieldsByType[RelManyToMany] {
if ffi.relModelInfo == mi {
found = true
fi.reverseField = ffi.name
fi.reverseFieldInfo = ffi
break mForC
}
}
}
if found == false {
err = fmt.Errorf("reverse field `%s` not found in model `%s`", fi.fullName, fi.relModelInfo.fullName)
goto end
}
}
}
}
end:
if err != nil {
fmt.Println(err)
os.Exit(2)
}
runCommand()
}
package orm
import (
"errors"
"fmt"
"strconv"
"time"
)
const (
// bool
TypeBooleanField = 1 << iota
// string
TypeCharField
// string
TypeTextField
// time.Time
TypeDateField
// time.Time
TypeDateTimeField
// int16
TypeSmallIntegerField
// int32
TypeIntegerField
// int64
TypeBigIntegerField
// uint16
TypePositiveSmallIntegerField
// uint32
TypePositiveIntegerField
// uint64
TypePositiveBigIntegerField
// float64
TypeFloatField
// float64
TypeDecimalField
RelForeignKey
RelOneToOne
RelManyToMany
RelReverseOne
RelReverseMany
)
const (
IsIntegerField = ^-TypePositiveBigIntegerField >> 4 << 5
IsPostiveIntegerField = ^-TypePositiveBigIntegerField >> 7 << 8
IsRelField = ^-RelReverseMany >> 12 << 13
IsFieldType = ^-RelReverseMany<<1 + 1
)
// A true/false field.
type BooleanField bool
func (e BooleanField) Value() bool {
return bool(e)
}
func (e *BooleanField) Set(d bool) {
*e = BooleanField(d)
}
func (e *BooleanField) String() string {
return strconv.FormatBool(e.Value())
}
func (e *BooleanField) FieldType() int {
return TypeBooleanField
}
func (e *BooleanField) SetRaw(value interface{}) error {
switch d := value.(type) {
case bool:
e.Set(d)
case string:
v, err := StrTo(d).Bool()
if err != nil {
e.Set(v)
}
return err
default:
return errors.New(fmt.Sprintf("<BooleanField.SetRaw> unknown value `%s`", value))
}
return nil
}
func (e *BooleanField) RawValue() interface{} {
return e.Value()
}
// A string field
// required values tag: max_length
// The max_length is enforced at the database level and in models’s validation.
// eg: `max_length:"120"`
type CharField string
func (e CharField) Value() string {
return string(e)
}
func (e *CharField) Set(d string) {
*e = CharField(d)
}
func (e *CharField) String() string {
return e.Value()
}
func (e *CharField) FieldType() int {
return TypeCharField
}
func (e *CharField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
e.Set(d)
default:
return errors.New(fmt.Sprintf("<CharField.SetRaw> unknown value `%s`", value))
}
return nil
}
func (e *CharField) RawValue() interface{} {
return e.Value()
}
// A date, represented in go by a time.Time instance.
// only date values like 2006-01-02
// Has a few extra, optional attr tag:
//
// auto_now:
// Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps.
// Note that the current date is always used; it’s not just a default value that you can override.
//
// auto_now_add:
// Automatically set the field to now when the object is first created. Useful for creation of timestamps.
// Note that the current date is always used; it’s not just a default value that you can override.
//
// eg: `attr:"auto_now"` or `attr:"auto_now_add"`
type DateField time.Time
func (e DateField) Value() time.Time {
return time.Time(e)
}
func (e *DateField) Set(d time.Time) {
*e = DateField(d)
}
func (e *DateField) String() string {
return e.Value().String()
}
func (e *DateField) FieldType() int {
return TypeDateField
}
func (e *DateField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
v, err := timeParse(d, format_Date)
if err != nil {
e.Set(v)
}
return err
default:
return errors.New(fmt.Sprintf("<DateField.SetRaw> unknown value `%s`", value))
}
return nil
}
func (e *DateField) RawValue() interface{} {
return e.Value()
}
// A date, represented in go by a time.Time instance.
// datetime values like 2006-01-02 15:04:05
// Takes the same extra arguments as DateField.
type DateTimeField time.Time
func (e DateTimeField) Value() time.Time {
return time.Time(e)
}
func (e *DateTimeField) Set(d time.Time) {
*e = DateTimeField(d)
}
func (e *DateTimeField) String() string {
return e.Value().String()
}
func (e *DateTimeField) FieldType() int {
return TypeDateTimeField
}
func (e *DateTimeField) SetRaw(value interface{}) error {
switch d := value.(type) {
case time.Time:
e.Set(d)
case string:
v, err := timeParse(d, format_DateTime)
if err != nil {
e.Set(v)
}
return err
default:
return errors.New(fmt.Sprintf("<DateTimeField.SetRaw> unknown value `%s`", value))
}
return nil
}
func (e *DateTimeField) RawValue() interface{} {
return e.Value()
}
// A floating-point number represented in go by a float32 value.
type FloatField float64
func (e FloatField) Value() float64 {
return float64(e)
}
func (e *FloatField) Set(d float64) {
*e = FloatField(d)
}
func (e *FloatField) String() string {
return ToStr(e.Value(), -1, 32)
}
func (e *FloatField) FieldType() int {
return TypeFloatField
}
func (e *FloatField) SetRaw(value interface{}) error {
switch d := value.(type) {
case float32:
e.Set(float64(d))
case float64:
e.Set(d)
case string:
v, err := StrTo(d).Float64()
if err != nil {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<FloatField.SetRaw> unknown value `%s`", value))
}
return nil
}
func (e *FloatField) RawValue() interface{} {
return e.Value()
}
// -32768 to 32767
type SmallIntegerField int16
func (e SmallIntegerField) Value() int16 {
return int16(e)
}
func (e *SmallIntegerField) Set(d int16) {
*e = SmallIntegerField(d)
}
func (e *SmallIntegerField) String() string {
return ToStr(e.Value())
}
func (e *SmallIntegerField) FieldType() int {
return TypeSmallIntegerField
}
func (e *SmallIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int16:
e.Set(d)
case string:
v, err := StrTo(d).Int16()
if err != nil {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<SmallIntegerField.SetRaw> unknown value `%s`", value))
}
return nil
}
func (e *SmallIntegerField) RawValue() interface{} {
return e.Value()
}
// -2147483648 to 2147483647
type IntegerField int32
func (e IntegerField) Value() int32 {
return int32(e)
}
func (e *IntegerField) Set(d int32) {
*e = IntegerField(d)
}
func (e *IntegerField) String() string {
return ToStr(e.Value())
}
func (e *IntegerField) FieldType() int {
return TypeIntegerField
}
func (e *IntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int32:
e.Set(d)
case string:
v, err := StrTo(d).Int32()
if err != nil {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<IntegerField.SetRaw> unknown value `%s`", value))
}
return nil
}
func (e *IntegerField) RawValue() interface{} {
return e.Value()
}
// -9223372036854775808 to 9223372036854775807.
type BigIntegerField int64
func (e BigIntegerField) Value() int64 {
return int64(e)
}
func (e *BigIntegerField) Set(d int64) {
*e = BigIntegerField(d)
}
func (e *BigIntegerField) String() string {
return ToStr(e.Value())
}
func (e *BigIntegerField) FieldType() int {
return TypeBigIntegerField
}
func (e *BigIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case int64:
e.Set(d)
case string:
v, err := StrTo(d).Int64()
if err != nil {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<BigIntegerField.SetRaw> unknown value `%s`", value))
}
return nil
}
func (e *BigIntegerField) RawValue() interface{} {
return e.Value()
}
// 0 to 65535
type PositiveSmallIntegerField uint16
func (e PositiveSmallIntegerField) Value() uint16 {
return uint16(e)
}
func (e *PositiveSmallIntegerField) Set(d uint16) {
*e = PositiveSmallIntegerField(d)
}
func (e *PositiveSmallIntegerField) String() string {
return ToStr(e.Value())
}
func (e *PositiveSmallIntegerField) FieldType() int {
return TypePositiveSmallIntegerField
}
func (e *PositiveSmallIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint16:
e.Set(d)
case string:
v, err := StrTo(d).Uint16()
if err != nil {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<PositiveSmallIntegerField.SetRaw> unknown value `%s`", value))
}
return nil
}
func (e *PositiveSmallIntegerField) RawValue() interface{} {
return e.Value()
}
// 0 to 4294967295
type PositiveIntegerField uint32
func (e PositiveIntegerField) Value() uint32 {
return uint32(e)
}
func (e *PositiveIntegerField) Set(d uint32) {
*e = PositiveIntegerField(d)
}
func (e *PositiveIntegerField) String() string {
return ToStr(e.Value())
}
func (e *PositiveIntegerField) FieldType() int {
return TypePositiveIntegerField
}
func (e *PositiveIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint32:
e.Set(d)
case string:
v, err := StrTo(d).Uint32()
if err != nil {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<PositiveIntegerField.SetRaw> unknown value `%s`", value))
}
return nil
}
func (e *PositiveIntegerField) RawValue() interface{} {
return e.Value()
}
// 0 to 18446744073709551615
type PositiveBigIntegerField uint64
func (e PositiveBigIntegerField) Value() uint64 {
return uint64(e)
}
func (e *PositiveBigIntegerField) Set(d uint64) {
*e = PositiveBigIntegerField(d)
}
func (e *PositiveBigIntegerField) String() string {
return ToStr(e.Value())
}
func (e *PositiveBigIntegerField) FieldType() int {
return TypePositiveIntegerField
}
func (e *PositiveBigIntegerField) SetRaw(value interface{}) error {
switch d := value.(type) {
case uint64:
e.Set(d)
case string:
v, err := StrTo(d).Uint64()
if err != nil {
e.Set(v)
}
default:
return errors.New(fmt.Sprintf("<PositiveBigIntegerField.SetRaw> unknown value `%s`", value))
}
return nil
}
func (e *PositiveBigIntegerField) RawValue() interface{} {
return e.Value()
}
// A large text field.
type TextField string
func (e TextField) Value() string {
return string(e)
}
func (e *TextField) Set(d string) {
*e = TextField(d)
}
func (e *TextField) String() string {
return e.Value()
}
func (e *TextField) FieldType() int {
return TypeTextField
}
func (e *TextField) SetRaw(value interface{}) error {
switch d := value.(type) {
case string:
e.Set(d)
default:
return errors.New(fmt.Sprintf("<TextField.SetRaw> unknown value `%s`", value))
}
return nil
}
func (e *TextField) RawValue() interface{} {
return e.Value()
}
package orm
import (
"errors"
"fmt"
"reflect"
"strings"
)
type fieldChoices []StrTo
func (f *fieldChoices) Add(s StrTo) {
if f.Have(s) == false {
*f = append(*f, s)
}
}
func (f *fieldChoices) Clear() {
*f = fieldChoices([]StrTo{})
}
func (f *fieldChoices) Have(s StrTo) bool {
for _, v := range *f {
if v == s {
return true
}
}
return false
}
func (f *fieldChoices) Clone() fieldChoices {
return *f
}
type primaryKeys []*fieldInfo
func (p *primaryKeys) Add(fi *fieldInfo) {
*p = append(*p, fi)
}
func (p primaryKeys) Exist(fi *fieldInfo) (int, bool) {
for i, v := range p {
if v == fi {
return i, true
}
}
return -1, false
}
func (p primaryKeys) IsMulti() bool {
return len(p) > 1
}
func (p primaryKeys) IsEmpty() bool {
return len(p) == 0
}
type fields struct {
pk primaryKeys
auto *fieldInfo
columns map[string]*fieldInfo
fields map[string]*fieldInfo
fieldsLow map[string]*fieldInfo
fieldsByType map[int][]*fieldInfo
fieldsRel []*fieldInfo
fieldsReverse []*fieldInfo
fieldsDB []*fieldInfo
rels []*fieldInfo
orders []string
dbcols []string
}
func (f *fields) Add(fi *fieldInfo) (added bool) {
if f.fields[fi.name] == nil && f.columns[fi.column] == nil {
f.columns[fi.column] = fi
f.fields[fi.name] = fi
f.fieldsLow[strings.ToLower(fi.name)] = fi
} else {
return
}
if _, ok := f.fieldsByType[fi.fieldType]; ok == false {
f.fieldsByType[fi.fieldType] = make([]*fieldInfo, 0)
}
f.fieldsByType[fi.fieldType] = append(f.fieldsByType[fi.fieldType], fi)
f.orders = append(f.orders, fi.column)
if fi.dbcol {
f.dbcols = append(f.dbcols, fi.column)
f.fieldsDB = append(f.fieldsDB, fi)
}
if fi.rel {
f.fieldsRel = append(f.fieldsRel, fi)
}
if fi.reverse {
f.fieldsReverse = append(f.fieldsReverse, fi)
}
return true
}
func (f *fields) GetByName(name string) *fieldInfo {
return f.fields[name]
}
func (f *fields) GetByColumn(column string) *fieldInfo {
return f.columns[column]
}
func (f *fields) GetByAny(name string) (*fieldInfo, bool) {
if fi, ok := f.fields[name]; ok {
return fi, ok
}
if fi, ok := f.fieldsLow[strings.ToLower(name)]; ok {
return fi, ok
}
if fi, ok := f.columns[name]; ok {
return fi, ok
}
return nil, false
}
func newFields() *fields {
f := new(fields)
f.fields = make(map[string]*fieldInfo)
f.fieldsLow = make(map[string]*fieldInfo)
f.columns = make(map[string]*fieldInfo)
f.fieldsByType = make(map[int][]*fieldInfo)
return f
}
type fieldInfo struct {
mi *modelInfo
fieldIndex int
fieldType int
dbcol bool
inModel bool
name string
fullName string
column string
addrValue *reflect.Value
sf *reflect.StructField
auto bool
pk bool
null bool
blank bool
index bool
unique bool
initial StrTo
choices fieldChoices
maxLength int
auto_now bool
auto_now_add bool
rel bool
reverse bool
reverseField string
reverseFieldInfo *fieldInfo
relTable string
relThrough string
relThroughModelInfo *modelInfo
relModelInfo *modelInfo
digits int
decimals int
isFielder bool
onDelete string
}
func newFieldInfo(mi *modelInfo, field reflect.Value, sf reflect.StructField) (fi *fieldInfo, err error) {
var (
tag string
tagValue string
choices fieldChoices
values fieldChoices
initial StrTo
fieldType int
attrs map[string]bool
tags map[string]string
parts []string
addrField reflect.Value
)
fi = new(fieldInfo)
if field.Kind() != reflect.Ptr && field.Kind() != reflect.Slice && field.CanAddr() {
addrField = field.Addr()
} else {
addrField = field
}
parseStructTag(sf.Tag.Get(defaultStructTagName), &attrs, &tags)
digits := tags["digits"]
decimals := tags["decimals"]
maxLength := tags["max_length"]
onDelete := tags["on_delete"]
checkType:
switch f := addrField.Interface().(type) {
case Fielder:
fi.isFielder = true
if field.Kind() == reflect.Ptr {
err = fmt.Errorf("the model Fielder can not be use ptr")
goto end
}
fieldType = f.FieldType()
if fieldType&IsRelField > 0 {
err = fmt.Errorf("unsupport rel type custom field")
goto end
}
default:
tag = "rel"
tagValue = tags[tag]
if tagValue != "" {
switch tagValue {
case "fk":
fieldType = RelForeignKey
break checkType
case "one":
fieldType = RelOneToOne
break checkType
case "m2m":
fieldType = RelManyToMany
if tv := tags["rel_table"]; tv != "" {
fi.relTable = tv
} else if tv := tags["rel_through"]; tv != "" {
fi.relThrough = tv
}
break checkType
default:
err = fmt.Errorf("error")
goto wrongTag
}
}
tag = "reverse"
tagValue = tags[tag]
if tagValue != "" {
switch tagValue {
case "one":
fieldType = RelReverseOne
break checkType
case "many":
fieldType = RelReverseMany
break checkType
default:
err = fmt.Errorf("error")
goto wrongTag
}
}
fieldType, err = getFieldType(addrField)
if err != nil {
goto end
}
if fieldType == TypeTextField && maxLength != "" {
fieldType = TypeCharField
}
if fieldType == TypeFloatField && (digits != "" || decimals != "") {
fieldType = TypeDecimalField
}
if fieldType == TypeDateTimeField && attrs["date"] {
fieldType = TypeDateField
}
}
switch fieldType {
case RelForeignKey, RelOneToOne, RelReverseOne:
if _, ok := addrField.Interface().(Modeler); ok == false {
err = fmt.Errorf("rel/reverse:one field must be implements Modeler")
goto end
}
if field.Kind() != reflect.Ptr {
err = fmt.Errorf("rel/reverse:one field must be *%s", field.Type().Name())
goto end
}
case RelManyToMany, RelReverseMany:
if field.Kind() != reflect.Slice {
err = fmt.Errorf("rel/reverse:many field must be slice")
goto end
} else {
if field.Type().Elem().Kind() != reflect.Ptr {
err = fmt.Errorf("rel/reverse:many slice must be []*%s", field.Type().Elem().Name())
goto end
}
if _, ok := reflect.New(field.Type().Elem()).Elem().Interface().(Modeler); ok == false {
err = fmt.Errorf("rel/reverse:many slice element must be implements Modeler")
goto end
}
}
}
if fieldType&IsFieldType == 0 {
err = fmt.Errorf("wrong field type")
goto end
}
fi.fieldType = fieldType
fi.name = sf.Name
fi.column = getColumnName(fieldType, addrField, sf, tags["column"])
fi.addrValue = &addrField
fi.sf = &sf
fi.fullName = mi.fullName + "." + sf.Name
fi.null = attrs["null"]
fi.blank = attrs["blank"]
fi.index = attrs["index"]
fi.auto = attrs["auto"]
fi.pk = attrs["pk"]
fi.unique = attrs["unique"]
switch fieldType {
case RelManyToMany, RelReverseMany, RelReverseOne:
fi.null = false
fi.blank = false
fi.index = false
fi.auto = false
fi.pk = false
fi.unique = false
default:
fi.dbcol = true
}
switch fieldType {
case RelForeignKey, RelOneToOne, RelManyToMany:
fi.rel = true
if fieldType == RelOneToOne {
fi.unique = true
}
case RelReverseMany, RelReverseOne:
fi.reverse = true
}
if fi.rel && fi.dbcol {
switch onDelete {
case od_CASCADE, od_DO_NOTHING:
case od_SET_DEFAULT:
if tags["default"] == "" {
err = errors.New("on_delete: set_default need set field a default value")
goto end
}
case od_SET_NULL:
if fi.null == false {
err = errors.New("on_delete: set_null need set field null")
goto end
}
default:
if onDelete == "" {
onDelete = od_CASCADE
} else {
err = fmt.Errorf("on_delete value expected choice in `cascade,set_null,set_default,do_nothing`, unknown `%s`", onDelete)
goto end
}
}
fi.onDelete = onDelete
}
switch fieldType {
case TypeBooleanField:
case TypeCharField:
if maxLength != "" {
v, e := StrTo(maxLength).Int32()
if e != nil {
err = fmt.Errorf("wrong maxLength value `%s`", maxLength)
} else {
fi.maxLength = int(v)
}
} else {
err = fmt.Errorf("maxLength must be specify")
}
case TypeTextField:
fi.index = false
fi.unique = false
case TypeDateField, TypeDateTimeField:
if attrs["auto_now"] {
fi.auto_now = true
} else if attrs["auto_now_add"] {
fi.auto_now_add = true
}
case TypeFloatField:
case TypeDecimalField:
d1 := digits
d2 := decimals
v1, er1 := StrTo(d1).Int16()
v2, er2 := StrTo(d2).Int16()
if er1 != nil || er2 != nil {
err = fmt.Errorf("wrong digits/decimals value %s/%s", d2, d1)
goto end
}
fi.digits = int(v1)
fi.decimals = int(v2)
default:
switch {
case fieldType&IsIntegerField > 0:
case fieldType&IsRelField > 0:
}
}
if fieldType&IsIntegerField == 0 {
if fi.auto {
err = fmt.Errorf("non-integer type cannot set auto")
goto end
}
if fi.pk || fi.index || fi.unique {
if fieldType != TypeCharField && fieldType != RelOneToOne {
err = fmt.Errorf("cannot set pk/index/unique")
goto end
}
}
}
if fi.auto || fi.pk {
if fi.auto {
fi.pk = true
}
fi.null = false
fi.blank = false
fi.index = false
fi.unique = false
}
if fi.unique {
fi.null = false
fi.blank = false
fi.index = false
}
parts = strings.Split(tags["choices"], ",")
if len(parts) > 1 {
for _, v := range parts {
choices.Add(StrTo(strings.TrimSpace(v)))
}
}
initial.Clear()
if v, ok := tags["default"]; ok {
initial.Set(v)
}
if fi.auto || fi.pk || fi.unique || fieldType == TypeDateField || fieldType == TypeDateTimeField {
// can not set default
choices.Clear()
initial.Clear()
}
values = choices.Clone()
if initial.Exist() {
values.Add(initial)
}
for i, v := range values {
switch fieldType {
case TypeBooleanField:
_, err = v.Bool()
case TypeFloatField, TypeDecimalField:
_, err = v.Float64()
case TypeSmallIntegerField:
_, err = v.Int16()
case TypeIntegerField:
_, err = v.Int32()
case TypeBigIntegerField:
_, err = v.Int64()
case TypePositiveSmallIntegerField:
_, err = v.Uint16()
case TypePositiveIntegerField:
_, err = v.Uint32()
case TypePositiveBigIntegerField:
_, err = v.Uint64()
}
if err != nil {
if initial.Exist() && len(values) == i {
tag, tagValue = "default", tags["default"]
} else {
tag, tagValue = "choices", tags["choices"]
}
goto wrongTag
}
}
if len(choices) > 0 && initial.Exist() {
if choices.Have(initial) == false {
err = fmt.Errorf("default value `%s` not in choices `%s`", tags["default"], tags["choices"])
goto end
}
}
fi.choices = choices
fi.initial = initial
end:
if err != nil {
return nil, err
}
return
wrongTag:
return nil, fmt.Errorf("wrong tag format: `%s:\"%s\"`, %s", tag, tagValue, err)
}
package orm
import (
"errors"
"fmt"
"os"
"reflect"
)
type modelInfo struct {
pkg string
name string
fullName string
table string
model Modeler
fields *fields
manual bool
addrField reflect.Value
}
func newModelInfo(model Modeler) (info *modelInfo) {
var (
err error
fi *fieldInfo
sf reflect.StructField
)
info = &modelInfo{}
info.fields = newFields()
val := reflect.ValueOf(model)
ind := reflect.Indirect(val)
typ := ind.Type()
info.addrField = ind.Addr()
info.name = typ.Name()
info.fullName = typ.PkgPath() + "." + typ.Name()
for i := 0; i < ind.NumField(); i++ {
field := ind.Field(i)
sf = ind.Type().Field(i)
if field.CanAddr() {
addr := field.Addr()
if _, ok := addr.Interface().(*Manager); ok {
continue
}
}
fi, err = newFieldInfo(info, field, sf)
if err != nil {
break
}
added := info.fields.Add(fi)
if added == false {
err = errors.New(fmt.Sprintf("duplicate column name: %s", fi.column))
break
}
if fi.pk {
if info.fields.pk != nil {
err = errors.New(fmt.Sprintf("one model must have one pk field only"))
break
} else {
info.fields.pk.Add(fi)
}
}
if fi.auto {
info.fields.auto = fi
}
fi.fieldIndex = i
fi.mi = info
}
if _, ok := info.fields.pk.Exist(info.fields.auto); info.fields.auto != nil && ok == false {
err = errors.New(fmt.Sprintf("when auto field exists, you cannot set other pk field"))
goto end
}
if err != nil {
fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err))
os.Exit(2)
}
end:
if err != nil {
fmt.Println(err)
os.Exit(2)
}
return
}
func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
info = new(modelInfo)
info.fields = newFields()
info.table = m1.table + "_" + m2.table + "_rel"
info.name = camelString(info.table)
info.fullName = m1.pkg + "." + info.name
fa := new(fieldInfo)
f1 := new(fieldInfo)
f2 := new(fieldInfo)
fa.fieldType = TypeBigIntegerField
fa.auto = true
fa.pk = true
fa.dbcol = true
f1.dbcol = true
f2.dbcol = true
f1.fieldType = RelForeignKey
f2.fieldType = RelForeignKey
f1.name = camelString(m1.table)
f2.name = camelString(m2.table)
f1.fullName = info.fullName + "." + f1.name
f2.fullName = info.fullName + "." + f2.name
f1.column = m1.table + "_id"
f2.column = m2.table + "_id"
f1.rel = true
f2.rel = true
f1.relTable = m1.table
f2.relTable = m2.table
f1.relModelInfo = m1
f2.relModelInfo = m2
f1.mi = info
f2.mi = info
info.fields.Add(fa)
info.fields.Add(f1)
info.fields.Add(f2)
info.fields.pk.Add(fa)
return
}
package orm
import ()
// non cleaned field errors
type FieldErrors map[string]error
func (fe FieldErrors) Get(name string) error {
return fe[name]
}
func (fe FieldErrors) Set(name string, value error) {
fe[name] = value
}
type Manager struct {
ins Modeler
inited bool
}
// func (m *Manager) init(model reflect.Value) {
// elm := model.Elem()
// for i := 0; i < elm.NumField(); i++ {
// field := elm.Field(i)
// if _, ok := field.Interface().(Fielder); ok && field.CanSet() {
// if field.Elem().Kind() != reflect.Struct {
// field.Set(reflect.New(field.Type().Elem()))
// }
// }
// }
// }
func (m *Manager) Init(model Modeler) Modeler {
if m.inited {
return m.ins
}
m.inited = true
m.ins = model
return model
}
func (m *Manager) IsInited() bool {
return m.inited
}
func (m *Manager) Clean() FieldErrors {
return nil
}
func (m *Manager) CleanFields(name string) FieldErrors {
return nil
}
func (m *Manager) GetTableName() string {
return getTableName(m.ins)
}
package orm
import (
"fmt"
"reflect"
"strings"
"time"
)
func getTableName(model Modeler) string {
val := reflect.ValueOf(model)
ind := reflect.Indirect(val)
fun := val.MethodByName("TableName")
if fun.IsValid() {
vals := fun.Call([]reflect.Value{})
if len(vals) > 0 {
val := vals[0]
if val.Kind() == reflect.String {
return val.String()
}
}
}
return snakeString(ind.Type().Name())
}
func getPkgPath(model Modeler) string {
val := reflect.ValueOf(model)
return val.Type().Elem().PkgPath()
}
func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string {
column := strings.ToLower(col)
if column == "" {
column = snakeString(sf.Name)
}
switch ft {
case RelForeignKey, RelOneToOne:
column = column + "_id"
case RelManyToMany, RelReverseMany, RelReverseOne:
column = sf.Name
}
return column
}
func getFieldType(val reflect.Value) (ft int, err error) {
elm := reflect.Indirect(val)
switch elm.Kind() {
case reflect.Int16:
ft = TypeSmallIntegerField
case reflect.Int32, reflect.Int:
ft = TypeIntegerField
case reflect.Int64:
ft = TypeBigIntegerField
case reflect.Uint16:
ft = TypePositiveSmallIntegerField
case reflect.Uint32:
ft = TypePositiveIntegerField
case reflect.Uint64:
ft = TypePositiveBigIntegerField
case reflect.Float32, reflect.Float64:
ft = TypeFloatField
case reflect.Bool:
ft = TypeBooleanField
case reflect.String:
ft = TypeTextField
case reflect.Invalid:
default:
if elm.CanInterface() {
if _, ok := elm.Interface().(time.Time); ok {
ft = TypeDateTimeField
}
}
}
if ft&IsFieldType == 0 {
err = fmt.Errorf("unsupport field type %s, may be miss setting tag", val)
}
return
}
func parseStructTag(data string, attrs *map[string]bool, tags *map[string]string) {
attr := make(map[string]bool)
tag := make(map[string]string)
for _, v := range strings.Split(data, ";") {
v = strings.TrimSpace(v)
if supportTag[v] == 1 {
attr[v] = true
} else if i := strings.Index(v, "("); i > 0 && strings.Index(v, ")") == len(v)-1 {
name := v[:i]
if supportTag[name] == 2 {
v = v[i+1 : len(v)-1]
tag[name] = v
}
}
}
*attrs = attr
*tags = tag
}
package orm
import (
"database/sql"
"errors"
"fmt"
"time"
)
var (
ErrTXHasBegin = errors.New("<Ormer.Begin> transaction already begin")
ErrTXNotBegin = errors.New("<Ormer.Commit/Rollback> transaction not begin")
ErrMultiRows = errors.New("<QuerySeter.One> return multi rows")
ErrStmtClosed = errors.New("<QuerySeter.Insert> stmt already closed")
DefaultRowsLimit = 1000
DefaultRelsDepth = 5
DefaultTimeLoc = time.Local
)
type Params map[string]interface{}
type ParamsList []interface{}
type orm struct {
alias *alias
db dbQuerier
isTx bool
}
func (o *orm) Object(md Modeler) ObjectSeter {
name := md.GetTableName()
if mi, ok := modelCache.get(name); ok {
return newObject(o, mi, md)
}
panic(fmt.Sprintf("<orm.Object> table name: `%s` not exists", name))
}
func (o *orm) QueryTable(ptrStructOrTableName interface{}) QuerySeter {
name := ""
if table, ok := ptrStructOrTableName.(string); ok {
name = snakeString(table)
} else if m, ok := ptrStructOrTableName.(Modeler); ok {
name = m.GetTableName()
}
if mi, ok := modelCache.get(name); ok {
return newQuerySet(o, mi)
}
panic(fmt.Sprintf("<orm.SetTable> table name: `%s` not exists", name))
}
func (o *orm) Using(name string) error {
if o.isTx {
panic("<orm.Using> transaction has been start, cannot change db")
}
if al, ok := dataBaseCache.get(name); ok {
o.alias = al
o.db = al.DB
} else {
return errors.New(fmt.Sprintf("<orm.Using> unknown db alias name `%s`", name))
}
return nil
}
func (o *orm) Begin() error {
if o.isTx {
return ErrTXHasBegin
}
tx, err := o.alias.DB.Begin()
if err != nil {
return err
}
o.isTx = true
o.db = tx
return nil
}
func (o *orm) Commit() error {
if o.isTx == false {
return ErrTXNotBegin
}
err := o.db.(*sql.Tx).Commit()
if err == nil {
o.isTx = false
o.db = o.alias.DB
}
return err
}
func (o *orm) Rollback() error {
if o.isTx == false {
return ErrTXNotBegin
}
err := o.db.(*sql.Tx).Rollback()
if err == nil {
o.isTx = false
o.db = o.alias.DB
}
return err
}
func (o *orm) Raw(query string, args ...interface{}) RawSeter {
return newRawSet(o, query, args)
}
func NewOrm() Ormer {
o := new(orm)
err := o.Using("default")
if err != nil {
panic(err)
}
return o
}
package orm
import (
"strings"
)
const (
ExprSep = "__"
)
type condValue struct {
exprs []string
args []interface{}
cond *Condition
isOr bool
isNot bool
isCond bool
}
type Condition struct {
params []condValue
}
func NewCondition() *Condition {
c := &Condition{}
return c
}
func (c *Condition) And(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic("<Condition.And> args cannot empty")
}
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args})
return c
}
func (c *Condition) AndNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic("<Condition.AndNot> args cannot empty")
}
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true})
return c
}
func (c *Condition) AndCond(cond *Condition) *Condition {
if c == cond {
panic("cannot use self as sub cond")
}
if cond != nil {
c.params = append(c.params, condValue{cond: cond, isCond: true})
}
return c
}
func (c *Condition) Or(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic("<Condition.Or> args cannot empty")
}
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isOr: true})
return c
}
func (c *Condition) OrNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic("<Condition.OrNot> args cannot empty")
}
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true, isOr: true})
return c
}
func (c *Condition) OrCond(cond *Condition) *Condition {
if c == cond {
panic("cannot use self as sub cond")
}
if cond != nil {
c.params = append(c.params, condValue{cond: cond, isCond: true, isOr: true})
}
return c
}
func (c *Condition) IsEmpty() bool {
return len(c.params) == 0
}
func (c Condition) Clone() *Condition {
params := c.params
c.params = make([]condValue, len(params))
copy(c.params, params)
return &c
}
func (c *Condition) Merge() (expr string, args []interface{}) {
return expr, args
}
package orm
import (
"database/sql"
"fmt"
"reflect"
)
type insertSet struct {
mi *modelInfo
orm *orm
stmt *sql.Stmt
closed bool
}
func (o *insertSet) Insert(md Modeler) (int64, error) {
if o.closed {
return 0, ErrStmtClosed
}
val := reflect.ValueOf(md)
ind := reflect.Indirect(val)
if val.Type() != o.mi.addrField.Type() {
panic(fmt.Sprintf("<Inserter.Insert> need type `%s` but found `%s`", o.mi.addrField.Type(), val.Type()))
}
id, err := o.orm.alias.DbBaser.InsertStmt(o.stmt, o.mi, ind)
if err != nil {
return id, err
}
if id > 0 {
if o.mi.fields.auto != nil {
ind.Field(o.mi.fields.auto.fieldIndex).SetInt(id)
}
}
return id, nil
}
func (o *insertSet) Close() error {
o.closed = true
return o.stmt.Close()
}
func newInsertSet(orm *orm, mi *modelInfo) (Inserter, error) {
bi := new(insertSet)
bi.orm = orm
bi.mi = mi
st, err := orm.alias.DbBaser.PrepareInsert(orm.db, mi)
if err != nil {
return nil, err
}
bi.stmt = st
return bi, nil
}
type object struct {
ind reflect.Value
mi *modelInfo
orm *orm
}
func (o *object) Insert() (int64, error) {
id, err := o.orm.alias.DbBaser.Insert(o.orm.db, o.mi, o.ind)
if err != nil {
return id, err
}
if id > 0 {
if o.mi.fields.auto != nil {
o.ind.Field(o.mi.fields.auto.fieldIndex).SetInt(id)
}
}
return id, nil
}
func (o *object) Update() (int64, error) {
num, err := o.orm.alias.DbBaser.Update(o.orm.db, o.mi, o.ind)
if err != nil {
return num, err
}
return 0, nil
}
func (o *object) Delete() (int64, error) {
return o.orm.alias.DbBaser.Delete(o.orm.db, o.mi, o.ind)
}
func newObject(orm *orm, mi *modelInfo, md Modeler) ObjectSeter {
o := new(object)
ind := reflect.Indirect(reflect.ValueOf(md))
o.ind = ind
o.mi = mi
o.orm = orm
return o
}
package orm
import (
"fmt"
)
type querySet struct {
mi *modelInfo
cond *Condition
related []string
relDepth int
limit int
offset int64
orders []string
orm *orm
}
func (o *querySet) Filter(expr string, args ...interface{}) QuerySeter {
if o.cond == nil {
o.cond = NewCondition()
}
o.cond.And(expr, args...)
return o.Clone()
}
func (o *querySet) Exclude(expr string, args ...interface{}) QuerySeter {
if o.cond == nil {
o.cond = NewCondition()
}
o.cond.AndNot(expr, args...)
return o.Clone()
}
func (o *querySet) Limit(limit int, args ...int64) QuerySeter {
o.limit = limit
if len(args) > 0 {
o.offset = args[0]
}
return o.Clone()
}
func (o *querySet) Offset(offset int64) QuerySeter {
o.offset = offset
return o.Clone()
}
func (o *querySet) OrderBy(orders ...string) QuerySeter {
o.orders = orders
return o.Clone()
}
func (o *querySet) RelatedSel(params ...interface{}) QuerySeter {
var related []string
if len(params) == 0 {
o.relDepth = DefaultRelsDepth
} else {
for _, p := range params {
switch val := p.(type) {
case string:
related = append(o.related, val)
case int:
o.relDepth = val
default:
panic(fmt.Sprintf("<querySet.RelatedSel> wrong param kind: %v", val))
}
}
}
o.related = related
return o.Clone()
}
func (o querySet) Clone() QuerySeter {
if o.cond != nil {
o.cond = o.cond.Clone()
}
return &o
}
func (o *querySet) SetCond(cond *Condition) error {
o.cond = cond
return nil
}
func (o *querySet) Count() (int64, error) {
return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond)
}
func (o *querySet) Update(values Params) (int64, error) {
return o.orm.alias.DbBaser.UpdateBatch(o.orm.db, o, o.mi, o.cond, values)
}
func (o *querySet) Delete() (int64, error) {
return o.orm.alias.DbBaser.DeleteBatch(o.orm.db, o, o.mi, o.cond)
}
func (o *querySet) PrepareInsert() (Inserter, error) {
return newInsertSet(o.orm, o.mi)
}
func (o *querySet) All(container interface{}) (int64, error) {
return o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container)
}
func (o *querySet) One(container Modeler) error {
num, err := o.orm.alias.DbBaser.ReadBatch(o.orm.db, o, o.mi, o.cond, container)
if err != nil {
return err
}
if num > 1 {
return ErrMultiRows
}
return nil
}
func (o *querySet) Values(results *[]Params, args ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, args, results)
}
func (o *querySet) ValuesList(results *[]ParamsList, args ...string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, args, results)
}
func (o *querySet) ValuesFlat(result *ParamsList, arg string) (int64, error) {
return o.orm.alias.DbBaser.ReadValues(o.orm.db, o, o.mi, o.cond, []string{arg}, result)
}
func newQuerySet(orm *orm, mi *modelInfo) QuerySeter {
o := new(querySet)
o.mi = mi
o.orm = orm
return o
}
package orm
import (
"database/sql"
"fmt"
"reflect"
)
func getResult(res sql.Result) (int64, error) {
if num, err := res.LastInsertId(); err != nil {
return 0, err
} else {
if num > 0 {
return num, nil
}
}
if num, err := res.RowsAffected(); err != nil {
return num, err
} else {
if num > 0 {
return num, nil
}
}
return 0, nil
}
type rawPrepare struct {
rs *rawSet
stmt *sql.Stmt
closed bool
}
func (o *rawPrepare) Exec(args ...interface{}) (int64, error) {
if o.closed {
return 0, ErrStmtClosed
}
res, err := o.stmt.Exec(args...)
if err != nil {
return 0, err
}
return getResult(res)
}
func (o *rawPrepare) Close() error {
o.closed = true
return o.stmt.Close()
}
func newRawPreparer(rs *rawSet) (RawPreparer, error) {
o := new(rawPrepare)
o.rs = rs
st, err := rs.orm.db.Prepare(rs.query)
if err != nil {
return nil, err
}
o.stmt = st
return o, nil
}
type rawSet struct {
query string
args []interface{}
orm *orm
}
func (o rawSet) SetArgs(args ...interface{}) RawSeter {
o.args = args
return &o
}
func (o *rawSet) Exec() (int64, error) {
res, err := o.orm.db.Exec(o.query, o.args...)
if err != nil {
return 0, err
}
return getResult(res)
}
func (o *rawSet) Mapper(...interface{}) (int64, error) {
//TODO
return 0, nil
}
func (o *rawSet) readValues(container interface{}) (int64, error) {
var (
maps []Params
lists []ParamsList
list ParamsList
)
typ := 0
switch container.(type) {
case *[]Params:
typ = 1
case *[]ParamsList:
typ = 2
case *ParamsList:
typ = 3
default:
panic(fmt.Sprintf("unsupport read values type `%T`", container))
}
var rs *sql.Rows
if r, err := o.orm.db.Query(o.query, o.args...); err != nil {
return 0, err
} else {
rs = r
}
var (
refs []interface{}
cnt int64
cols []string
)
for rs.Next() {
if cnt == 0 {
if columns, err := rs.Columns(); err != nil {
return 0, err
} else {
cols = columns
refs = make([]interface{}, len(cols))
for i, _ := range refs {
var ref string
refs[i] = &ref
}
}
}
if err := rs.Scan(refs...); err != nil {
return 0, err
}
switch typ {
case 1:
params := make(Params, len(cols))
for i, ref := range refs {
value := reflect.Indirect(reflect.ValueOf(ref)).Interface()
params[cols[i]] = value
}
maps = append(maps, params)
case 2:
params := make(ParamsList, 0, len(cols))
for _, ref := range refs {
value := reflect.Indirect(reflect.ValueOf(ref)).Interface()
params = append(params, value)
}
lists = append(lists, params)
case 3:
for _, ref := range refs {
value := reflect.Indirect(reflect.ValueOf(ref)).Interface()
list = append(list, value)
}
}
cnt++
}
switch v := container.(type) {
case *[]Params:
*v = maps
case *[]ParamsList:
*v = lists
case *ParamsList:
*v = list
}
return cnt, nil
}
func (o *rawSet) Values(container *[]Params) (int64, error) {
return o.readValues(container)
}
func (o *rawSet) ValuesList(container *[]ParamsList) (int64, error) {
return o.readValues(container)
}
func (o *rawSet) ValuesFlat(container *ParamsList) (int64, error) {
return o.readValues(container)
}
func (o *rawSet) Prepare() (RawPreparer, error) {
return newRawPreparer(o)
}
func newRawSet(orm *orm, query string, args []interface{}) RawSeter {
o := new(rawSet)
o.query = query
o.args = args
o.orm = orm
return o
}
package orm
import (
"database/sql"
"reflect"
)
type Fielder interface {
String() string
FieldType() int
SetRaw(interface{}) error
RawValue() interface{}
}
type Modeler interface {
Init(Modeler) Modeler
IsInited() bool
Clean() FieldErrors
CleanFields(string) FieldErrors
GetTableName() string
}
type Ormer interface {
Object(Modeler) ObjectSeter
QueryTable(interface{}) QuerySeter
Using(string) error
Begin() error
Commit() error
Rollback() error
Raw(string, ...interface{}) RawSeter
}
type ObjectSeter interface {
Insert() (int64, error)
Update() (int64, error)
Delete() (int64, error)
}
type Inserter interface {
Insert(Modeler) (int64, error)
Close() error
}
type QuerySeter interface {
Filter(string, ...interface{}) QuerySeter
Exclude(string, ...interface{}) QuerySeter
Limit(int, ...int64) QuerySeter
Offset(int64) QuerySeter
OrderBy(...string) QuerySeter
RelatedSel(...interface{}) QuerySeter
Clone() QuerySeter
SetCond(*Condition) error
Count() (int64, error)
Update(Params) (int64, error)
Delete() (int64, error)
PrepareInsert() (Inserter, error)
All(interface{}) (int64, error)
One(Modeler) error
Values(*[]Params, ...string) (int64, error)
ValuesList(*[]ParamsList, ...string) (int64, error)
ValuesFlat(*ParamsList, string) (int64, error)
}
type RawPreparer interface {
Close() error
}
type RawSeter interface {
Exec() (int64, error)
Mapper(...interface{}) (int64, error)
Values(*[]Params) (int64, error)
ValuesList(*[]ParamsList) (int64, error)
ValuesFlat(*ParamsList) (int64, error)
Prepare() (RawPreparer, error)
}
type dbQuerier interface {
Prepare(query string) (*sql.Stmt, error)
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
type dbBaser interface {
Insert(dbQuerier, *modelInfo, reflect.Value) (int64, error)
InsertStmt(*sql.Stmt, *modelInfo, reflect.Value) (int64, error)
Update(dbQuerier, *modelInfo, reflect.Value) (int64, error)
Delete(dbQuerier, *modelInfo, reflect.Value) (int64, error)
ReadBatch(dbQuerier, *querySet, *modelInfo, *Condition, interface{}) (int64, error)
UpdateBatch(dbQuerier, *querySet, *modelInfo, *Condition, Params) (int64, error)
DeleteBatch(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error)
Count(dbQuerier, *querySet, *modelInfo, *Condition) (int64, error)
GetOperatorSql(*modelInfo, string, []interface{}) (string, []interface{})
PrepareInsert(dbQuerier, *modelInfo) (*sql.Stmt, error)
ReadValues(dbQuerier, *querySet, *modelInfo, *Condition, []string, interface{}) (int64, error)
}
package orm
import (
"fmt"
"strconv"
"strings"
"time"
)
type StrTo string
func (f *StrTo) Set(v string) {
if v != "" {
*f = StrTo(v)
} else {
f.Clear()
}
}
func (f *StrTo) Clear() {
*f = StrTo(0x1E)
}
func (f StrTo) Exist() bool {
return string(f) != string(0x1E)
}
func (f StrTo) Bool() (bool, error) {
return strconv.ParseBool(f.String())
}
func (f StrTo) Float32() (float32, error) {
v, err := strconv.ParseFloat(f.String(), 32)
return float32(v), err
}
func (f StrTo) Float64() (float64, error) {
return strconv.ParseFloat(f.String(), 64)
}
func (f StrTo) Int16() (int16, error) {
v, err := strconv.ParseInt(f.String(), 10, 16)
return int16(v), err
}
func (f StrTo) Int32() (int32, error) {
v, err := strconv.ParseInt(f.String(), 10, 32)
return int32(v), err
}
func (f StrTo) Int64() (int64, error) {
v, err := strconv.ParseInt(f.String(), 10, 64)
return int64(v), err
}
func (f StrTo) Uint16() (uint16, error) {
v, err := strconv.ParseUint(f.String(), 10, 16)
return uint16(v), err
}
func (f StrTo) Uint32() (uint32, error) {
v, err := strconv.ParseUint(f.String(), 10, 32)
return uint32(v), err
}
func (f StrTo) Uint64() (uint64, error) {
v, err := strconv.ParseUint(f.String(), 10, 64)
return uint64(v), err
}
func (f StrTo) String() string {
if f.Exist() {
return string(f)
}
return ""
}
func ToStr(value interface{}, args ...int) (s string) {
switch v := value.(type) {
case bool:
s = strconv.FormatBool(v)
case float32:
s = strconv.FormatFloat(float64(v), 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 32))
case float64:
s = strconv.FormatFloat(v, 'f', argInt(args).Get(0, -1), argInt(args).Get(1, 64))
case int:
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
case int16:
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
case int32:
s = strconv.FormatInt(int64(v), argInt(args).Get(0, 10))
case int64:
s = strconv.FormatInt(v, argInt(args).Get(0, 10))
case uint:
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
case uint16:
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
case uint32:
s = strconv.FormatUint(uint64(v), argInt(args).Get(0, 10))
case uint64:
s = strconv.FormatUint(v, argInt(args).Get(0, 10))
case string:
s = v
default:
s = fmt.Sprintf("%v", v)
}
return s
}
func snakeString(s string) string {
data := make([]byte, 0, len(s)*2)
j := false
num := len(s)
for i := 0; i < num; i++ {
d := s[i]
if i > 0 && d >= 'A' && d <= 'Z' && j {
data = append(data, '_')
}
if d != '_' {
j = true
}
data = append(data, d)
}
return strings.ToLower(string(data[:len(data)]))
}
func camelString(s string) string {
data := make([]byte, 0, len(s))
j := false
k := false
num := len(s) - 1
for i := 0; i <= num; i++ {
d := s[i]
if k == false && d >= 'A' && d <= 'Z' {
k = true
}
if d >= 'a' && d <= 'z' && (j || k == false) {
d = d - 32
j = false
k = true
}
if k && d == '_' && num > i && s[i+1] >= 'a' && s[i+1] <= 'z' {
j = true
continue
}
data = append(data, d)
}
return string(data[:len(data)])
}
type argString []string
func (a argString) Get(i int, args ...string) (r string) {
if i >= 0 && i < len(a) {
r = a[i]
} else if len(args) > 0 {
r = args[0]
}
return
}
type argInt []int
func (a argInt) Get(i int, args ...int) (r int) {
if i >= 0 && i < len(a) {
r = a[i]
}
if len(args) > 0 {
r = args[0]
}
return
}
func timeParse(dateString, format string) (time.Time, error) {
tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
return tp, err
}
func timeFormat(t time.Time, format string) string {
return t.Format(format)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment