Commit c38abf35 authored by slene's avatar slene

orm support auto create db

parent 1fedaf21
package orm
import (
"flag"
"fmt"
"os"
"strings"
)
type commander interface {
Parse([]string)
Run()
}
var (
commands = make(map[string]commander)
)
func printHelp(errs ...string) {
content := `orm command usage:
syncdb - auto create tables
sqlall - print sql of create tables
help - print this help
`
if len(errs) > 0 {
fmt.Println(errs[0])
}
fmt.Println(content)
os.Exit(2)
}
func RunCommand() {
if len(os.Args) < 2 || os.Args[1] != "orm" {
return
}
BootStrap()
args := argString(os.Args[2:])
name := args.Get(0)
if name == "help" {
printHelp()
}
if cmd, ok := commands[name]; ok {
cmd.Parse(os.Args[3:])
cmd.Run()
os.Exit(0)
} else {
if name == "" {
printHelp()
} else {
printHelp(fmt.Sprintf("unknown command %s", name))
}
}
}
type commandSyncDb struct {
al *alias
force bool
verbose bool
}
func (d *commandSyncDb) Parse(args []string) {
var name string
flagSet := flag.NewFlagSet("orm command: syncdb", flag.ExitOnError)
flagSet.StringVar(&name, "db", "default", "DataBase alias name")
flagSet.BoolVar(&d.force, "force", false, "drop tables before create")
flagSet.BoolVar(&d.verbose, "v", false, "verbose info")
flagSet.Parse(args)
d.al = getDbAlias(name)
}
func (d *commandSyncDb) Run() {
var drops []string
if d.force {
drops = getDbDropSql(d.al)
}
db := d.al.DB
if d.force {
for i, mi := range modelCache.allOrdered() {
query := drops[i]
_, err := db.Exec(query)
result := ""
if err != nil {
result = err.Error()
}
fmt.Printf("drop table `%s` %s\n", mi.table, result)
if d.verbose {
fmt.Printf(" %s\n\n", query)
}
}
}
tables := getDbCreateSql(d.al)
for i, mi := range modelCache.allOrdered() {
query := tables[i]
_, err := db.Exec(query)
fmt.Printf("create table `%s` \n", mi.table)
if d.verbose {
query = " " + strings.Join(strings.Split(query, "\n"), "\n ")
fmt.Println(query)
}
if err != nil {
fmt.Printf(" %s\n", err.Error())
}
if d.verbose {
fmt.Println("")
}
}
}
type commandSqlAll struct {
al *alias
}
func (d *commandSqlAll) Parse(args []string) {
var name string
flagSet := flag.NewFlagSet("orm command: sqlall", flag.ExitOnError)
flagSet.StringVar(&name, "db", "default", "DataBase alias name")
flagSet.Parse(args)
d.al = getDbAlias(name)
}
func (d *commandSqlAll) Run() {
sqls := getDbCreateSql(d.al)
sql := strings.Join(sqls, "\n\n")
fmt.Println(sql)
}
func init() {
commands["syncdb"] = new(commandSyncDb)
commands["sqlall"] = new(commandSqlAll)
}
package orm
import (
"fmt"
"os"
"strings"
)
func getDbAlias(name string) *alias {
if al, ok := dataBaseCache.get(name); ok {
return al
} else {
fmt.Println(fmt.Sprintf("unknown DataBase alias name %s", name))
os.Exit(2)
}
return nil
}
func getDbDropSql(al *alias) (sqls []string) {
if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model")
os.Exit(2)
}
Q := al.DbBaser.TableQuote()
for _, mi := range modelCache.allOrdered() {
sqls = append(sqls, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q))
}
return sqls
}
func getDbCreateSql(al *alias) (sqls []string) {
if len(modelCache.cache) == 0 {
fmt.Println("no Model found, need register your model")
os.Exit(2)
}
Q := al.DbBaser.TableQuote()
T := al.DbBaser.DbTypes()
for _, mi := range modelCache.allOrdered() {
sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.fullName)
sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q)
columns := make([]string, 0, len(mi.fields.fieldsDB))
for _, fi := range mi.fields.fieldsDB {
fieldType := fi.fieldType
column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q)
col := ""
checkColumn:
switch fieldType {
case TypeBooleanField:
col = T["bool"]
case TypeCharField:
col = fmt.Sprintf(T["string"], fi.size)
case TypeTextField:
col = T["string-text"]
case TypeDateField:
col = T["time.Time-date"]
case TypeDateTimeField:
col = T["time.Time"]
case TypeBitField:
col = T["int8"]
case TypeSmallIntegerField:
col = T["int16"]
case TypeIntegerField:
col = T["int32"]
case TypeBigIntegerField:
if al.Driver == DR_Sqlite {
fieldType = TypeIntegerField
goto checkColumn
}
col = T["int64"]
case TypePositiveBitField:
col = T["uint8"]
case TypePositiveSmallIntegerField:
col = T["uint16"]
case TypePositiveIntegerField:
col = T["uint32"]
case TypePositiveBigIntegerField:
col = T["uint64"]
case TypeFloatField:
col = T["float64"]
case TypeDecimalField:
s := T["float64-decimal"]
if strings.Index(s, "%d") == -1 {
col = s
} else {
col = fmt.Sprintf(s, fi.digits, fi.decimals)
}
case RelForeignKey, RelOneToOne:
fieldType = fi.relModelInfo.fields.pk.fieldType
goto checkColumn
}
if fi.auto {
if al.Driver == DR_Postgres {
column += T["auto"]
} else {
column += col + " " + T["auto"]
}
} else if fi.pk {
column += col + " " + T["pk"]
} else {
column += col
if fi.null == false {
column += " " + "NOT NULL"
}
if fi.unique {
column += " " + "UNIQUE"
}
}
if strings.Index(column, "%COL%") != -1 {
column = strings.Replace(column, "%COL%", fi.column, -1)
}
columns = append(columns, column)
}
sql += strings.Join(columns, ",\n")
sql += "\n)"
if al.Driver == DR_MySQL {
sql += " ENGINE=INNODB"
}
sqls = append(sqls, sql)
}
return sqls
}
package orm
import (
"flag"
"fmt"
"os"
)
func printHelp() {
}
func getSqlAll() (sql string) {
for _, mi := range modelCache.allOrdered() {
_ = mi
}
return
}
func runCommand() {
if len(os.Args) < 2 || os.Args[1] != "orm" {
return
}
_ = flag.NewFlagSet("orm command", flag.ExitOnError)
args := argString(os.Args[2:])
cmd := args.Get(0)
switch cmd {
case "syncdb":
case "sqlall":
sql := getSqlAll()
fmt.Println(sql)
default:
if cmd != "" {
fmt.Printf("unknown command %s", cmd)
} else {
printHelp()
}
os.Exit(2)
}
}
......@@ -805,7 +805,7 @@ setValue:
_, err = str.Int32()
case TypeBigIntegerField:
_, err = str.Int64()
case TypePostiveBitField:
case TypePositiveBitField:
_, err = str.Uint8()
case TypePositiveSmallIntegerField:
_, err = str.Uint16()
......@@ -1112,3 +1112,7 @@ func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) {
func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) {
*t = t.In(tz)
}
func (d *dbBase) DbTypes() map[string]string {
return nil
}
......@@ -17,6 +17,26 @@ var mysqlOperators = map[string]string{
"iendswith": "LIKE ?",
}
var mysqlTypes = map[string]string{
"auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY",
"pk": "NOT NULL PRIMARY KEY",
"bool": "bool",
"string": "varchar(%d)",
"string-text": "longtext",
"time.Time-date": "date",
"time.Time": "datetime",
"int8": "tinyint",
"int16": "smallint",
"int32": "integer",
"int64": "bigint",
"uint8": "tinyint unsigned",
"uint16": "smallint unsigned",
"uint32": "integer unsigned",
"uint64": "bigint unsigned",
"float64": "double precision",
"float64-decimal": "numeric(%d, %d)",
}
type dbBaseMysql struct {
dbBase
}
......@@ -27,6 +47,10 @@ func (d *dbBaseMysql) OperatorSql(operator string) string {
return mysqlOperators[operator]
}
func (d *dbBaseMysql) DbTypes() map[string]string {
return mysqlTypes
}
func newdbBaseMysql() dbBaser {
b := new(dbBaseMysql)
b.ins = b
......
......@@ -20,6 +20,26 @@ var postgresOperators = map[string]string{
"iendswith": "LIKE UPPER(?)",
}
var postgresTypes = map[string]string{
"auto": "serial NOT NULL PRIMARY KEY",
"pk": "NOT NULL PRIMARY KEY",
"bool": "bool",
"string": "varchar(%d)",
"string-text": "text",
"time.Time-date": "date",
"time.Time": "timestamp with time zone",
"int8": `smallint CHECK("%COL%" >= -127 AND "%COL%" <= 128)`,
"int16": "smallint",
"int32": "integer",
"int64": "bigint",
"uint8": `smallint CHECK("%COL%" >= 0 AND "%COL%" <= 255)`,
"uint16": `integer CHECK("%COL%" >= 0)`,
"uint32": `bigint CHECK("%COL%" >= 0)`,
"uint64": `bigint CHECK("%COL%" >= 0)`,
"float64": "double precision",
"float64-decimal": "numeric(%d, %d)",
}
type dbBasePostgres struct {
dbBase
}
......@@ -87,6 +107,10 @@ func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) (has bool)
return
}
func (d *dbBasePostgres) DbTypes() map[string]string {
return postgresTypes
}
func newdbBasePostgres() dbBaser {
b := new(dbBasePostgres)
b.ins = b
......
......@@ -19,6 +19,26 @@ var sqliteOperators = map[string]string{
"iendswith": "LIKE ? ESCAPE '\\'",
}
var sqliteTypes = map[string]string{
"auto": "NOT NULL PRIMARY KEY AUTOINCREMENT",
"pk": "NOT NULL PRIMARY KEY",
"bool": "bool",
"string": "varchar(%d)",
"string-text": "text",
"time.Time-date": "date",
"time.Time": "datetime",
"int8": "tinyint",
"int16": "smallint",
"int32": "integer",
"int64": "bigint",
"uint8": "tinyint unsigned",
"uint16": "smallint unsigned",
"uint32": "integer unsigned",
"uint64": "bigint unsigned",
"float64": "real",
"float64-decimal": "decimal",
}
type dbBaseSqlite struct {
dbBase
}
......@@ -43,6 +63,10 @@ func (d *dbBaseSqlite) MaxLimit() uint64 {
return 9223372036854775807
}
func (d *dbBaseSqlite) DbTypes() map[string]string {
return sqliteTypes
}
func newdbBaseSqlite() dbBaser {
b := new(dbBaseSqlite)
b.ins = b
......
......@@ -84,3 +84,10 @@ func (mc *_modelCache) set(table string, mi *modelInfo) *modelInfo {
}
return mii
}
func (mc *_modelCache) clean() {
mc.orders = make([]string, 0)
mc.cache = make(map[string]*modelInfo)
mc.cacheByFN = make(map[string]*modelInfo)
mc.done = false
}
......@@ -8,7 +8,7 @@ import (
"strings"
)
func registerModel(model interface{}) {
func registerModel(model interface{}, prefix string) {
val := reflect.ValueOf(model)
ind := reflect.Indirect(val)
typ := ind.Type()
......@@ -17,20 +17,25 @@ func registerModel(model interface{}) {
panic(fmt.Sprintf("<orm.RegisterModel> cannot use non-ptr model struct `%s`", getFullName(typ)))
}
info := newModelInfo(val)
table := getTableName(val)
if prefix != "" {
table = prefix + table
}
name := getFullName(typ)
if _, ok := modelCache.getByFN(name); ok {
fmt.Printf("<orm.RegisterModel> model `%s` redeclared, must be unique\n", name)
fmt.Printf("<orm.RegisterModel> model `%s` repeat register, must be unique\n", name)
os.Exit(2)
}
table := getTableName(val)
if _, ok := modelCache.get(table); ok {
fmt.Printf("<orm.RegisterModel> table name `%s` redeclared, must be unique\n", table)
fmt.Printf("<orm.RegisterModel> table name `%s` repeat register, must be unique\n", table)
os.Exit(2)
}
info := newModelInfo(val)
if info.fields.pk == nil {
outFor:
for _, fi := range info.fields.fieldsDB {
......@@ -58,6 +63,7 @@ func registerModel(model interface{}) {
info.pkg = typ.PkgPath()
info.model = model
info.manual = true
modelCache.set(table, info)
}
......@@ -72,7 +78,7 @@ func bootStrap() {
)
if dataBaseCache.getDefault() == nil {
err = fmt.Errorf("must have one register alias named `default`")
err = fmt.Errorf("must have one register DataBase alias named `default`")
goto end
}
......@@ -97,7 +103,7 @@ func bootStrap() {
switch fi.fieldType {
case RelManyToMany:
if fi.relThrough != "" {
msg := fmt.Sprintf("filed `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough)
msg := fmt.Sprintf("field `%s` wrong rel_through value `%s`", fi.fullName, fi.relThrough)
if i := strings.LastIndex(fi.relThrough, "."); i != -1 && len(fi.relThrough) > (i+1) {
pn := fi.relThrough[:i]
mn := fi.relThrough[i+1:]
......@@ -238,11 +244,22 @@ end:
func RegisterModel(models ...interface{}) {
if modelCache.done {
panic(fmt.Errorf("RegisterModel must be run begore BootStrap"))
panic(fmt.Errorf("RegisterModel must be run before BootStrap"))
}
for _, model := range models {
registerModel(model, "")
}
}
// register model with a prefix
func RegisterModelWithPrefix(prefix string, models ...interface{}) {
if modelCache.done {
panic(fmt.Errorf("RegisterModel must be run before BootStrap"))
}
for _, model := range models {
registerModel(model)
registerModel(model, prefix)
}
}
......
......@@ -31,7 +31,7 @@ const (
// int64
TypeBigIntegerField
// uint8
TypePostiveBitField
TypePositiveBitField
// uint16
TypePositiveSmallIntegerField
// uint32
......
......@@ -399,7 +399,7 @@ checkType:
_, err = v.Int32()
case TypeBigIntegerField:
_, err = v.Int64()
case TypePostiveBitField:
case TypePositiveBitField:
_, err = v.Uint8()
case TypePositiveSmallIntegerField:
_, err = v.Uint16()
......
......@@ -90,6 +90,9 @@ func newM2MModelInfo(m1, m2 *modelInfo) (info *modelInfo) {
fa.auto = true
fa.pk = true
fa.dbcol = true
fa.name = "Id"
fa.column = "id"
fa.fullName = info.fullName + "." + fa.name
f1.dbcol = true
f2.dbcol = true
......
This diff is collapsed.
......@@ -52,7 +52,7 @@ func getFieldType(val reflect.Value) (ft int, err error) {
case reflect.Int64:
ft = TypeBigIntegerField
case reflect.Uint8:
ft = TypePostiveBitField
ft = TypePositiveBitField
case reflect.Uint16:
ft = TypePositiveSmallIntegerField
case reflect.Uint32, reflect.Uint:
......
......@@ -189,6 +189,47 @@ func throwFailNow(t *testing.T, err error, args ...interface{}) {
}
}
func TestSyncDb(t *testing.T) {
RegisterModel(new(Data), new(DataNull))
RegisterModel(new(User))
RegisterModel(new(Profile))
RegisterModel(new(Post))
RegisterModel(new(Tag))
RegisterModel(new(Comment))
BootStrap()
al := dataBaseCache.getDefault()
db := al.DB
drops := getDbDropSql(al)
for _, query := range drops {
_, err := db.Exec(query)
throwFailNow(t, err, query)
}
tables := getDbCreateSql(al)
for _, query := range tables {
_, err := db.Exec(query)
throwFailNow(t, err, query)
}
modelCache.clean()
}
func TestRegisterModels(t *testing.T) {
RegisterModel(new(Data), new(DataNull))
RegisterModel(new(User))
RegisterModel(new(Profile))
RegisterModel(new(Post))
RegisterModel(new(Tag))
RegisterModel(new(Comment))
BootStrap()
dORM = NewOrm()
}
func TestModelSyntax(t *testing.T) {
user := &User{}
ind := reflect.ValueOf(user).Elem()
......
......@@ -132,4 +132,5 @@ type dbBaser interface {
HasReturningID(*modelInfo, *string) bool
TimeFromDB(*time.Time, *time.Location)
TimeToDB(*time.Time, *time.Location)
DbTypes() map[string]string
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment