Commit 6497f29e authored by astaxie's avatar astaxie

version 1.1.2 release

parents 443aaadc 1705b425
.DS_Store
*.swp
*.swo
......@@ -2,6 +2,7 @@ package beego
import (
"net/http"
"os"
"path"
"path/filepath"
"strconv"
......@@ -12,7 +13,7 @@ import (
)
// beego web framework version.
const VERSION = "1.1.1"
const VERSION = "1.1.2"
type hookfunc func() error //hook function to run
var hooks []hookfunc //hook function slice to store the hookfunc
......@@ -174,6 +175,16 @@ func AddAPPStartHook(hf hookfunc) {
// Run beego application.
// it's alias of App.Run.
func Run() {
initBeforeHttpRun()
if EnableAdmin {
go BeeAdminApp.Run()
}
BeeApp.Run()
}
func initBeforeHttpRun() {
// if AppConfigPath not In the conf/app.conf reParse config
if AppConfigPath != filepath.Join(AppPath, "conf", "app.conf") {
err := ParseConfig()
......@@ -222,12 +233,13 @@ func Run() {
middleware.VERSION = VERSION
middleware.AppName = AppName
middleware.RegisterErrorHandler()
}
if EnableAdmin {
go BeeAdminApp.Run()
}
BeeApp.Run()
func TestBeegoInit(apppath string) {
AppPath = apppath
AppConfigPath = filepath.Join(AppPath, "conf", "app.conf")
os.Chdir(AppPath)
initBeforeHttpRun()
}
func init() {
......
......@@ -11,12 +11,14 @@ import (
"github.com/astaxie/beego/config"
"github.com/astaxie/beego/logs"
"github.com/astaxie/beego/session"
"github.com/astaxie/beego/utils"
)
var (
BeeApp *App // beego application
AppName string
AppPath string
workPath string
AppConfigPath string
StaticDir map[string]string
TemplateCache map[string]*template.Template // template caching map
......@@ -58,15 +60,28 @@ var (
EnableAdmin bool // flag of enable admin module to log every request info.
AdminHttpAddr string // http server configurations for admin module.
AdminHttpPort int
FlashName string // name of the flash variable found in response header and cookie
FlashSeperator string // used to seperate flash key:value
)
func init() {
// create beego application
BeeApp = NewApp()
workPath, _ = os.Getwd()
workPath, _ = filepath.Abs(workPath)
// initialize default configurations
AppPath, _ = filepath.Abs(filepath.Dir(os.Args[0]))
os.Chdir(AppPath)
AppConfigPath = filepath.Join(AppPath, "conf", "app.conf")
if workPath != AppPath {
if utils.FileExists(AppConfigPath) {
os.Chdir(AppPath)
} else {
AppConfigPath = filepath.Join(workPath, "conf", "app.conf")
}
}
StaticDir = make(map[string]string)
StaticDir["/static"] = "static"
......@@ -105,8 +120,6 @@ func init() {
EnableGzip = false
AppConfigPath = filepath.Join(AppPath, "conf", "app.conf")
HttpServerTimeOut = 0
ErrorsShow = true
......@@ -123,6 +136,9 @@ func init() {
AdminHttpAddr = "127.0.0.1"
AdminHttpPort = 8088
FlashName = "BEEGO_FLASH"
FlashSeperator = "BEEGOFLASH"
runtime.GOMAXPROCS(runtime.NumCPU())
// init BeeLogger
......@@ -271,6 +287,14 @@ func ParseConfig() (err error) {
BeegoServerName = serverName
}
if flashname := AppConfig.String("FlashName"); flashname != "" {
FlashName = flashname
}
if flashseperator := AppConfig.String("FlashSeperator"); flashseperator != "" {
FlashSeperator = flashseperator
}
if sd := AppConfig.String("StaticDir"); sd != "" {
for k := range StaticDir {
delete(StaticDir, k)
......
package beego
import (
"testing"
)
func TestDefaults(t *testing.T) {
if FlashName != "BEEGO_FLASH" {
t.Errorf("FlashName was not set to default.")
}
if FlashSeperator != "BEEGOFLASH" {
t.Errorf("FlashName was not set to default.")
}
}
......@@ -4,6 +4,7 @@ import (
"bytes"
"io/ioutil"
"net/http"
"reflect"
"strconv"
"strings"
......@@ -13,11 +14,13 @@ import (
// BeegoInput operates the http request header ,data ,cookie and body.
// it also contains router params and current session.
type BeegoInput struct {
CruSession session.SessionStore
Params map[string]string
Data map[interface{}]interface{} // store some values in this context when calling context in filter or controller.
Request *http.Request
RequestBody []byte
CruSession session.SessionStore
Params map[string]string
Data map[interface{}]interface{} // store some values in this context when calling context in filter or controller.
Request *http.Request
RequestBody []byte
RunController reflect.Type
RunMethod string
}
// NewInput return BeegoInput generated by http.Request.
......
......@@ -62,7 +62,6 @@ type ControllerInterface interface {
// Init generates default values of controller operations.
func (c *Controller) Init(ctx *context.Context, controllerName, actionName string, app interface{}) {
c.Data = make(map[interface{}]interface{})
c.Layout = ""
c.TplNames = ""
c.controllerName = controllerName
......
......@@ -6,9 +6,6 @@ import (
"strings"
)
// the separation string when encoding flash data.
const BEEGO_FLASH_SEP = "#BEEGOFLASH#"
// FlashData is a tools to maintain data when using across request.
type FlashData struct {
Data map[string]string
......@@ -54,29 +51,27 @@ func (fd *FlashData) Store(c *Controller) {
c.Data["flash"] = fd.Data
var flashValue string
for key, value := range fd.Data {
flashValue += "\x00" + key + BEEGO_FLASH_SEP + value + "\x00"
flashValue += "\x00" + key + "\x23" + FlashSeperator + "\x23" + value + "\x00"
}
c.Ctx.SetCookie("BEEGO_FLASH", url.QueryEscape(flashValue), 0, "/")
c.Ctx.SetCookie(FlashName, url.QueryEscape(flashValue), 0, "/")
}
// ReadFromRequest parsed flash data from encoded values in cookie.
func ReadFromRequest(c *Controller) *FlashData {
flash := &FlashData{
Data: make(map[string]string),
}
if cookie, err := c.Ctx.Request.Cookie("BEEGO_FLASH"); err == nil {
flash := NewFlash()
if cookie, err := c.Ctx.Request.Cookie(FlashName); err == nil {
v, _ := url.QueryUnescape(cookie.Value)
vals := strings.Split(v, "\x00")
for _, v := range vals {
if len(v) > 0 {
kv := strings.Split(v, BEEGO_FLASH_SEP)
kv := strings.Split(v, FlashSeperator)
if len(kv) == 2 {
flash.Data[kv[0]] = kv[1]
}
}
}
//read one time then delete it
c.Ctx.SetCookie("BEEGO_FLASH", "", -1, "/")
c.Ctx.SetCookie(FlashName, "", -1, "/")
}
c.Data["flash"] = flash.Data
return flash
......
package beego
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
type TestFlashController struct {
Controller
}
func (this *TestFlashController) TestWriteFlash() {
flash := NewFlash()
flash.Notice("TestFlashString")
flash.Store(&this.Controller)
// we choose to serve json because we don't want to load a template html file
this.ServeJson(true)
}
func TestFlashHeader(t *testing.T) {
// create fake GET request
r, _ := http.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
// setup the handler
handler := NewControllerRegistor()
handler.Add("/", &TestFlashController{}, "get:TestWriteFlash")
handler.ServeHTTP(w, r)
// get the Set-Cookie value
sc := w.Header().Get("Set-Cookie")
// match for the expected header
res := strings.Contains(sc, "BEEGO_FLASH=%00notice%23BEEGOFLASH%23TestFlashString%00")
// validate the assertion
if res != true {
t.Errorf("TestFlashHeader() unable to validate flash message")
}
}
......@@ -22,12 +22,21 @@ func SetLevel(l int) {
BeeLogger.SetLevel(l)
}
func SetLogFuncCall(b bool) {
BeeLogger.EnableFuncCallDepth(b)
BeeLogger.SetLogFuncCallDepth(3)
}
// logger references the used application logger.
var BeeLogger *logs.BeeLogger
// SetLogger sets a new logger.
func SetLogger(adaptername string, config string) {
BeeLogger.SetLogger(adaptername, config)
func SetLogger(adaptername string, config string) error {
err := BeeLogger.SetLogger(adaptername, config)
if err != nil {
return err
}
return nil
}
// Trace logs a message at trace level.
......
......@@ -6,6 +6,7 @@ import (
func TestConsole(t *testing.T) {
log := NewLogger(10000)
log.EnableFuncCallDepth(true)
log.SetLogger("console", "")
log.Trace("trace")
log.Info("info")
......@@ -23,6 +24,7 @@ func TestConsole(t *testing.T) {
func BenchmarkConsole(b *testing.B) {
log := NewLogger(10000)
log.EnableFuncCallDepth(true)
log.SetLogger("console", "")
for i := 0; i < b.N; i++ {
log.Trace("trace")
......
......@@ -97,12 +97,12 @@ func (w *FileLogWriter) Init(jsonconfig string) error {
if len(w.Filename) == 0 {
return errors.New("jsonconfig must have filename")
}
err = w.StartLogger()
err = w.startLogger()
return err
}
// start file logger. create log file and set to locker-inside file writer.
func (w *FileLogWriter) StartLogger() error {
func (w *FileLogWriter) startLogger() error {
fd, err := w.createLogFile()
if err != nil {
return err
......@@ -199,7 +199,7 @@ func (w *FileLogWriter) DoRotate() error {
}
// re-start logger
err = w.StartLogger()
err = w.startLogger()
if err != nil {
return fmt.Errorf("Rotate StartLogger: %s\n", err)
}
......
......@@ -2,6 +2,8 @@ package logs
import (
"fmt"
"path"
"runtime"
"sync"
)
......@@ -43,10 +45,12 @@ func Register(name string, log loggerType) {
// BeeLogger is default logger in beego application.
// it can contain several providers and log message into all providers.
type BeeLogger struct {
lock sync.Mutex
level int
msg chan *logMsg
outputs map[string]LoggerInterface
lock sync.Mutex
level int
enableFuncCallDepth bool
loggerFuncCallDepth int
msg chan *logMsg
outputs map[string]LoggerInterface
}
type logMsg struct {
......@@ -59,10 +63,11 @@ type logMsg struct {
// if the buffering chan is full, logger adapters write to file or other way.
func NewLogger(channellen int64) *BeeLogger {
bl := new(BeeLogger)
bl.loggerFuncCallDepth = 2
bl.msg = make(chan *logMsg, channellen)
bl.outputs = make(map[string]LoggerInterface)
//bl.SetLogger("console", "") // default output to console
go bl.StartLogger()
go bl.startLogger()
return bl
}
......@@ -73,7 +78,10 @@ func (bl *BeeLogger) SetLogger(adaptername string, config string) error {
defer bl.lock.Unlock()
if log, ok := adapters[adaptername]; ok {
lg := log()
lg.Init(config)
err := lg.Init(config)
if err != nil {
return err
}
bl.outputs[adaptername] = lg
return nil
} else {
......@@ -100,7 +108,17 @@ func (bl *BeeLogger) writerMsg(loglevel int, msg string) error {
}
lm := new(logMsg)
lm.level = loglevel
lm.msg = msg
if bl.enableFuncCallDepth {
_, file, line, ok := runtime.Caller(bl.loggerFuncCallDepth)
if ok {
_, filename := path.Split(file)
lm.msg = fmt.Sprintf("[%s:%d] %s", filename, line, msg)
} else {
lm.msg = msg
}
} else {
lm.msg = msg
}
bl.msg <- lm
return nil
}
......@@ -111,9 +129,19 @@ func (bl *BeeLogger) SetLevel(l int) {
bl.level = l
}
// set log funcCallDepth
func (bl *BeeLogger) SetLogFuncCallDepth(d int) {
bl.loggerFuncCallDepth = d
}
// enable log funcCallDepth
func (bl *BeeLogger) EnableFuncCallDepth(b bool) {
bl.enableFuncCallDepth = b
}
// start logger chan reading.
// when chan is full, write logs.
func (bl *BeeLogger) StartLogger() {
func (bl *BeeLogger) startLogger() {
for {
select {
case bm := <-bl.msg:
......
......@@ -51,9 +51,16 @@ outFor:
continue
}
switch v := arg.(type) {
case []byte:
case string:
kind := val.Kind()
if kind == reflect.Ptr {
val = val.Elem()
kind = val.Kind()
arg = val.Interface()
}
switch kind {
case reflect.String:
v := val.String()
if fi != nil {
if fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField {
var t time.Time
......@@ -78,61 +85,66 @@ outFor:
}
}
arg = v
case time.Time:
if fi != nil && fi.fieldType == TypeDateField {
arg = v.In(tz).Format(format_Date)
} else {
arg = v.In(tz).Format(format_DateTime)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
arg = val.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
arg = val.Uint()
case reflect.Float32:
arg, _ = StrTo(ToStr(arg)).Float64()
case reflect.Float64:
arg = val.Float()
case reflect.Bool:
arg = val.Bool()
case reflect.Slice, reflect.Array:
if _, ok := arg.([]byte); ok {
continue outFor
}
default:
kind := val.Kind()
switch kind {
case reflect.Slice, reflect.Array:
var args []interface{}
for i := 0; i < val.Len(); i++ {
v := val.Index(i)
var vu interface{}
if v.CanInterface() {
vu = v.Interface()
}
if vu == nil {
continue
}
var args []interface{}
for i := 0; i < val.Len(); i++ {
v := val.Index(i)
args = append(args, vu)
var vu interface{}
if v.CanInterface() {
vu = v.Interface()
}
if len(args) > 0 {
p := getFlatParams(fi, args, tz)
params = append(params, p...)
if vu == nil {
continue
}
continue outFor
case reflect.Ptr, reflect.Struct:
ind := reflect.Indirect(val)
args = append(args, vu)
}
if ind.Kind() == reflect.Struct {
typ := ind.Type()
name := getFullName(typ)
var value interface{}
if mmi, ok := modelCache.getByFN(name); ok {
if _, vu, exist := getExistPk(mmi, ind); exist {
value = vu
}
if len(args) > 0 {
p := getFlatParams(fi, args, tz)
params = append(params, p...)
}
continue outFor
case reflect.Struct:
if v, ok := arg.(time.Time); ok {
if fi != nil && fi.fieldType == TypeDateField {
arg = v.In(tz).Format(format_Date)
} else {
arg = v.In(tz).Format(format_DateTime)
}
} else {
typ := val.Type()
name := getFullName(typ)
var value interface{}
if mmi, ok := modelCache.getByFN(name); ok {
if _, vu, exist := getExistPk(mmi, val); exist {
value = vu
}
arg = value
}
arg = value
if arg == nil {
panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name))
}
} else {
arg = ind.Interface()
if arg == nil {
panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name))
}
}
}
params = append(params, arg)
}
return
......
......@@ -144,6 +144,45 @@ type DataNull struct {
NullInt64 sql.NullInt64 `orm:"null"`
}
type String string
type Boolean bool
type Byte byte
type Rune rune
type Int int
type Int8 int8
type Int16 int16
type Int32 int32
type Int64 int64
type Uint uint
type Uint8 uint8
type Uint16 uint16
type Uint32 uint32
type Uint64 uint64
type Float32 float64
type Float64 float64
type DataCustom struct {
Id int
Boolean Boolean
Char string `orm:"size(50)"`
Text string `orm:"type(text)"`
Byte Byte
Rune Rune
Int Int
Int8 Int8
Int16 Int16
Int32 Int32
Int64 Int64
Uint Uint
Uint8 Uint8
Uint16 Uint16
Uint32 Uint32
Uint64 Uint64
Float32 Float32
Float64 Float64
Decimal Float64 `orm:"digits(8);decimals(4)"`
}
// only for mysql
type UserBig struct {
Id uint64
......@@ -155,7 +194,7 @@ type User struct {
UserName string `orm:"size(30);unique"`
Email string `orm:"size(100)"`
Password string `orm:"size(100)"`
Status int16
Status int16 `orm:"column(Status)"`
IsStaff bool
IsActive bool `orm:"default(1)"`
Created time.Time `orm:"auto_now_add;type(date)"`
......
......@@ -80,7 +80,6 @@ func getTableUnique(val reflect.Value) [][]string {
// get snaked column name
func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col string) string {
col = strings.ToLower(col)
column := col
if col == "" {
column = snakeString(sf.Name)
......@@ -99,34 +98,41 @@ func getColumnName(ft int, addrField reflect.Value, sf reflect.StructField, col
// return field type as type constant from reflect.Value
func getFieldType(val reflect.Value) (ft int, err error) {
elm := reflect.Indirect(val)
switch elm.Interface().(type) {
case int8:
switch elm.Kind() {
case reflect.Int8:
ft = TypeBitField
case int16:
case reflect.Int16:
ft = TypeSmallIntegerField
case int32, int:
case reflect.Int32, reflect.Int:
ft = TypeIntegerField
case int64, sql.NullInt64:
case reflect.Int64:
ft = TypeBigIntegerField
case uint8:
case reflect.Uint8:
ft = TypePositiveBitField
case uint16:
case reflect.Uint16:
ft = TypePositiveSmallIntegerField
case uint32, uint:
case reflect.Uint32, reflect.Uint:
ft = TypePositiveIntegerField
case uint64:
case reflect.Uint64:
ft = TypePositiveBigIntegerField
case float32, float64, sql.NullFloat64:
case reflect.Float32, reflect.Float64:
ft = TypeFloatField
case bool, sql.NullBool:
case reflect.Bool:
ft = TypeBooleanField
case string, sql.NullString:
case reflect.String:
ft = TypeCharField
default:
if elm.CanInterface() {
if _, ok := elm.Interface().(time.Time); ok {
ft = TypeDateTimeField
}
switch elm.Interface().(type) {
case sql.NullInt64:
ft = TypeBigIntegerField
case sql.NullFloat64:
ft = TypeFloatField
case sql.NullBool:
ft = TypeBooleanField
case sql.NullString:
ft = TypeCharField
case time.Time:
ft = TypeDateTimeField
}
}
if ft&IsFieldType == 0 {
......
......@@ -149,7 +149,7 @@ func TestGetDB(t *testing.T) {
}
func TestSyncDb(t *testing.T) {
RegisterModel(new(Data), new(DataNull))
RegisterModel(new(Data), new(DataNull), new(DataCustom))
RegisterModel(new(User))
RegisterModel(new(Profile))
RegisterModel(new(Post))
......@@ -165,7 +165,7 @@ func TestSyncDb(t *testing.T) {
}
func TestRegisterModels(t *testing.T) {
RegisterModel(new(Data), new(DataNull))
RegisterModel(new(Data), new(DataNull), new(DataCustom))
RegisterModel(new(User))
RegisterModel(new(Profile))
RegisterModel(new(Post))
......@@ -309,6 +309,39 @@ func TestNullDataTypes(t *testing.T) {
throwFail(t, AssertIs(d.NullFloat64.Float64, 42.42))
}
func TestDataCustomTypes(t *testing.T) {
d := DataCustom{}
ind := reflect.Indirect(reflect.ValueOf(&d))
for name, value := range Data_Values {
e := ind.FieldByName(name)
if !e.IsValid() {
continue
}
e.Set(reflect.ValueOf(value).Convert(e.Type()))
}
id, err := dORM.Insert(&d)
throwFail(t, err)
throwFail(t, AssertIs(id, 1))
d = DataCustom{Id: 1}
err = dORM.Read(&d)
throwFail(t, err)
ind = reflect.Indirect(reflect.ValueOf(&d))
for name, value := range Data_Values {
e := ind.FieldByName(name)
if !e.IsValid() {
continue
}
vu := e.Interface()
value = reflect.ValueOf(value).Convert(e.Type()).Interface()
throwFail(t, AssertIs(vu == value, true), value, vu)
}
}
func TestCRUD(t *testing.T) {
profile := NewProfile()
profile.Age = 30
......@@ -562,6 +595,10 @@ func TestOperators(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name__exact", String("slene")).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
num, err = qs.Filter("user_name__exact", "slene").Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
......@@ -602,11 +639,11 @@ func TestOperators(t *testing.T) {
throwFail(t, err)
throwFail(t, AssertIs(num, 3))
num, err = qs.Filter("status__lt", 3).Count()
num, err = qs.Filter("status__lt", Uint(3)).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 2))
num, err = qs.Filter("status__lte", 3).Count()
num, err = qs.Filter("status__lte", Int(3)).Count()
throwFail(t, err)
throwFail(t, AssertIs(num, 3))
......@@ -1380,7 +1417,7 @@ func TestRawQueryRow(t *testing.T) {
)
cols = []string{
"id", "status", "profile_id",
"id", "Status", "profile_id",
}
query = fmt.Sprintf("SELECT %s%s%s FROM %suser%s WHERE id = ?", Q, strings.Join(cols, sep), Q, Q, Q)
err = dORM.Raw(query, 4).QueryRow(&uid, &status, &pid)
......@@ -1460,7 +1497,7 @@ func TestRawValues(t *testing.T) {
Q := dDbBaser.TableQuote()
var maps []Params
query := fmt.Sprintf("SELECT %suser_name%s FROM %suser%s WHERE %sstatus%s = ?", Q, Q, Q, Q, Q, Q)
query := fmt.Sprintf("SELECT %suser_name%s FROM %suser%s WHERE %sStatus%s = ?", Q, Q, Q, Q, Q, Q)
num, err := dORM.Raw(query, 1).Values(&maps)
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
......
......@@ -44,6 +44,11 @@ var (
"GetControllerAndAction"}
)
// To append a slice's value into "exceptMethod", for controller's methods shouldn't reflect to AutoRouter
func ExceptMethodAppend(action string) {
exceptMethod = append(exceptMethod, action)
}
type controllerInfo struct {
pattern string
regex *regexp.Regexp
......@@ -621,29 +626,37 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
context.Input.Body()
}
if context.Input.RunController != nil && context.Input.RunMethod != "" {
findrouter = true
runMethod = context.Input.RunMethod
runrouter = context.Input.RunController
}
//first find path from the fixrouters to Improve Performance
for _, route := range p.fixrouters {
n := len(requestPath)
if requestPath == route.pattern {
runMethod = p.getRunMethod(r.Method, context, route)
if runMethod != "" {
runrouter = route.controllerType
findrouter = true
break
if !findrouter {
for _, route := range p.fixrouters {
n := len(requestPath)
if requestPath == route.pattern {
runMethod = p.getRunMethod(r.Method, context, route)
if runMethod != "" {
runrouter = route.controllerType
findrouter = true
break
}
}
}
// pattern /admin url /admin 200 /admin/ 200
// pattern /admin/ url /admin 301 /admin/ 200
if requestPath[n-1] != '/' && requestPath+"/" == route.pattern {
http.Redirect(w, r, requestPath+"/", 301)
goto Admin
}
if requestPath[n-1] == '/' && route.pattern+"/" == requestPath {
runMethod = p.getRunMethod(r.Method, context, route)
if runMethod != "" {
runrouter = route.controllerType
findrouter = true
break
// pattern /admin url /admin 200 /admin/ 200
// pattern /admin/ url /admin 301 /admin/ 200
if requestPath[n-1] != '/' && requestPath+"/" == route.pattern {
http.Redirect(w, r, requestPath+"/", 301)
goto Admin
}
if requestPath[n-1] == '/' && route.pattern+"/" == requestPath {
runMethod = p.getRunMethod(r.Method, context, route)
if runMethod != "" {
runrouter = route.controllerType
findrouter = true
break
}
}
}
}
......
......@@ -118,6 +118,7 @@ func (pder *CookieProvider) SessionInit(maxlifetime int64, config string) error
if err != nil {
return err
}
pder.maxlifetime = maxlifetime
return nil
}
......
package session
/*
beego session provider for postgresql
-------------------------------------
depends on github.com/lib/pq:
go install github.com/lib/pq
needs this table in your database:
CREATE TABLE session (
session_key char(64) NOT NULL,
session_data bytea,
session_expiry timestamp NOT NULL,
CONSTRAINT session_key PRIMARY KEY(session_key)
);
will be activated with these settings in app.conf:
SessionOn = true
SessionProvider = postgresql
SessionSavePath = "user=a password=b dbname=c sslmode=disable"
SessionName = session
*/
import (
"database/sql"
"net/http"
"sync"
"time"
_ "github.com/lib/pq"
)
var postgresqlpder = &PostgresqlProvider{}
// postgresql session store
type PostgresqlSessionStore struct {
c *sql.DB
sid string
lock sync.RWMutex
values map[interface{}]interface{}
}
// set value in postgresql session.
// it is temp value in map.
func (st *PostgresqlSessionStore) Set(key, value interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
st.values[key] = value
return nil
}
// get value from postgresql session
func (st *PostgresqlSessionStore) Get(key interface{}) interface{} {
st.lock.RLock()
defer st.lock.RUnlock()
if v, ok := st.values[key]; ok {
return v
} else {
return nil
}
return nil
}
// delete value in postgresql session
func (st *PostgresqlSessionStore) Delete(key interface{}) error {
st.lock.Lock()
defer st.lock.Unlock()
delete(st.values, key)
return nil
}
// clear all values in postgresql session
func (st *PostgresqlSessionStore) Flush() error {
st.lock.Lock()
defer st.lock.Unlock()
st.values = make(map[interface{}]interface{})
return nil
}
// get session id of this postgresql session store
func (st *PostgresqlSessionStore) SessionID() string {
return st.sid
}
// save postgresql session values to database.
// must call this method to save values to database.
func (st *PostgresqlSessionStore) SessionRelease(w http.ResponseWriter) {
defer st.c.Close()
b, err := encodeGob(st.values)
if err != nil {
return
}
st.c.Exec("UPDATE session set session_data=$1, session_expiry=$2 where session_key=$3",
b, time.Now().Format(time.RFC3339), st.sid)
}
// postgresql session provider
type PostgresqlProvider struct {
maxlifetime int64
savePath string
}
// connect to postgresql
func (mp *PostgresqlProvider) connectInit() *sql.DB {
db, e := sql.Open("postgres", mp.savePath)
if e != nil {
return nil
}
return db
}
// init postgresql session.
// savepath is the connection string of postgresql.
func (mp *PostgresqlProvider) SessionInit(maxlifetime int64, savePath string) error {
mp.maxlifetime = maxlifetime
mp.savePath = savePath
return nil
}
// get postgresql session by sid
func (mp *PostgresqlProvider) SessionRead(sid string) (SessionStore, error) {
c := mp.connectInit()
row := c.QueryRow("select session_data from session where session_key=$1", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
_, err = c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)",
sid, "", time.Now().Format(time.RFC3339))
if err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}
var kv map[interface{}]interface{}
if len(sessiondata) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = decodeGob(sessiondata)
if err != nil {
return nil, err
}
}
rs := &PostgresqlSessionStore{c: c, sid: sid, values: kv}
return rs, nil
}
// check postgresql session exist
func (mp *PostgresqlProvider) SessionExist(sid string) bool {
c := mp.connectInit()
defer c.Close()
row := c.QueryRow("select session_data from session where session_key=$1", sid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
return false
} else {
return true
}
}
// generate new sid for postgresql session
func (mp *PostgresqlProvider) SessionRegenerate(oldsid, sid string) (SessionStore, error) {
c := mp.connectInit()
row := c.QueryRow("select session_data from session where session_key=$1", oldsid)
var sessiondata []byte
err := row.Scan(&sessiondata)
if err == sql.ErrNoRows {
c.Exec("insert into session(session_key,session_data,session_expiry) values($1,$2,$3)",
oldsid, "", time.Now().Format(time.RFC3339))
}
c.Exec("update session set session_key=$1 where session_key=$2", sid, oldsid)
var kv map[interface{}]interface{}
if len(sessiondata) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = decodeGob(sessiondata)
if err != nil {
return nil, err
}
}
rs := &PostgresqlSessionStore{c: c, sid: sid, values: kv}
return rs, nil
}
// delete postgresql session by sid
func (mp *PostgresqlProvider) SessionDestroy(sid string) error {
c := mp.connectInit()
c.Exec("DELETE FROM session where session_key=$1", sid)
c.Close()
return nil
}
// delete expired values in postgresql session
func (mp *PostgresqlProvider) SessionGC() {
c := mp.connectInit()
c.Exec("DELETE from session where EXTRACT(EPOCH FROM (current_timestamp - session_expiry)) > $1", mp.maxlifetime)
c.Close()
return
}
// count values in postgresql session
func (mp *PostgresqlProvider) SessionAll() int {
c := mp.connectInit()
defer c.Close()
var total int
err := c.QueryRow("SELECT count(*) as num from session").Scan(&total)
if err != nil {
return 0
}
return total
}
func init() {
Register("postgresql", postgresqlpder)
}
......@@ -151,7 +151,7 @@ func getTplDeep(root, file, parent string, t *template.Template) (*template.Temp
fileabspath = filepath.Join(root, file)
}
if e := utils.FileExists(fileabspath); !e {
panic("can't find template file" + file)
panic("can't find template file:" + file)
}
data, err := ioutil.ReadFile(fileabspath)
if err != nil {
......
......@@ -67,7 +67,7 @@ func (r Required) IsSatisfied(obj interface{}) bool {
}
func (r Required) DefaultMessage() string {
return "Required"
return fmt.Sprint(MessageTmpls["Required"])
}
func (r Required) GetKey() string {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment