Commit 46668b81 authored by slene's avatar slene

some fix / add test

parent 10f4e822
This diff is collapsed.
......@@ -9,24 +9,37 @@ import (
const defaultMaxIdle = 30
type driverType int
type DriverType int
const (
_ driverType = iota
_ DriverType = iota
DR_MySQL
DR_Sqlite
DR_Oracle
DR_Postgres
)
type driver string
func (d driver) Type() DriverType {
a, _ := dataBaseCache.get(string(d))
return a.Driver
}
func (d driver) Name() string {
return string(d)
}
var _ Driver = new(driver)
var (
dataBaseCache = &_dbCache{cache: make(map[string]*alias)}
drivers = map[string]driverType{
drivers = map[string]DriverType{
"mysql": DR_MySQL,
"postgres": DR_Postgres,
"sqlite3": DR_Sqlite,
}
dbBasers = map[driverType]dbBaser{
dbBasers = map[DriverType]dbBaser{
DR_MySQL: newdbBaseMysql(),
DR_Sqlite: newdbBaseSqlite(),
DR_Oracle: newdbBaseMysql(),
......@@ -63,6 +76,7 @@ func (ac *_dbCache) getDefault() (al *alias) {
type alias struct {
Name string
Driver DriverType
DriverName string
DataSource string
MaxIdle int
......@@ -87,6 +101,7 @@ func RegisterDataBase(name, driverName, dataSource string, maxIdle int) {
if dr, ok := drivers[driverName]; ok {
al.DbBaser = dbBasers[dr]
al.Driver = dr
} else {
err = fmt.Errorf("driver name `%s` have not registered", driverName)
goto end
......@@ -116,7 +131,7 @@ end:
}
}
func RegisterDriver(name string, typ driverType) {
func RegisterDriver(name string, typ DriverType) {
if t, ok := drivers[name]; ok == false {
drivers[name] = typ
} else {
......
......@@ -49,6 +49,7 @@ type _modelCache struct {
sync.RWMutex
orders []string
cache map[string]*modelInfo
done bool
}
func (mc *_modelCache) all() map[string]*modelInfo {
......
......@@ -8,7 +8,7 @@ import (
"strings"
)
func RegisterModel(model Modeler) {
func registerModel(model Modeler) {
info := newModelInfo(model)
model.Init(model)
table := model.GetTableName()
......@@ -27,9 +27,10 @@ func RegisterModel(model Modeler) {
modelCache.set(table, info)
}
func BootStrap() {
modelCache.Lock()
defer modelCache.Unlock()
func bootStrap() {
if modelCache.done {
return
}
var (
err error
......@@ -59,14 +60,6 @@ func BootStrap() {
}
fi.relModelInfo = mii
if fi.rel {
if mii.fields.pk.IsMulti() {
err = fmt.Errorf("field `%s` unsupport rel to multi primary key field", fi.fullName)
goto end
}
}
switch fi.fieldType {
case RelManyToMany:
if fi.relThrough != "" {
......@@ -207,6 +200,25 @@ end:
fmt.Println(err)
os.Exit(2)
}
}
runCommand()
func RegisterModel(models ...Modeler) {
if modelCache.done {
panic(fmt.Errorf("RegisterModel must be run begore BootStrap"))
}
for _, model := range models {
registerModel(model)
}
}
func BootStrap() {
if modelCache.done {
return
}
modelCache.Lock()
defer modelCache.Unlock()
bootStrap()
modelCache.done = true
}
......@@ -32,32 +32,8 @@ func (f *fieldChoices) Clone() fieldChoices {
return *f
}
type primaryKeys []*fieldInfo
func (p *primaryKeys) Add(fi *fieldInfo) {
*p = append(*p, fi)
}
func (p primaryKeys) Exist(fi *fieldInfo) (int, bool) {
for i, v := range p {
if v == fi {
return i, true
}
}
return -1, false
}
func (p primaryKeys) IsMulti() bool {
return len(p) > 1
}
func (p primaryKeys) IsEmpty() bool {
return len(p) == 0
}
type fields struct {
pk primaryKeys
auto *fieldInfo
pk *fieldInfo
columns map[string]*fieldInfo
fields map[string]*fieldInfo
fieldsLow map[string]*fieldInfo
......
......@@ -50,41 +50,31 @@ func newModelInfo(model Modeler) (info *modelInfo) {
if err != nil {
break
}
added := info.fields.Add(fi)
if added == false {
err = errors.New(fmt.Sprintf("duplicate column name: %s", fi.column))
break
}
if fi.pk {
if info.fields.pk != nil {
err = errors.New(fmt.Sprintf("one model must have one pk field only"))
break
} else {
info.fields.pk.Add(fi)
info.fields.pk = fi
}
}
if fi.auto {
info.fields.auto = fi
}
fi.fieldIndex = i
fi.mi = info
}
if _, ok := info.fields.pk.Exist(info.fields.auto); info.fields.auto != nil && ok == false {
err = errors.New(fmt.Sprintf("when auto field exists, you cannot set other pk field"))
goto end
}
if err != nil {
fmt.Println(fmt.Errorf("field: %s.%s, %s", ind.Type(), sf.Name, err))
os.Exit(2)
}
end:
if err != nil {
fmt.Println(err)
os.Exit(2)
}
return
}
......@@ -125,6 +115,6 @@ func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
info.fields.Add(fa)
info.fields.Add(f1)
info.fields.Add(f2)
info.fields.pk.Add(fa)
info.fields.pk = fa
return
}
package orm
import (
"fmt"
"os"
"time"
_ "github.com/bmizerany/pq"
_ "github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3"
)
type User struct {
Id int `orm:"auto"`
UserName string `orm:"size(30);unique"`
Email string `orm:"size(100)"`
Password string `orm:"size(100)"`
Status int16 `orm:"choices(0,1,2,3);defalut(0)"`
IsStaff bool `orm:"default(false)"`
IsActive bool `orm:"default(1)"`
Created time.Time `orm:"auto_now_add;type(date)"`
Updated time.Time `orm:"auto_now"`
Profile *Profile `orm:"null;rel(one);on_delete(set_null)"`
Posts []*Post `orm:"reverse(many)" json:"-"`
Manager `json:"-"`
}
func NewUser() *User {
obj := new(User)
obj.Manager.Init(obj)
return obj
}
type Profile struct {
Id int `orm:"auto"`
Age int16 ``
Money float64 ``
User *User `orm:"reverse(one)" json:"-"`
Manager `json:"-"`
}
func (u *Profile) TableName() string {
return "user_profile"
}
func NewProfile() *Profile {
obj := new(Profile)
obj.Manager.Init(obj)
return obj
}
type Post struct {
Id int `orm:"auto"`
User *User `orm:"rel(fk)"` //
Title string `orm:"size(60)"`
Content string ``
Created time.Time `orm:"auto_now_add"`
Updated time.Time `orm:"auto_now"`
Tags []*Tag `orm:"rel(m2m)"`
Manager `json:"-"`
}
func NewPost() *Post {
obj := new(Post)
obj.Manager.Init(obj)
return obj
}
type Tag struct {
Id int `orm:"auto"`
Name string `orm:"size(30)"`
Posts []*Post `orm:"reverse(many)" json:"-"`
Manager `json:"-"`
}
func NewTag() *Tag {
obj := new(Tag)
obj.Manager.Init(obj)
return obj
}
type Comment struct {
Id int `orm:"auto"`
Post *Post `orm:"rel(fk)"`
Content string ``
Parent *Comment `orm:"null;rel(fk)"`
Created time.Time `orm:"auto_now_add"`
Manager `json:"-"`
}
func NewComment() *Comment {
obj := new(Comment)
obj.Manager.Init(obj)
return obj
}
var DBARGS = struct {
Driver string
Source string
}{
os.Getenv("ORM_DRIVER"),
os.Getenv("ORM_SOURCE"),
}
var dORM Ormer
func init() {
RegisterModel(new(User))
RegisterModel(new(Profile))
RegisterModel(new(Post))
RegisterModel(new(Tag))
RegisterModel(new(Comment))
if DBARGS.Driver == "" || DBARGS.Source == "" {
fmt.Println(`need driver and source!
Default DB Drivers.
driver: url
mysql: https://github.com/go-sql-driver/mysql
sqlite3: https://github.com/mattn/go-sqlite3
postgres: https://github.com/bmizerany/pq
eg: mysql
ORM_DRIVER=mysql ORM_SOURCE="root:root@/my_db?charset=utf8" go test github.com/astaxie/beego/orm
`)
os.Exit(2)
}
RegisterDataBase("default", DBARGS.Driver, DBARGS.Source, 20)
BootStrap()
truncateTables()
dORM = NewOrm()
}
func truncateTables() {
logs := "truncate tables for test\n"
o := NewOrm()
for _, m := range modelCache.allOrdered() {
query := fmt.Sprintf("truncate table `%s`", m.table)
_, err := o.Raw(query).Exec()
logs += query + "\n"
if err != nil {
fmt.Println(logs)
fmt.Println(err)
os.Exit(2)
}
}
}
......@@ -9,13 +9,15 @@ import (
)
var (
ErrTXHasBegin = errors.New("<Ormer.Begin> transaction already begin")
ErrTXNotBegin = errors.New("<Ormer.Commit/Rollback> transaction not begin")
ErrMultiRows = errors.New("<QuerySeter.One> return multi rows")
ErrStmtClosed = errors.New("<QuerySeter.Insert> stmt already closed")
DefaultRowsLimit = 1000
DefaultRelsDepth = 5
DefaultTimeLoc = time.Local
ErrTxHasBegan = errors.New("<Ormer.Begin> transaction already begin")
ErrTxDone = errors.New("<Ormer.Commit/Rollback> transaction not begin")
ErrMultiRows = errors.New("<QuerySeter> return multi rows")
ErrNoRows = errors.New("<QuerySeter> not row found")
ErrStmtClosed = errors.New("<QuerySeter> stmt already closed")
ErrNotImplement = errors.New("have not implement")
)
type Params map[string]interface{}
......@@ -27,13 +29,15 @@ type orm struct {
isTx bool
}
var _ Ormer = new(orm)
func (o *orm) getMiInd(md Modeler) (mi *modelInfo, ind reflect.Value) {
md.Init(md, true)
name := md.GetTableName()
if mi, ok := modelCache.get(name); ok {
return mi, reflect.Indirect(reflect.ValueOf(md))
}
panic(fmt.Sprintf("<orm.Object> table name: `%s` not exists", name))
panic(fmt.Sprintf("<orm> table name: `%s` not exists", name))
}
func (o *orm) Read(md Modeler) error {
......@@ -52,8 +56,8 @@ func (o *orm) Insert(md Modeler) (int64, error) {
return id, err
}
if id > 0 {
if mi.fields.auto != nil {
ind.Field(mi.fields.auto.fieldIndex).SetInt(id)
if mi.fields.pk.auto {
ind.Field(mi.fields.pk.fieldIndex).SetInt(id)
}
}
return id, nil
......@@ -75,13 +79,31 @@ func (o *orm) Delete(md Modeler) (int64, error) {
return num, err
}
if num > 0 {
if mi.fields.auto != nil {
ind.Field(mi.fields.auto.fieldIndex).SetInt(0)
if mi.fields.pk.auto {
ind.Field(mi.fields.pk.fieldIndex).SetInt(0)
}
}
return num, nil
}
func (o *orm) M2mAdd(md Modeler, name string, mds ...interface{}) (int64, error) {
// TODO
panic(ErrNotImplement)
return 0, nil
}
func (o *orm) M2mDel(md Modeler, name string, mds ...interface{}) (int64, error) {
// TODO
panic(ErrNotImplement)
return 0, nil
}
func (o *orm) LoadRel(md Modeler, name string) (int64, error) {
// TODO
panic(ErrNotImplement)
return 0, nil
}
func (o *orm) QueryTable(ptrStructOrTableName interface{}) QuerySeter {
name := ""
if table, ok := ptrStructOrTableName.(string); ok {
......@@ -111,7 +133,7 @@ func (o *orm) Using(name string) error {
func (o *orm) Begin() error {
if o.isTx {
return ErrTXHasBegin
return ErrTxHasBegan
}
tx, err := o.alias.DB.Begin()
if err != nil {
......@@ -124,24 +146,28 @@ func (o *orm) Begin() error {
func (o *orm) Commit() error {
if o.isTx == false {
return ErrTXNotBegin
return ErrTxDone
}
err := o.db.(*sql.Tx).Commit()
if err == nil {
o.isTx = false
o.db = o.alias.DB
} else if err == sql.ErrTxDone {
return ErrTxDone
}
return err
}
func (o *orm) Rollback() error {
if o.isTx == false {
return ErrTXNotBegin
return ErrTxDone
}
err := o.db.(*sql.Tx).Rollback()
if err == nil {
o.isTx = false
o.db = o.alias.DB
} else if err == sql.ErrTxDone {
return ErrTxDone
}
return err
}
......@@ -150,7 +176,13 @@ func (o *orm) Raw(query string, args ...interface{}) RawSeter {
return newRawSet(o, query, args)
}
func (o *orm) Driver() Driver {
return driver(o.alias.Name)
}
func NewOrm() Ormer {
BootStrap() // execute only once
o := new(orm)
err := o.Using("default")
if err != nil {
......
......@@ -26,23 +26,24 @@ func NewCondition() *Condition {
return c
}
func (c *Condition) And(expr string, args ...interface{}) *Condition {
func (c Condition) And(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic("<Condition.And> args cannot empty")
}
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args})
return c
return &c
}
func (c *Condition) AndNot(expr string, args ...interface{}) *Condition {
func (c Condition) AndNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic("<Condition.AndNot> args cannot empty")
}
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true})
return c
return &c
}
func (c *Condition) AndCond(cond *Condition) *Condition {
c = c.clone()
if c == cond {
panic("cannot use self as sub cond")
}
......@@ -52,23 +53,24 @@ func (c *Condition) AndCond(cond *Condition) *Condition {
return c
}
func (c *Condition) Or(expr string, args ...interface{}) *Condition {
func (c Condition) Or(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic("<Condition.Or> args cannot empty")
}
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isOr: true})
return c
return &c
}
func (c *Condition) OrNot(expr string, args ...interface{}) *Condition {
func (c Condition) OrNot(expr string, args ...interface{}) *Condition {
if expr == "" || len(args) == 0 {
panic("<Condition.OrNot> args cannot empty")
}
c.params = append(c.params, condValue{exprs: strings.Split(expr, ExprSep), args: args, isNot: true, isOr: true})
return c
return &c
}
func (c *Condition) OrCond(cond *Condition) *Condition {
c = c.clone()
if c == cond {
panic("cannot use self as sub cond")
}
......@@ -82,13 +84,6 @@ func (c *Condition) IsEmpty() bool {
return len(c.params) == 0
}
func (c Condition) Clone() *Condition {
params := c.params
c.params = make([]condValue, len(params))
copy(c.params, params)
func (c Condition) clone() *Condition {
return &c
}
func (c *Condition) Merge() (expr string, args []interface{}) {
return expr, args
}
......@@ -13,6 +13,8 @@ type insertSet struct {
closed bool
}
var _ Inserter = new(insertSet)
func (o *insertSet) Insert(md Modeler) (int64, error) {
if o.closed {
return 0, ErrStmtClosed
......@@ -28,14 +30,17 @@ func (o *insertSet) Insert(md Modeler) (int64, error) {
return id, err
}
if id > 0 {
if o.mi.fields.auto != nil {
ind.Field(o.mi.fields.auto.fieldIndex).SetInt(id)
if o.mi.fields.pk.auto {
ind.Field(o.mi.fields.pk.fieldIndex).SetInt(id)
}
}
return id, nil
}
func (o *insertSet) Close() error {
if o.closed {
return ErrStmtClosed
}
o.closed = true
return o.stmt.Close()
}
......
......@@ -15,47 +15,43 @@ type querySet struct {
orm *orm
}
func (o *querySet) Filter(expr string, args ...interface{}) QuerySeter {
o = o.clone()
var _ QuerySeter = new(querySet)
func (o querySet) Filter(expr string, args ...interface{}) QuerySeter {
if o.cond == nil {
o.cond = NewCondition()
}
o.cond.And(expr, args...)
return o
o.cond = o.cond.And(expr, args...)
return &o
}
func (o *querySet) Exclude(expr string, args ...interface{}) QuerySeter {
o = o.clone()
func (o querySet) Exclude(expr string, args ...interface{}) QuerySeter {
if o.cond == nil {
o.cond = NewCondition()
}
o.cond.AndNot(expr, args...)
return o
o.cond = o.cond.AndNot(expr, args...)
return &o
}
func (o *querySet) Limit(limit int, args ...int64) QuerySeter {
o = o.clone()
func (o querySet) Limit(limit int, args ...int64) QuerySeter {
o.limit = limit
if len(args) > 0 {
o.offset = args[0]
}
return o
return &o
}
func (o *querySet) Offset(offset int64) QuerySeter {
o = o.clone()
func (o querySet) Offset(offset int64) QuerySeter {
o.offset = offset
return o
return &o
}
func (o *querySet) OrderBy(exprs ...string) QuerySeter {
o = o.clone()
func (o querySet) OrderBy(exprs ...string) QuerySeter {
o.orders = exprs
return o
return &o
}
func (o *querySet) RelatedSel(params ...interface{}) QuerySeter {
o = o.clone()
func (o querySet) RelatedSel(params ...interface{}) QuerySeter {
var related []string
if len(params) == 0 {
o.relDepth = DefaultRelsDepth
......@@ -72,13 +68,6 @@ func (o *querySet) RelatedSel(params ...interface{}) QuerySeter {
}
}
o.related = related
return o
}
func (o querySet) clone() *querySet {
if o.cond != nil {
o.cond = o.cond.Clone()
}
return &o
}
......@@ -115,6 +104,9 @@ func (o *querySet) One(container Modeler) error {
if num > 1 {
return ErrMultiRows
}
if num == 0 {
return ErrNoRows
}
return nil
}
......
......@@ -63,6 +63,8 @@ type rawSet struct {
orm *orm
}
var _ RawSeter = new(rawSet)
func (o rawSet) SetArgs(args ...interface{}) RawSeter {
o.args = args
return &o
......@@ -76,7 +78,12 @@ func (o *rawSet) Exec() (int64, error) {
return getResult(res)
}
func (o *rawSet) Mapper(...interface{}) (int64, error) {
func (o *rawSet) QueryRow(...interface{}) error {
//TODO
return nil
}
func (o *rawSet) QueryRows(...interface{}) (int64, error) {
//TODO
return 0, nil
}
......@@ -120,7 +127,7 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
cols = columns
refs = make([]interface{}, len(cols))
for i, _ := range refs {
var ref string
var ref sql.NullString
refs[i] = &ref
}
}
......@@ -134,21 +141,21 @@ func (o *rawSet) readValues(container interface{}) (int64, error) {
case 1:
params := make(Params, len(cols))
for i, ref := range refs {
value := reflect.Indirect(reflect.ValueOf(ref)).Interface()
params[cols[i]] = value
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
params[cols[i]] = value.String
}
maps = append(maps, params)
case 2:
params := make(ParamsList, 0, len(cols))
for _, ref := range refs {
value := reflect.Indirect(reflect.ValueOf(ref)).Interface()
params = append(params, value)
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
params = append(params, value.String)
}
lists = append(lists, params)
case 3:
for _, ref := range refs {
value := reflect.Indirect(reflect.ValueOf(ref)).Interface()
list = append(list, value)
value := reflect.Indirect(reflect.ValueOf(ref)).Interface().(sql.NullString)
list = append(list, value.String)
}
}
......
This diff is collapsed.
......@@ -5,6 +5,11 @@ import (
"reflect"
)
type Driver interface {
Name() string
Type() DriverType
}
type Fielder interface {
String() string
FieldType() int
......@@ -26,12 +31,16 @@ type Ormer interface {
Insert(Modeler) (int64, error)
Update(Modeler) (int64, error)
Delete(Modeler) (int64, error)
M2mAdd(Modeler, string, ...interface{}) (int64, error)
M2mDel(Modeler, string, ...interface{}) (int64, error)
LoadRel(Modeler, string) (int64, error)
QueryTable(interface{}) QuerySeter
Using(string) error
Begin() error
Commit() error
Rollback() error
Raw(string, ...interface{}) RawSeter
Driver() Driver
}
type Inserter interface {
......@@ -42,16 +51,15 @@ type Inserter interface {
type QuerySeter interface {
Filter(string, ...interface{}) QuerySeter
Exclude(string, ...interface{}) QuerySeter
SetCond(*Condition) QuerySeter
Limit(int, ...int64) QuerySeter
Offset(int64) QuerySeter
OrderBy(...string) QuerySeter
RelatedSel(...interface{}) QuerySeter
SetCond(*Condition) QuerySeter
Count() (int64, error)
Update(Params) (int64, error)
Delete() (int64, error)
PrepareInsert() (Inserter, error)
All(interface{}) (int64, error)
One(Modeler) error
Values(*[]Params, ...string) (int64, error)
......@@ -60,12 +68,15 @@ type QuerySeter interface {
}
type RawPreparer interface {
Exec(...interface{}) (int64, error)
Close() error
}
type RawSeter interface {
Exec() (int64, error)
Mapper(...interface{}) (int64, error)
QueryRow(...interface{}) error
QueryRows(...interface{}) (int64, error)
SetArgs(...interface{}) RawSeter
Values(*[]Params) (int64, error)
ValuesList(*[]ParamsList) (int64, error)
ValuesFlat(*ParamsList) (int64, error)
......
......@@ -171,6 +171,18 @@ func (a argInt) Get(i int, args ...int) (r int) {
return
}
type argAny []interface{}
func (a argAny) Get(i int, args ...interface{}) (r interface{}) {
if i >= 0 && i < len(a) {
r = a[i]
}
if len(args) > 0 {
r = args[0]
}
return
}
func timeParse(dateString, format string) (time.Time, error) {
tp, err := time.ParseInLocation(format, dateString, DefaultTimeLoc)
return tp, err
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment