Commit 323a1c42 authored by astaxie's avatar astaxie Committed by GitHub

Merge pull request #2485 from astaxie/develop

beego 1.8.0 
parents b55e20ac c5f838e7
.idea .idea
.vscode
.DS_Store .DS_Store
*.swp *.swp
*.swo *.swo
......
...@@ -31,6 +31,8 @@ install: ...@@ -31,6 +31,8 @@ install:
- go get github.com/siddontang/ledisdb/config - go get github.com/siddontang/ledisdb/config
- go get github.com/siddontang/ledisdb/ledis - go get github.com/siddontang/ledisdb/ledis
- go get github.com/ssdb/gossdb/ssdb - go get github.com/ssdb/gossdb/ssdb
- go get github.com/cloudflare/golz4
- go get github.com/gogo/protobuf/proto
before_script: before_script:
- psql --version - psql --version
- sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi" - sh -c "if [ '$ORM_DRIVER' = 'postgres' ]; then psql -c 'create database orm_test;' -U postgres; fi"
......
...@@ -23,7 +23,7 @@ import ( ...@@ -23,7 +23,7 @@ import (
const ( const (
// VERSION represent beego web framework version. // VERSION represent beego web framework version.
VERSION = "1.7.2" VERSION = "1.8.0"
// DEV is for develop // DEV is for develop
DEV = "dev" DEV = "dev"
......
...@@ -22,6 +22,7 @@ import ( ...@@ -22,6 +22,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"reflect" "reflect"
...@@ -222,33 +223,13 @@ func exists(path string) (bool, error) { ...@@ -222,33 +223,13 @@ func exists(path string) (bool, error) {
// FileGetContents Get bytes to file. // FileGetContents Get bytes to file.
// if non-exist, create this file. // if non-exist, create this file.
func FileGetContents(filename string) (data []byte, e error) { func FileGetContents(filename string) (data []byte, e error) {
f, e := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, os.ModePerm) return ioutil.ReadFile(filename)
if e != nil {
return
}
defer f.Close()
stat, e := f.Stat()
if e != nil {
return
}
data = make([]byte, stat.Size())
result, e := f.Read(data)
if e != nil || int64(result) != stat.Size() {
return nil, e
}
return
} }
// FilePutContents Put bytes to file. // FilePutContents Put bytes to file.
// if non-exist, create this file. // if non-exist, create this file.
func FilePutContents(filename string, content []byte) error { func FilePutContents(filename string, content []byte) error {
fp, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, os.ModePerm) return ioutil.WriteFile(filename, content, os.ModePerm)
if err != nil {
return err
}
defer fp.Close()
_, err = fp.Write(content)
return err
} }
// GobEncode Gob encodes file cache item. // GobEncode Gob encodes file cache item.
......
...@@ -152,7 +152,7 @@ func (rc *Cache) IsExist(key string) bool { ...@@ -152,7 +152,7 @@ func (rc *Cache) IsExist(key string) bool {
if err != nil { if err != nil {
return false return false
} }
if resp[1] == "1" { if len(resp) == 2 && resp[1] == "1" {
return true return true
} }
return false return false
......
...@@ -41,6 +41,7 @@ type Config struct { ...@@ -41,6 +41,7 @@ type Config struct {
EnableGzip bool EnableGzip bool
MaxMemory int64 MaxMemory int64
EnableErrorsShow bool EnableErrorsShow bool
EnableErrorsRender bool
Listen Listen Listen Listen
WebConfig WebConfig WebConfig WebConfig
Log LogConfig Log LogConfig
...@@ -144,9 +145,6 @@ func init() { ...@@ -144,9 +145,6 @@ func init() {
if err = parseConfig(appConfigPath); err != nil { if err = parseConfig(appConfigPath); err != nil {
panic(err) panic(err)
} }
if err = os.Chdir(AppPath); err != nil {
panic(err)
}
} }
func recoverPanic(ctx *context.Context) { func recoverPanic(ctx *context.Context) {
...@@ -174,7 +172,7 @@ func recoverPanic(ctx *context.Context) { ...@@ -174,7 +172,7 @@ func recoverPanic(ctx *context.Context) {
logs.Critical(fmt.Sprintf("%s:%d", file, line)) logs.Critical(fmt.Sprintf("%s:%d", file, line))
stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line)) stack = stack + fmt.Sprintln(fmt.Sprintf("%s:%d", file, line))
} }
if BConfig.RunMode == DEV { if BConfig.RunMode == DEV && BConfig.EnableErrorsRender {
showErr(err, ctx, stack) showErr(err, ctx, stack)
} }
} }
...@@ -192,6 +190,7 @@ func newBConfig() *Config { ...@@ -192,6 +190,7 @@ func newBConfig() *Config {
EnableGzip: false, EnableGzip: false,
MaxMemory: 1 << 26, //64MB MaxMemory: 1 << 26, //64MB
EnableErrorsShow: true, EnableErrorsShow: true,
EnableErrorsRender: true,
Listen: Listen{ Listen: Listen{
Graceful: false, Graceful: false,
ServerTimeOut: 0, ServerTimeOut: 0,
...@@ -257,6 +256,9 @@ func parseConfig(appConfigPath string) (err error) { ...@@ -257,6 +256,9 @@ func parseConfig(appConfigPath string) (err error) {
} }
func assignConfig(ac config.Configer) error { func assignConfig(ac config.Configer) error {
for _, i := range []interface{}{BConfig, &BConfig.Listen, &BConfig.WebConfig, &BConfig.Log, &BConfig.WebConfig.Session} {
assignSingleConfig(i, ac)
}
// set the run mode first // set the run mode first
if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" { if envRunMode := os.Getenv("BEEGO_RUNMODE"); envRunMode != "" {
BConfig.RunMode = envRunMode BConfig.RunMode = envRunMode
...@@ -264,10 +266,6 @@ func assignConfig(ac config.Configer) error { ...@@ -264,10 +266,6 @@ func assignConfig(ac config.Configer) error {
BConfig.RunMode = runMode BConfig.RunMode = runMode
} }
for _, i := range []interface{}{BConfig, &BConfig.Listen, &BConfig.WebConfig, &BConfig.Log, &BConfig.WebConfig.Session} {
assignSingleConfig(i, ac)
}
if sd := ac.String("StaticDir"); sd != "" { if sd := ac.String("StaticDir"); sd != "" {
BConfig.WebConfig.StaticDir = map[string]string{} BConfig.WebConfig.StaticDir = map[string]string{}
sds := strings.Fields(sd) sds := strings.Fields(sd)
......
// Copyright 2014 beego Author. All Rights Reserved.
// Copyright 2017 Faissal Elamraoui. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package env
import (
"fmt"
"os"
"strings"
"github.com/astaxie/beego/utils"
)
var env *utils.BeeMap
func init() {
env = utils.NewBeeMap()
for _, e := range os.Environ() {
splits := strings.Split(e, "=")
env.Set(splits[0], os.Getenv(splits[0]))
}
}
// Get returns a value by key.
// If the key does not exist, the default value will be returned.
func Get(key string, defVal string) string {
if val := env.Get(key); val != nil {
return val.(string)
}
return defVal
}
// MustGet returns a value by key.
// If the key does not exist, it will return an error.
func MustGet(key string) (string, error) {
if val := env.Get(key); val != nil {
return val.(string), nil
}
return "", fmt.Errorf("no env variable with %s", key)
}
// Set sets a value in the ENV copy.
// This does not affect the child process environment.
func Set(key string, value string) {
env.Set(key, value)
}
// MustSet sets a value in the ENV copy and the child process environment.
// It returns an error in case the set operation failed.
func MustSet(key string, value string) error {
err := os.Setenv(key, value)
if err != nil {
return err
}
env.Set(key, value)
return nil
}
// GetAll returns all keys/values in the current child process environment.
func GetAll() map[string]string {
items := env.Items()
envs := make(map[string]string, env.Count())
for key, val := range items {
switch key := key.(type) {
case string:
switch val := val.(type) {
case string:
envs[key] = val
}
}
}
return envs
}
// Copyright 2014 beego Author. All Rights Reserved.
// Copyright 2017 Faissal Elamraoui. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package env
import (
"os"
"testing"
)
func TestEnvGet(t *testing.T) {
gopath := Get("GOPATH", "")
if gopath != os.Getenv("GOPATH") {
t.Error("expected GOPATH not empty.")
}
noExistVar := Get("NOEXISTVAR", "foo")
if noExistVar != "foo" {
t.Errorf("expected NOEXISTVAR to equal foo, got %s.", noExistVar)
}
}
func TestEnvMustGet(t *testing.T) {
gopath, err := MustGet("GOPATH")
if err != nil {
t.Error(err)
}
if gopath != os.Getenv("GOPATH") {
t.Errorf("expected GOPATH to be the same, got %s.", gopath)
}
_, err = MustGet("NOEXISTVAR")
if err == nil {
t.Error("expected error to be non-nil")
}
}
func TestEnvSet(t *testing.T) {
Set("MYVAR", "foo")
myVar := Get("MYVAR", "bar")
if myVar != "foo" {
t.Errorf("expected MYVAR to equal foo, got %s.", myVar)
}
}
func TestEnvMustSet(t *testing.T) {
err := MustSet("FOO", "bar")
if err != nil {
t.Error(err)
}
fooVar := os.Getenv("FOO")
if fooVar != "bar" {
t.Errorf("expected FOO variable to equal bar, got %s.", fooVar)
}
}
func TestEnvGetAll(t *testing.T) {
envMap := GetAll()
if len(envMap) == 0 {
t.Error("expected environment not empty.")
}
}
...@@ -18,16 +18,13 @@ import ( ...@@ -18,16 +18,13 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"errors" "errors"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"os" "os"
"path"
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time"
) )
var ( var (
...@@ -52,24 +49,26 @@ func (ini *IniConfig) Parse(name string) (Configer, error) { ...@@ -52,24 +49,26 @@ func (ini *IniConfig) Parse(name string) (Configer, error) {
} }
func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) { func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
file, err := os.Open(name) data, err := ioutil.ReadFile(name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return ini.parseData(filepath.Dir(name), data)
}
func (ini *IniConfig) parseData(dir string, data []byte) (*IniConfigContainer, error) {
cfg := &IniConfigContainer{ cfg := &IniConfigContainer{
file.Name(), data: make(map[string]map[string]string),
make(map[string]map[string]string), sectionComment: make(map[string]string),
make(map[string]string), keyComment: make(map[string]string),
make(map[string]string), RWMutex: sync.RWMutex{},
sync.RWMutex{},
} }
cfg.Lock() cfg.Lock()
defer cfg.Unlock() defer cfg.Unlock()
defer file.Close()
var comment bytes.Buffer var comment bytes.Buffer
buf := bufio.NewReader(file) buf := bufio.NewReader(bytes.NewBuffer(data))
// check the BOM // check the BOM
head, err := buf.Peek(3) head, err := buf.Peek(3)
if err == nil && head[0] == 239 && head[1] == 187 && head[2] == 191 { if err == nil && head[0] == 239 && head[1] == 187 && head[2] == 191 {
...@@ -130,16 +129,20 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) { ...@@ -130,16 +129,20 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
// handle include "other.conf" // handle include "other.conf"
if len(keyValue) == 1 && strings.HasPrefix(key, "include") { if len(keyValue) == 1 && strings.HasPrefix(key, "include") {
includefiles := strings.Fields(key) includefiles := strings.Fields(key)
if includefiles[0] == "include" && len(includefiles) == 2 { if includefiles[0] == "include" && len(includefiles) == 2 {
otherfile := strings.Trim(includefiles[1], "\"") otherfile := strings.Trim(includefiles[1], "\"")
if !filepath.IsAbs(otherfile) { if !filepath.IsAbs(otherfile) {
otherfile = filepath.Join(filepath.Dir(name), otherfile) otherfile = filepath.Join(dir, otherfile)
} }
i, err := ini.parseFile(otherfile) i, err := ini.parseFile(otherfile)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for sec, dt := range i.data { for sec, dt := range i.data {
if _, ok := cfg.data[sec]; !ok { if _, ok := cfg.data[sec]; !ok {
cfg.data[sec] = make(map[string]string) cfg.data[sec] = make(map[string]string)
...@@ -148,12 +151,15 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) { ...@@ -148,12 +151,15 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
cfg.data[sec][k] = v cfg.data[sec][k] = v
} }
} }
for sec, comm := range i.sectionComment { for sec, comm := range i.sectionComment {
cfg.sectionComment[sec] = comm cfg.sectionComment[sec] = comm
} }
for k, comm := range i.keyComment { for k, comm := range i.keyComment {
cfg.keyComment[k] = comm cfg.keyComment[k] = comm
} }
continue continue
} }
} }
...@@ -177,20 +183,18 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) { ...@@ -177,20 +183,18 @@ func (ini *IniConfig) parseFile(name string) (*IniConfigContainer, error) {
} }
// ParseData parse ini the data // ParseData parse ini the data
// When include other.conf,other.conf is either absolute directory
// or under beego in default temporary directory(/tmp/beego).
func (ini *IniConfig) ParseData(data []byte) (Configer, error) { func (ini *IniConfig) ParseData(data []byte) (Configer, error) {
// Save memory data to temporary file dir := filepath.Join(os.TempDir(), "beego")
tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond())) os.MkdirAll(dir, os.ModePerm)
os.MkdirAll(path.Dir(tmpName), os.ModePerm)
if err := ioutil.WriteFile(tmpName, data, 0655); err != nil { return ini.parseData(dir, data)
return nil, err
}
return ini.Parse(tmpName)
} }
// IniConfigContainer A Config represents the ini configuration. // IniConfigContainer A Config represents the ini configuration.
// When set and get value, support key as section:name type. // When set and get value, support key as section:name type.
type IniConfigContainer struct { type IniConfigContainer struct {
filename string
data map[string]map[string]string // section=> key:val data map[string]map[string]string // section=> key:val
sectionComment map[string]string // section : comment sectionComment map[string]string // section : comment
keyComment map[string]string // id: []{comment, key...}; id 1 is for main comment. keyComment map[string]string // id: []{comment, key...}; id 1 is for main comment.
...@@ -297,7 +301,7 @@ func (c *IniConfigContainer) GetSection(section string) (map[string]string, erro ...@@ -297,7 +301,7 @@ func (c *IniConfigContainer) GetSection(section string) (map[string]string, erro
if v, ok := c.data[section]; ok { if v, ok := c.data[section]; ok {
return v, nil return v, nil
} }
return nil, errors.New("not exist setction") return nil, errors.New("not exist section")
} }
// SaveConfigFile save the config into file. // SaveConfigFile save the config into file.
......
...@@ -35,11 +35,9 @@ import ( ...@@ -35,11 +35,9 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"path"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time"
"github.com/astaxie/beego/config" "github.com/astaxie/beego/config"
"github.com/beego/x2j" "github.com/beego/x2j"
...@@ -52,36 +50,26 @@ type Config struct{} ...@@ -52,36 +50,26 @@ type Config struct{}
// Parse returns a ConfigContainer with parsed xml config map. // Parse returns a ConfigContainer with parsed xml config map.
func (xc *Config) Parse(filename string) (config.Configer, error) { func (xc *Config) Parse(filename string) (config.Configer, error) {
file, err := os.Open(filename) context, err := ioutil.ReadFile(filename)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer file.Close()
return xc.ParseData(context)
}
// ParseData xml data
func (xc *Config) ParseData(data []byte) (config.Configer, error) {
x := &ConfigContainer{data: make(map[string]interface{})} x := &ConfigContainer{data: make(map[string]interface{})}
content, err := ioutil.ReadAll(file)
if err != nil {
return nil, err
}
d, err := x2j.DocToMap(string(content)) d, err := x2j.DocToMap(string(data))
if err != nil { if err != nil {
return nil, err return nil, err
} }
x.data = config.ExpandValueEnvForMap(d["config"].(map[string]interface{})) x.data = config.ExpandValueEnvForMap(d["config"].(map[string]interface{}))
return x, nil
}
// ParseData xml data return x, nil
func (xc *Config) ParseData(data []byte) (config.Configer, error) {
// Save memory data to temporary file
tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond()))
os.MkdirAll(path.Dir(tmpName), os.ModePerm)
if err := ioutil.WriteFile(tmpName, data, 0655); err != nil {
return nil, err
}
return xc.Parse(tmpName)
} }
// ConfigContainer A Config represents the xml configuration. // ConfigContainer A Config represents the xml configuration.
......
...@@ -37,10 +37,8 @@ import ( ...@@ -37,10 +37,8 @@ import (
"io/ioutil" "io/ioutil"
"log" "log"
"os" "os"
"path"
"strings" "strings"
"sync" "sync"
"time"
"github.com/astaxie/beego/config" "github.com/astaxie/beego/config"
"github.com/beego/goyaml2" "github.com/beego/goyaml2"
...@@ -63,26 +61,30 @@ func (yaml *Config) Parse(filename string) (y config.Configer, err error) { ...@@ -63,26 +61,30 @@ func (yaml *Config) Parse(filename string) (y config.Configer, err error) {
// ParseData parse yaml data // ParseData parse yaml data
func (yaml *Config) ParseData(data []byte) (config.Configer, error) { func (yaml *Config) ParseData(data []byte) (config.Configer, error) {
// Save memory data to temporary file cnf, err := parseYML(data)
tmpName := path.Join(os.TempDir(), "beego", fmt.Sprintf("%d", time.Now().Nanosecond())) if err != nil {
os.MkdirAll(path.Dir(tmpName), os.ModePerm)
if err := ioutil.WriteFile(tmpName, data, 0655); err != nil {
return nil, err return nil, err
} }
return yaml.Parse(tmpName)
return &ConfigContainer{
data: cnf,
}, nil
} }
// ReadYmlReader Read yaml file to map. // ReadYmlReader Read yaml file to map.
// if json like, use json package, unless goyaml2 package. // if json like, use json package, unless goyaml2 package.
func ReadYmlReader(path string) (cnf map[string]interface{}, err error) { func ReadYmlReader(path string) (cnf map[string]interface{}, err error) {
f, err := os.Open(path) buf, err := ioutil.ReadFile(path)
if err != nil { if err != nil {
return return
} }
defer f.Close()
buf, err := ioutil.ReadAll(f) return parseYML(buf)
if err != nil || len(buf) < 3 { }
// parseYML parse yaml formatted []byte to map.
func parseYML(buf []byte) (cnf map[string]interface{}, err error) {
if len(buf) < 3 {
return return
} }
...@@ -250,7 +252,7 @@ func (c *ConfigContainer) GetSection(section string) (map[string]string, error) ...@@ -250,7 +252,7 @@ func (c *ConfigContainer) GetSection(section string) (map[string]string, error)
if v, ok := c.data[section]; ok { if v, ok := c.data[section]; ok {
return v.(map[string]string), nil return v.(map[string]string), nil
} }
return nil, errors.New("not exist setction") return nil, errors.New("not exist section")
} }
// SaveConfigFile save the config into file // SaveConfigFile save the config into file
......
...@@ -413,7 +413,13 @@ func (input *BeegoInput) Bind(dest interface{}, key string) error { ...@@ -413,7 +413,13 @@ func (input *BeegoInput) Bind(dest interface{}, key string) error {
if !value.CanSet() { if !value.CanSet() {
return errors.New("beego: non-settable variable passed to Bind: " + key) return errors.New("beego: non-settable variable passed to Bind: " + key)
} }
rv := input.bind(key, value.Type()) typ := value.Type()
// Get real type if dest define with interface{}.
// e.g var dest interface{} dest=1.0
if value.Kind() == reflect.Interface {
typ = value.Elem().Type()
}
rv := input.bind(key, typ)
if !rv.IsValid() { if !rv.IsValid() {
return errors.New("beego: reflect value is empty") return errors.New("beego: reflect value is empty")
} }
...@@ -422,6 +428,9 @@ func (input *BeegoInput) Bind(dest interface{}, key string) error { ...@@ -422,6 +428,9 @@ func (input *BeegoInput) Bind(dest interface{}, key string) error {
} }
func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value { func (input *BeegoInput) bind(key string, typ reflect.Type) reflect.Value {
if input.Context.Request.Form == nil {
input.Context.Request.ParseForm()
}
rv := reflect.Zero(typ) rv := reflect.Zero(typ)
switch typ.Kind() { switch typ.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
......
...@@ -15,81 +15,97 @@ ...@@ -15,81 +15,97 @@
package context package context
import ( import (
"fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect" "reflect"
"testing" "testing"
) )
func TestParse(t *testing.T) { func TestBind(t *testing.T) {
r, _ := http.NewRequest("GET", "/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&user.Name=astaxie", nil) type testItem struct {
beegoInput := NewInput() field string
beegoInput.Context = NewContext() empty interface{}
beegoInput.Context.Reset(httptest.NewRecorder(), r) want interface{}
beegoInput.ParseFormOrMulitForm(1 << 20) }
type Human struct {
var id int ID int
err := beegoInput.Bind(&id, "id") Nick string
if id != 123 || err != nil { Pwd string
t.Fatal("id should has int value") Ms bool
} }
fmt.Println(id)
cases := []struct {
var isok bool request string
err = beegoInput.Bind(&isok, "isok") valueGp []testItem
if !isok || err != nil { }{
t.Fatal("isok should be true") {"/?p=str", []testItem{{"p", interface{}(""), interface{}("str")}}},
}
fmt.Println(isok) {"/?p=", []testItem{{"p", "", ""}}},
{"/?p=str", []testItem{{"p", "", "str"}}},
var float float64
err = beegoInput.Bind(&float, "ft") {"/?p=123", []testItem{{"p", 0, 123}}},
if float != 1.2 || err != nil { {"/?p=123", []testItem{{"p", uint(0), uint(123)}}},
t.Fatal("float should be equal to 1.2")
} {"/?p=1.0", []testItem{{"p", 0.0, 1.0}}},
fmt.Println(float) {"/?p=1", []testItem{{"p", false, true}}},
ol := make([]int, 0, 2) {"/?p=true", []testItem{{"p", false, true}}},
err = beegoInput.Bind(&ol, "ol") {"/?p=ON", []testItem{{"p", false, true}}},
if len(ol) != 2 || err != nil || ol[0] != 1 || ol[1] != 2 { {"/?p=on", []testItem{{"p", false, true}}},
t.Fatal("ol should has two elements") {"/?p=1", []testItem{{"p", false, true}}},
} {"/?p=2", []testItem{{"p", false, false}}},
fmt.Println(ol) {"/?p=false", []testItem{{"p", false, false}}},
ul := make([]string, 0, 2) {"/?p[a]=1&p[b]=2&p[c]=3", []testItem{{"p", map[string]int{}, map[string]int{"a": 1, "b": 2, "c": 3}}}},
err = beegoInput.Bind(&ul, "ul") {"/?p[a]=v1&p[b]=v2&p[c]=v3", []testItem{{"p", map[string]string{}, map[string]string{"a": "v1", "b": "v2", "c": "v3"}}}},
if len(ul) != 2 || err != nil || ul[0] != "str" || ul[1] != "array" {
t.Fatal("ul should has two elements") {"/?p[]=8&p[]=9&p[]=10", []testItem{{"p", []int{}, []int{8, 9, 10}}}},
} {"/?p[0]=8&p[1]=9&p[2]=10", []testItem{{"p", []int{}, []int{8, 9, 10}}}},
fmt.Println(ul) {"/?p[0]=8&p[1]=9&p[2]=10&p[5]=14", []testItem{{"p", []int{}, []int{8, 9, 10, 0, 0, 14}}}},
{"/?p[0]=8.0&p[1]=9.0&p[2]=10.0", []testItem{{"p", []float64{}, []float64{8.0, 9.0, 10.0}}}},
type User struct {
Name string {"/?p[]=10&p[]=9&p[]=8", []testItem{{"p", []string{}, []string{"10", "9", "8"}}}},
} {"/?p[0]=8&p[1]=9&p[2]=10", []testItem{{"p", []string{}, []string{"8", "9", "10"}}}},
user := User{}
err = beegoInput.Bind(&user, "user") {"/?p[0]=true&p[1]=false&p[2]=true&p[5]=1&p[6]=ON&p[7]=other", []testItem{{"p", []bool{}, []bool{true, false, true, false, false, true, true, false}}}},
if err != nil || user.Name != "astaxie" {
t.Fatal("user should has name") {"/?human.Nick=astaxie", []testItem{{"human", Human{}, Human{Nick: "astaxie"}}}},
} {"/?human.ID=888&human.Nick=astaxie&human.Ms=true&human[Pwd]=pass", []testItem{{"human", Human{}, Human{ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass"}}}},
fmt.Println(user) {"/?human[0].ID=888&human[0].Nick=astaxie&human[0].Ms=true&human[0][Pwd]=pass01&human[1].ID=999&human[1].Nick=ysqi&human[1].Ms=On&human[1].Pwd=pass02",
} []testItem{{"human", []Human{}, []Human{
Human{ID: 888, Nick: "astaxie", Ms: true, Pwd: "pass01"},
Human{ID: 999, Nick: "ysqi", Ms: true, Pwd: "pass02"},
}}}},
{
"/?id=123&isok=true&ft=1.2&ol[0]=1&ol[1]=2&ul[]=str&ul[]=array&human.Nick=astaxie",
[]testItem{
{"id", 0, 123},
{"isok", false, true},
{"ft", 0.0, 1.2},
{"ol", []int{}, []int{1, 2}},
{"ul", []string{}, []string{"str", "array"}},
{"human", Human{}, Human{Nick: "astaxie"}},
},
},
}
for _, c := range cases {
r, _ := http.NewRequest("GET", c.request, nil)
beegoInput := NewInput()
beegoInput.Context = NewContext()
beegoInput.Context.Reset(httptest.NewRecorder(), r)
for _, item := range c.valueGp {
got := item.empty
err := beegoInput.Bind(&got, item.field)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(got, item.want) {
t.Fatalf("Bind %q error,should be:\n%#v \ngot:\n%#v", item.field, item.want, got)
}
}
func TestParse2(t *testing.T) {
r, _ := http.NewRequest("GET", "/?user[0][Username]=Raph&user[1].Username=Leo&user[0].Password=123456&user[1][Password]=654321", nil)
beegoInput := NewInput()
beegoInput.Context = NewContext()
beegoInput.Context.Reset(httptest.NewRecorder(), r)
beegoInput.ParseFormOrMulitForm(1 << 20)
type User struct {
Username string
Password string
}
var users []User
err := beegoInput.Bind(&users, "user")
fmt.Println(users)
if err != nil || users[0].Username != "Raph" || users[0].Password != "123456" || users[1].Username != "Leo" || users[1].Password != "654321" {
t.Fatal("users info wrong")
} }
} }
......
...@@ -67,6 +67,7 @@ func (output *BeegoOutput) Body(content []byte) error { ...@@ -67,6 +67,7 @@ func (output *BeegoOutput) Body(content []byte) error {
} }
if b, n, _ := WriteBody(encoding, buf, content); b { if b, n, _ := WriteBody(encoding, buf, content); b {
output.Header("Content-Encoding", n) output.Header("Content-Encoding", n)
output.Header("Content-Length", strconv.Itoa(buf.Len()))
} else { } else {
output.Header("Content-Length", strconv.Itoa(len(content))) output.Header("Content-Length", strconv.Itoa(len(content)))
} }
...@@ -330,16 +331,17 @@ func (output *BeegoOutput) IsServerError() bool { ...@@ -330,16 +331,17 @@ func (output *BeegoOutput) IsServerError() bool {
func stringsToJSON(str string) string { func stringsToJSON(str string) string {
rs := []rune(str) rs := []rune(str)
jsons := "" var jsons bytes.Buffer
for _, r := range rs { for _, r := range rs {
rint := int(r) rint := int(r)
if rint < 128 { if rint < 128 {
jsons += string(r) jsons.WriteRune(r)
} else { } else {
jsons += "\\u" + strconv.FormatInt(int64(rint), 16) // json jsons.WriteString("\\u")
jsons.WriteString(strconv.FormatInt(int64(rint), 16))
} }
} }
return jsons return jsons.String()
} }
// Session sets session item value with given key. // Session sets session item value with given key.
......
...@@ -69,6 +69,7 @@ type Controller struct { ...@@ -69,6 +69,7 @@ type Controller struct {
// template data // template data
TplName string TplName string
ViewPath string
Layout string Layout string
LayoutSections map[string]string // the key is the section name and the value is the template name LayoutSections map[string]string // the key is the section name and the value is the template name
TplPrefix string TplPrefix string
...@@ -185,7 +186,11 @@ func (c *Controller) Render() error { ...@@ -185,7 +186,11 @@ func (c *Controller) Render() error {
if err != nil { if err != nil {
return err return err
} }
c.Ctx.Output.Header("Content-Type", "text/html; charset=utf-8")
if c.Ctx.ResponseWriter.Header().Get("Content-Type") == "" {
c.Ctx.Output.Header("Content-Type", "text/html; charset=utf-8")
}
return c.Ctx.Output.Body(rb) return c.Ctx.Output.Body(rb)
} }
...@@ -209,7 +214,7 @@ func (c *Controller) RenderBytes() ([]byte, error) { ...@@ -209,7 +214,7 @@ func (c *Controller) RenderBytes() ([]byte, error) {
continue continue
} }
buf.Reset() buf.Reset()
err = ExecuteTemplate(&buf, sectionTpl, c.Data) err = ExecuteViewPathTemplate(&buf, sectionTpl, c.viewPath(), c.Data)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -218,7 +223,7 @@ func (c *Controller) RenderBytes() ([]byte, error) { ...@@ -218,7 +223,7 @@ func (c *Controller) RenderBytes() ([]byte, error) {
} }
buf.Reset() buf.Reset()
ExecuteTemplate(&buf, c.Layout, c.Data) ExecuteViewPathTemplate(&buf, c.Layout, c.viewPath() ,c.Data)
} }
return buf.Bytes(), err return buf.Bytes(), err
} }
...@@ -244,9 +249,16 @@ func (c *Controller) renderTemplate() (bytes.Buffer, error) { ...@@ -244,9 +249,16 @@ func (c *Controller) renderTemplate() (bytes.Buffer, error) {
} }
} }
} }
BuildTemplate(BConfig.WebConfig.ViewsPath, buildFiles...) BuildTemplate(c.viewPath() , buildFiles...)
}
return buf, ExecuteViewPathTemplate(&buf, c.TplName, c.viewPath(), c.Data)
}
func (c *Controller) viewPath() string {
if c.ViewPath == "" {
return BConfig.WebConfig.ViewsPath
} }
return buf, ExecuteTemplate(&buf, c.TplName, c.Data) return c.ViewPath
} }
// Redirect sends the redirection response to url with status code. // Redirect sends the redirection response to url with status code.
......
...@@ -20,6 +20,8 @@ import ( ...@@ -20,6 +20,8 @@ import (
"testing" "testing"
"github.com/astaxie/beego/context" "github.com/astaxie/beego/context"
"os"
"path/filepath"
) )
func TestGetInt(t *testing.T) { func TestGetInt(t *testing.T) {
...@@ -121,3 +123,59 @@ func TestGetUint64(t *testing.T) { ...@@ -121,3 +123,59 @@ func TestGetUint64(t *testing.T) {
t.Errorf("TestGetUint64 expect %v,get %T,%v", uint64(math.MaxUint64), val, val) t.Errorf("TestGetUint64 expect %v,get %T,%v", uint64(math.MaxUint64), val, val)
} }
} }
func TestAdditionalViewPaths(t *testing.T) {
dir1 := "_beeTmp"
dir2 := "_beeTmp2"
defer os.RemoveAll(dir1)
defer os.RemoveAll(dir2)
dir1file := "file1.tpl"
dir2file := "file2.tpl"
genFile := func(dir string, name string, content string) {
os.MkdirAll(filepath.Dir(filepath.Join(dir, name)), 0777)
if f, err := os.Create(filepath.Join(dir, name)); err != nil {
t.Fatal(err)
} else {
defer f.Close()
f.WriteString(content)
f.Close()
}
}
genFile(dir1, dir1file, `<div>{{.Content}}</div>`)
genFile(dir2, dir2file, `<html>{{.Content}}</html>`)
AddViewPath(dir1)
AddViewPath(dir2)
ctrl := Controller{
TplName: "file1.tpl",
ViewPath: dir1,
}
ctrl.Data = map[interface{}]interface{}{
"Content": "value2",
}
if result, err := ctrl.RenderString(); err != nil {
t.Fatal(err)
} else {
if result != "<div>value2</div>" {
t.Fatalf("TestAdditionalViewPaths expect %s got %s", "<div>value2</div>", result)
}
}
func() {
ctrl.TplName = "file2.tpl"
defer func() {
if r := recover(); r == nil {
t.Fatal("TestAdditionalViewPaths expected error")
}
}()
ctrl.RenderString();
}()
ctrl.TplName = "file2.tpl"
ctrl.ViewPath = dir2
ctrl.RenderString();
}
...@@ -85,23 +85,31 @@ var ( ...@@ -85,23 +85,31 @@ var (
isChild bool isChild bool
socketOrder string socketOrder string
once sync.Once
hookableSignals []os.Signal
) )
func onceInit() { func init() {
regLock = &sync.Mutex{}
flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)") flag.BoolVar(&isChild, "graceful", false, "listen on open fd (after forking)")
flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started") flag.StringVar(&socketOrder, "socketorder", "", "previous initialization order - used when more than one listener was started")
regLock = &sync.Mutex{}
runningServers = make(map[string]*Server) runningServers = make(map[string]*Server)
runningServersOrder = []string{} runningServersOrder = []string{}
socketPtrOffsetMap = make(map[string]uint) socketPtrOffsetMap = make(map[string]uint)
hookableSignals = []os.Signal{
syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGTERM,
}
} }
// NewServer returns a new graceServer. // NewServer returns a new graceServer.
func NewServer(addr string, handler http.Handler) (srv *Server) { func NewServer(addr string, handler http.Handler) (srv *Server) {
once.Do(onceInit)
regLock.Lock() regLock.Lock()
defer regLock.Unlock() defer regLock.Unlock()
if !flag.Parsed() { if !flag.Parsed() {
flag.Parse() flag.Parse()
} }
......
...@@ -162,9 +162,7 @@ func (srv *Server) handleSignals() { ...@@ -162,9 +162,7 @@ func (srv *Server) handleSignals() {
signal.Notify( signal.Notify(
srv.sigChan, srv.sigChan,
syscall.SIGHUP, hookableSignals...,
syscall.SIGINT,
syscall.SIGTERM,
) )
pid := syscall.Getpid() pid := syscall.Getpid()
...@@ -290,3 +288,19 @@ func (srv *Server) fork() (err error) { ...@@ -290,3 +288,19 @@ func (srv *Server) fork() (err error) {
return return
} }
// RegisterSignalHook registers a function to be run PreSignal or PostSignal for a given signal.
func (srv *Server) RegisterSignalHook(ppFlag int, sig os.Signal, f func()) (err error) {
if ppFlag != PreSignal && ppFlag != PostSignal {
err = fmt.Errorf("Invalid ppFlag argument. Must be either grace.PreSignal or grace.PostSignal.")
return
}
for _, s := range hookableSignals {
if s == sig {
srv.SignalHooks[ppFlag][sig] = append(srv.SignalHooks[ppFlag][sig], f)
return
}
}
err = fmt.Errorf("Signal '%v' is not supported.", sig)
return
}
...@@ -72,7 +72,8 @@ func registerSession() error { ...@@ -72,7 +72,8 @@ func registerSession() error {
} }
func registerTemplate() error { func registerTemplate() error {
if err := BuildTemplate(BConfig.WebConfig.ViewsPath); err != nil { defer lockViewPaths()
if err := AddViewPath(BConfig.WebConfig.ViewsPath); err != nil {
if BConfig.RunMode == DEV { if BConfig.RunMode == DEV {
logs.Warn(err) logs.Warn(err)
} }
......
...@@ -140,6 +140,7 @@ type BeegoHTTPSettings struct { ...@@ -140,6 +140,7 @@ type BeegoHTTPSettings struct {
EnableCookie bool EnableCookie bool
Gzip bool Gzip bool
DumpBody bool DumpBody bool
Retries int // if set to -1 means will retry forever
} }
// BeegoHTTPRequest provides more useful methods for requesting one url than http.Request. // BeegoHTTPRequest provides more useful methods for requesting one url than http.Request.
...@@ -189,6 +190,15 @@ func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest { ...@@ -189,6 +190,15 @@ func (b *BeegoHTTPRequest) Debug(isdebug bool) *BeegoHTTPRequest {
return b return b
} }
// Retries sets Retries times.
// default is 0 means no retried.
// -1 means retried forever.
// others means retried times.
func (b *BeegoHTTPRequest) Retries(times int) *BeegoHTTPRequest {
b.setting.Retries = times
return b
}
// DumpBody setting whether need to Dump the Body. // DumpBody setting whether need to Dump the Body.
func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest { func (b *BeegoHTTPRequest) DumpBody(isdump bool) *BeegoHTTPRequest {
b.setting.DumpBody = isdump b.setting.DumpBody = isdump
...@@ -390,7 +400,7 @@ func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) { ...@@ -390,7 +400,7 @@ func (b *BeegoHTTPRequest) getResponse() (*http.Response, error) {
} }
// DoRequest will do the client.Do // DoRequest will do the client.Do
func (b *BeegoHTTPRequest) DoRequest() (*http.Response, error) { func (b *BeegoHTTPRequest) DoRequest() (resp *http.Response, err error) {
var paramBody string var paramBody string
if len(b.params) > 0 { if len(b.params) > 0 {
var buf bytes.Buffer var buf bytes.Buffer
...@@ -467,7 +477,16 @@ func (b *BeegoHTTPRequest) DoRequest() (*http.Response, error) { ...@@ -467,7 +477,16 @@ func (b *BeegoHTTPRequest) DoRequest() (*http.Response, error) {
} }
b.dump = dump b.dump = dump
} }
return client.Do(b.req) // retries default value is 0, it will run once.
// retries equal to -1, it will run forever until success
// retries is setted, it will retries fixed times.
for i := 0; b.setting.Retries == -1 || i <= b.setting.Retries; i++ {
resp, err = client.Do(b.req)
if err == nil {
break
}
}
return resp, err
} }
// String returns the body string in response. // String returns the body string in response.
......
package alils
import (
"encoding/json"
"github.com/astaxie/beego/logs"
"github.com/gogo/protobuf/proto"
"strings"
"sync"
"time"
)
const (
CacheSize int = 64
Delimiter string = "##"
)
type AliLSConfig struct {
Project string `json:"project"`
Endpoint string `json:"endpoint"`
KeyID string `json:"key_id"`
KeySecret string `json:"key_secret"`
LogStore string `json:"log_store"`
Topics []string `json:"topics"`
Source string `json:"source"`
Level int `json:"level"`
FlushWhen int `json:"flush_when"`
}
// aliLSWriter implements LoggerInterface.
// it writes messages in keep-live tcp connection.
type aliLSWriter struct {
store *LogStore
group []*LogGroup
withMap bool
groupMap map[string]*LogGroup
lock *sync.Mutex
AliLSConfig
}
// 创建提供Logger接口的日志服务
func NewAliLS() logs.Logger {
alils := new(aliLSWriter)
alils.Level = logs.LevelTrace
return alils
}
// 读取配置
// 初始化必要的数据结构
func (c *aliLSWriter) Init(jsonConfig string) (err error) {
json.Unmarshal([]byte(jsonConfig), c)
if c.FlushWhen > CacheSize {
c.FlushWhen = CacheSize
}
// 初始化Project
prj := &LogProject{
Name: c.Project,
Endpoint: c.Endpoint,
AccessKeyId: c.KeyID,
AccessKeySecret: c.KeySecret,
}
// 获取logstore
c.store, err = prj.GetLogStore(c.LogStore)
if err != nil {
return err
}
// 创建默认Log Group
c.group = append(c.group, &LogGroup{
Topic: proto.String(""),
Source: proto.String(c.Source),
Logs: make([]*Log, 0, c.FlushWhen),
})
// 创建其它Log Group
c.groupMap = make(map[string]*LogGroup)
for _, topic := range c.Topics {
lg := &LogGroup{
Topic: proto.String(topic),
Source: proto.String(c.Source),
Logs: make([]*Log, 0, c.FlushWhen),
}
c.group = append(c.group, lg)
c.groupMap[topic] = lg
}
if len(c.group) == 1 {
c.withMap = false
} else {
c.withMap = true
}
c.lock = &sync.Mutex{}
return nil
}
// WriteMsg write message in connection.
// if connection is down, try to re-connect.
func (c *aliLSWriter) WriteMsg(when time.Time, msg string, level int) (err error) {
if level > c.Level {
return nil
}
var topic string
var content string
var lg *LogGroup
if c.withMap {
// 解析出Topic,并匹配LogGroup
strs := strings.SplitN(msg, Delimiter, 2)
if len(strs) == 2 {
pos := strings.LastIndex(strs[0], " ")
topic = strs[0][pos+1 : len(strs[0])]
content = strs[0][0:pos] + strs[1]
lg = c.groupMap[topic]
}
// 默认发到空Topic
if lg == nil {
topic = ""
content = msg
lg = c.group[0]
}
} else {
topic = ""
content = msg
lg = c.group[0]
}
// 生成日志
c1 := &Log_Content{
Key: proto.String("msg"),
Value: proto.String(content),
}
l := &Log{
Time: proto.Uint32(uint32(when.Unix())), // 填写日志时间
Contents: []*Log_Content{
c1,
},
}
c.lock.Lock()
lg.Logs = append(lg.Logs, l)
c.lock.Unlock()
// 满足条件则Flush
if len(lg.Logs) >= c.FlushWhen {
c.flush(lg)
}
return nil
}
// Flush implementing method. empty.
func (c *aliLSWriter) Flush() {
// flush所有group
for _, lg := range c.group {
c.flush(lg)
}
}
// Destroy destroy connection writer and close tcp listener.
func (c *aliLSWriter) Destroy() {
}
func (c *aliLSWriter) flush(lg *LogGroup) {
c.lock.Lock()
defer c.lock.Unlock()
// 把以上的LogGroup推送到SLS服务器,
// SLS服务器会根据该logstore的shard个数自动进行负载均衡。
err := c.store.PutLogs(lg)
if err != nil {
return
}
lg.Logs = make([]*Log, 0, c.FlushWhen)
}
func init() {
logs.Register(logs.AdapterAliLS, NewAliLS)
}
package alils
const (
version = "0.5.0" // SDK version
signatureMethod = "hmac-sha1" // Signature method
// OffsetNewest stands for the log head offset, i.e. the offset that will be
// assigned to the next message that will be produced to the shard.
OffsetNewest = "end"
// OffsetOldest stands for the oldest offset available on the logstore for a
// shard.
OffsetOldest = "begin"
)
This diff is collapsed.
package alils
type InputDetail struct {
LogType string `json:"logType"`
LogPath string `json:"logPath"`
FilePattern string `json:"filePattern"`
LocalStorage bool `json:"localStorage"`
TimeFormat string `json:"timeFormat"`
LogBeginRegex string `json:"logBeginRegex"`
Regex string `json:"regex"`
Keys []string `json:"key"`
FilterKeys []string `json:"filterKey"`
FilterRegex []string `json:"filterRegex"`
TopicFormat string `json:"topicFormat"`
}
type OutputDetail struct {
Endpoint string `json:"endpoint"`
LogStoreName string `json:"logstoreName"`
}
type LogConfig struct {
Name string `json:"configName"`
InputType string `json:"inputType"`
InputDetail InputDetail `json:"inputDetail"`
OutputType string `json:"outputType"`
OutputDetail OutputDetail `json:"outputDetail"`
CreateTime uint32
LastModifyTime uint32
project *LogProject
}
// GetAppliedMachineGroup returns applied machine group of this config.
func (c *LogConfig) GetAppliedMachineGroup(confName string) (groupNames []string, err error) {
groupNames, err = c.project.GetAppliedMachineGroups(c.Name)
return
}
This diff is collapsed.
package alils
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httputil"
"strconv"
lz4 "github.com/cloudflare/golz4"
"github.com/gogo/protobuf/proto"
)
type LogStore struct {
Name string `json:"logstoreName"`
TTL int
ShardCount int
CreateTime uint32
LastModifyTime uint32
project *LogProject
}
type Shard struct {
ShardID int `json:"shardID"`
}
// ListShards returns shard id list of this logstore.
func (s *LogStore) ListShards() (shardIDs []int, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
uri := fmt.Sprintf("/logstores/%v/shards", s.Name)
r, err := request(s.project, "GET", uri, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to list logstore")
dump, _ := httputil.DumpResponse(r, true)
fmt.Println(dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
var shards []*Shard
err = json.Unmarshal(buf, &shards)
if err != nil {
return
}
for _, v := range shards {
shardIDs = append(shardIDs, v.ShardID)
}
return
}
// PutLogs put logs into logstore.
// The callers should transform user logs into LogGroup.
func (s *LogStore) PutLogs(lg *LogGroup) (err error) {
body, err := proto.Marshal(lg)
if err != nil {
return
}
// Compresse body with lz4
out := make([]byte, lz4.CompressBound(body))
n, err := lz4.Compress(body, out)
if err != nil {
return
}
h := map[string]string{
"x-sls-compresstype": "lz4",
"x-sls-bodyrawsize": fmt.Sprintf("%v", len(body)),
"Content-Type": "application/x-protobuf",
}
uri := fmt.Sprintf("/logstores/%v", s.Name)
r, err := request(s.project, "POST", uri, h, out[:n])
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to put logs")
dump, _ := httputil.DumpResponse(r, true)
fmt.Println(dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
return
}
// GetCursor gets log cursor of one shard specified by shardId.
// The from can be in three form: a) unix timestamp in seccond, b) "begin", c) "end".
// For more detail please read: http://gitlab.alibaba-inc.com/sls/doc/blob/master/api/shard.md#logstore
func (s *LogStore) GetCursor(shardId int, from string) (cursor string, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
uri := fmt.Sprintf("/logstores/%v/shards/%v?type=cursor&from=%v",
s.Name, shardId, from)
r, err := request(s.project, "GET", uri, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to get cursor")
dump, _ := httputil.DumpResponse(r, true)
fmt.Println(dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
type Body struct {
Cursor string
}
body := &Body{}
err = json.Unmarshal(buf, body)
if err != nil {
return
}
cursor = body.Cursor
return
}
// GetLogsBytes gets logs binary data from shard specified by shardId according cursor.
// The logGroupMaxCount is the max number of logGroup could be returned.
// The nextCursor is the next curosr can be used to read logs at next time.
func (s *LogStore) GetLogsBytes(shardId int, cursor string,
logGroupMaxCount int) (out []byte, nextCursor string, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
"Accept": "application/x-protobuf",
"Accept-Encoding": "lz4",
}
uri := fmt.Sprintf("/logstores/%v/shards/%v?type=logs&cursor=%v&count=%v",
s.Name, shardId, cursor, logGroupMaxCount)
r, err := request(s.project, "GET", uri, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to get cursor")
dump, _ := httputil.DumpResponse(r, true)
fmt.Println(dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
v, ok := r.Header["X-Sls-Compresstype"]
if !ok || len(v) == 0 {
err = fmt.Errorf("can't find 'x-sls-compresstype' header")
return
}
if v[0] != "lz4" {
err = fmt.Errorf("unexpected compress type:%v", v[0])
return
}
v, ok = r.Header["X-Sls-Cursor"]
if !ok || len(v) == 0 {
err = fmt.Errorf("can't find 'x-sls-cursor' header")
return
}
nextCursor = v[0]
v, ok = r.Header["X-Sls-Bodyrawsize"]
if !ok || len(v) == 0 {
err = fmt.Errorf("can't find 'x-sls-bodyrawsize' header")
return
}
bodyRawSize, err := strconv.Atoi(v[0])
if err != nil {
return
}
out = make([]byte, bodyRawSize)
err = lz4.Uncompress(buf, out)
if err != nil {
return
}
return
}
// LogsBytesDecode decodes logs binary data retruned by GetLogsBytes API
func LogsBytesDecode(data []byte) (gl *LogGroupList, err error) {
gl = &LogGroupList{}
err = proto.Unmarshal(data, gl)
if err != nil {
return
}
return
}
// GetLogs gets logs from shard specified by shardId according cursor.
// The logGroupMaxCount is the max number of logGroup could be returned.
// The nextCursor is the next curosr can be used to read logs at next time.
func (s *LogStore) GetLogs(shardId int, cursor string,
logGroupMaxCount int) (gl *LogGroupList, nextCursor string, err error) {
out, nextCursor, err := s.GetLogsBytes(shardId, cursor, logGroupMaxCount)
if err != nil {
return
}
gl, err = LogsBytesDecode(out)
if err != nil {
return
}
return
}
package alils
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httputil"
)
type MachinGroupAttribute struct {
ExternalName string `json:"externalName"`
TopicName string `json:"groupTopic"`
}
type MachineGroup struct {
Name string `json:"groupName"`
Type string `json:"groupType"`
MachineIdType string `json:"machineIdentifyType"`
MachineIdList []string `json:"machineList"`
Attribute MachinGroupAttribute `json:"groupAttribute"`
CreateTime uint32
LastModifyTime uint32
project *LogProject
}
type Machine struct {
IP string
UniqueId string `json:"machine-uniqueid"`
UserdefinedId string `json:"userdefined-id"`
}
type MachineList struct {
Total int
Machines []*Machine
}
// ListMachines returns machine list of this machine group.
func (m *MachineGroup) ListMachines() (ms []*Machine, total int, err error) {
h := map[string]string{
"x-sls-bodyrawsize": "0",
}
uri := fmt.Sprintf("/machinegroups/%v/machines", m.Name)
r, err := request(m.project, "GET", uri, h, nil)
if err != nil {
return
}
buf, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
if r.StatusCode != http.StatusOK {
errMsg := &errorMessage{}
err = json.Unmarshal(buf, errMsg)
if err != nil {
err = fmt.Errorf("failed to remove config from machine group")
dump, _ := httputil.DumpResponse(r, true)
fmt.Println(dump)
return
}
err = fmt.Errorf("%v:%v", errMsg.Code, errMsg.Message)
return
}
body := &MachineList{}
err = json.Unmarshal(buf, body)
if err != nil {
return
}
ms = body.Machines
total = body.Total
return
}
// GetAppliedConfigs returns applied configs of this machine group.
func (m *MachineGroup) GetAppliedConfigs() (confNames []string, err error) {
confNames, err = m.project.GetAppliedConfigs(m.Name)
return
}
package alils
import (
"bytes"
"crypto/md5"
"fmt"
"net/http"
)
// request sends a request to SLS.
func request(project *LogProject, method, uri string, headers map[string]string,
body []byte) (resp *http.Response, err error) {
// The caller should provide 'x-sls-bodyrawsize' header
if _, ok := headers["x-sls-bodyrawsize"]; !ok {
err = fmt.Errorf("Can't find 'x-sls-bodyrawsize' header")
return
}
// SLS public request headers
headers["Host"] = project.Name + "." + project.Endpoint
headers["Date"] = nowRFC1123()
headers["x-sls-apiversion"] = version
headers["x-sls-signaturemethod"] = signatureMethod
if body != nil {
bodyMD5 := fmt.Sprintf("%X", md5.Sum(body))
headers["Content-MD5"] = bodyMD5
if _, ok := headers["Content-Type"]; !ok {
err = fmt.Errorf("Can't find 'Content-Type' header")
return
}
}
// Calc Authorization
// Authorization = "SLS <AccessKeyId>:<Signature>"
digest, err := signature(project, method, uri, headers)
if err != nil {
return
}
auth := fmt.Sprintf("SLS %v:%v", project.AccessKeyId, digest)
headers["Authorization"] = auth
// Initialize http request
reader := bytes.NewReader(body)
urlStr := fmt.Sprintf("http://%v.%v%v", project.Name, project.Endpoint, uri)
req, err := http.NewRequest(method, urlStr, reader)
if err != nil {
return
}
for k, v := range headers {
req.Header.Add(k, v)
}
// Get ready to do request
resp, err = http.DefaultClient.Do(req)
if err != nil {
return
}
return
}
package alils
import (
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"fmt"
"net/url"
"sort"
"strings"
"time"
)
// GMT location
var gmtLoc = time.FixedZone("GMT", 0)
// NowRFC1123 returns now time in RFC1123 format with GMT timezone,
// eg. "Mon, 02 Jan 2006 15:04:05 GMT".
func nowRFC1123() string {
return time.Now().In(gmtLoc).Format(time.RFC1123)
}
// signature calculates a request's signature digest.
func signature(project *LogProject, method, uri string,
headers map[string]string) (digest string, err error) {
var contentMD5, contentType, date, canoHeaders, canoResource string
var slsHeaderKeys sort.StringSlice
// SignString = VERB + "\n"
// + CONTENT-MD5 + "\n"
// + CONTENT-TYPE + "\n"
// + DATE + "\n"
// + CanonicalizedSLSHeaders + "\n"
// + CanonicalizedResource
if val, ok := headers["Content-MD5"]; ok {
contentMD5 = val
}
if val, ok := headers["Content-Type"]; ok {
contentType = val
}
date, ok := headers["Date"]
if !ok {
err = fmt.Errorf("Can't find 'Date' header")
return
}
// Calc CanonicalizedSLSHeaders
slsHeaders := make(map[string]string, len(headers))
for k, v := range headers {
l := strings.TrimSpace(strings.ToLower(k))
if strings.HasPrefix(l, "x-sls-") {
slsHeaders[l] = strings.TrimSpace(v)
slsHeaderKeys = append(slsHeaderKeys, l)
}
}
sort.Sort(slsHeaderKeys)
for i, k := range slsHeaderKeys {
canoHeaders += k + ":" + slsHeaders[k]
if i+1 < len(slsHeaderKeys) {
canoHeaders += "\n"
}
}
// Calc CanonicalizedResource
u, err := url.Parse(uri)
if err != nil {
return
}
canoResource += url.QueryEscape(u.Path)
if u.RawQuery != "" {
var keys sort.StringSlice
vals := u.Query()
for k, _ := range vals {
keys = append(keys, k)
}
sort.Sort(keys)
canoResource += "?"
for i, k := range keys {
if i > 0 {
canoResource += "&"
}
for _, v := range vals[k] {
canoResource += k + "=" + v
}
}
}
signStr := method + "\n" +
contentMD5 + "\n" +
contentType + "\n" +
date + "\n" +
canoHeaders + "\n" +
canoResource
// Signature = base64(hmac-sha1(UTF8-Encoding-Of(SignString),AccessKeySecret))
mac := hmac.New(sha1.New, []byte(project.AccessKeySecret))
_, err = mac.Write([]byte(signStr))
if err != nil {
return
}
digest = base64.StdEncoding.EncodeToString(mac.Sum(nil))
return
}
...@@ -270,6 +270,7 @@ func (w *fileLogWriter) doRotate(logTime time.Time) error { ...@@ -270,6 +270,7 @@ func (w *fileLogWriter) doRotate(logTime time.Time) error {
// Rename the file to its new found name // Rename the file to its new found name
// even if occurs error,we MUST guarantee to restart new logger // even if occurs error,we MUST guarantee to restart new logger
err = os.Rename(w.Filename, fName) err = os.Rename(w.Filename, fName)
err = os.Chmod(fName, os.FileMode(440))
// re-start logger // re-start logger
RESTART_LOGGER: RESTART_LOGGER:
......
...@@ -71,6 +71,7 @@ const ( ...@@ -71,6 +71,7 @@ const (
AdapterEs = "es" AdapterEs = "es"
AdapterJianLiao = "jianliao" AdapterJianLiao = "jianliao"
AdapterSlack = "slack" AdapterSlack = "slack"
AdapterAliLS = "alils"
) )
// Legacy log level constants to ensure backwards compatibility. // Legacy log level constants to ensure backwards compatibility.
......
...@@ -41,6 +41,8 @@ func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interfac ...@@ -41,6 +41,8 @@ func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interfac
vu := v.Int() vu := v.Int()
exist = true exist = true
value = vu value = vu
} else if fi.fieldType&IsRelField > 0 {
_, value, exist = getExistPk(fi.relModelInfo, reflect.Indirect(v))
} else { } else {
vu := v.String() vu := v.String()
exist = vu != "" exist = vu != ""
......
...@@ -117,7 +117,7 @@ func bootStrap() { ...@@ -117,7 +117,7 @@ func bootStrap() {
name := getFullName(elm) name := getFullName(elm)
mii, ok := modelCache.getByFullName(name) mii, ok := modelCache.getByFullName(name)
if !ok || mii.pkg != elm.PkgPath() { if !ok || mii.pkg != elm.PkgPath() {
err = fmt.Errorf("can not found rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String()) err = fmt.Errorf("can not find rel in field `%s`, `%s` may be miss register", fi.fullName, elm.String())
goto end goto end
} }
fi.relModelInfo = mii fi.relModelInfo = mii
......
...@@ -406,6 +406,11 @@ type UintPk struct { ...@@ -406,6 +406,11 @@ type UintPk struct {
Name string Name string
} }
type PtrPk struct {
ID *IntegerPk `orm:"pk;rel(one)"`
Positive bool
}
var DBARGS = struct { var DBARGS = struct {
Driver string Driver string
Source string Source string
......
...@@ -153,6 +153,8 @@ func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, i ...@@ -153,6 +153,8 @@ func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, i
id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex) id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex)
if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
id = int64(vid.Uint()) id = int64(vid.Uint())
} else if mi.fields.pk.rel {
return o.ReadOrCreate(vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name)
} else { } else {
id = vid.Int() id = vid.Int()
} }
......
...@@ -153,6 +153,11 @@ func (o querySet) SetCond(cond *Condition) QuerySeter { ...@@ -153,6 +153,11 @@ func (o querySet) SetCond(cond *Condition) QuerySeter {
return &o return &o
} }
// get condition from QuerySeter
func (o querySet) GetCond() *Condition {
return o.cond
}
// return QuerySeter execution result number // return QuerySeter execution result number
func (o *querySet) Count() (int64, error) { func (o *querySet) Count() (int64, error) {
return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ) return o.orm.alias.DbBaser.Count(o.orm.db, o, o.mi, o.cond, o.orm.alias.TZ)
......
...@@ -193,6 +193,7 @@ func TestSyncDb(t *testing.T) { ...@@ -193,6 +193,7 @@ func TestSyncDb(t *testing.T) {
RegisterModel(new(InLineOneToOne)) RegisterModel(new(InLineOneToOne))
RegisterModel(new(IntegerPk)) RegisterModel(new(IntegerPk))
RegisterModel(new(UintPk)) RegisterModel(new(UintPk))
RegisterModel(new(PtrPk))
err := RunSyncdb("default", true, Debug) err := RunSyncdb("default", true, Debug)
throwFail(t, err) throwFail(t, err)
...@@ -216,6 +217,7 @@ func TestRegisterModels(t *testing.T) { ...@@ -216,6 +217,7 @@ func TestRegisterModels(t *testing.T) {
RegisterModel(new(InLineOneToOne)) RegisterModel(new(InLineOneToOne))
RegisterModel(new(IntegerPk)) RegisterModel(new(IntegerPk))
RegisterModel(new(UintPk)) RegisterModel(new(UintPk))
RegisterModel(new(PtrPk))
BootStrap() BootStrap()
...@@ -2144,6 +2146,48 @@ func TestUintPk(t *testing.T) { ...@@ -2144,6 +2146,48 @@ func TestUintPk(t *testing.T) {
dORM.Delete(u) dORM.Delete(u)
} }
func TestPtrPk(t *testing.T) {
parent := &IntegerPk{ID: 10, Value: "10"}
id, _ := dORM.Insert(parent)
if !IsMysql {
// MySql does not support last_insert_id in this case: see #2382
throwFail(t, AssertIs(id, 10))
}
ptr := PtrPk{ID: parent, Positive: true}
num, err := dORM.InsertMulti(2, []PtrPk{ptr})
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
throwFail(t, AssertIs(ptr.ID, parent))
nptr := &PtrPk{ID: parent}
created, pk, err := dORM.ReadOrCreate(nptr, "ID")
throwFail(t, err)
throwFail(t, AssertIs(created, false))
throwFail(t, AssertIs(pk, 10))
throwFail(t, AssertIs(nptr.ID, parent))
throwFail(t, AssertIs(nptr.Positive, true))
nptr = &PtrPk{Positive: true}
created, pk, err = dORM.ReadOrCreate(nptr, "Positive")
throwFail(t, err)
throwFail(t, AssertIs(created, false))
throwFail(t, AssertIs(pk, 10))
throwFail(t, AssertIs(nptr.ID, parent))
nptr.Positive = false
num, err = dORM.Update(nptr)
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
throwFail(t, AssertIs(nptr.ID, parent))
throwFail(t, AssertIs(nptr.Positive, false))
num, err = dORM.Delete(nptr)
throwFail(t, err)
throwFail(t, AssertIs(num, 1))
}
func TestSnake(t *testing.T) { func TestSnake(t *testing.T) {
cases := map[string]string{ cases := map[string]string{
"i": "i", "i": "i",
......
...@@ -145,6 +145,16 @@ type QuerySeter interface { ...@@ -145,6 +145,16 @@ type QuerySeter interface {
// //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000 // //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000
// num, err := qs.SetCond(cond1).Count() // num, err := qs.SetCond(cond1).Count()
SetCond(*Condition) QuerySeter SetCond(*Condition) QuerySeter
// get condition from QuerySeter.
// sql's where condition
// cond := orm.NewCondition()
// cond = cond.And("profile__isnull", false).AndNot("status__in", 1)
// qs = qs.SetCond(cond)
// cond = qs.GetCond()
// cond := cond.Or("profile__age__gt", 2000)
// //sql-> WHERE T0.`profile_id` IS NOT NULL AND NOT T0.`Status` IN (?) OR T1.`age` > 2000
// num, err := qs.SetCond(cond).Count()
GetCond() *Condition
// add LIMIT value. // add LIMIT value.
// args[0] means offset, e.g. LIMIT num,offset. // args[0] means offset, e.g. LIMIT num,offset.
// if Limit <= 0 then Limit will be set to default limit ,eg 1000 // if Limit <= 0 then Limit will be set to default limit ,eg 1000
......
...@@ -219,22 +219,17 @@ func snakeString(s string) string { ...@@ -219,22 +219,17 @@ func snakeString(s string) string {
// camel string, xx_yy to XxYy // camel string, xx_yy to XxYy
func camelString(s string) string { func camelString(s string) string {
data := make([]byte, 0, len(s)) data := make([]byte, 0, len(s))
j := false flag, num := true, len(s)-1
k := false
num := len(s) - 1
for i := 0; i <= num; i++ { for i := 0; i <= num; i++ {
d := s[i] d := s[i]
if k == false && d >= 'A' && d <= 'Z' { if d == '_' {
k = true flag = true
}
if d >= 'a' && d <= 'z' && (j || k == false) {
d = d - 32
j = false
k = true
}
if k && d == '_' && num > i && s[i+1] >= 'a' && s[i+1] <= 'z' {
j = true
continue continue
} else if flag == true {
if d >= 'a' && d <= 'z' {
d = d - 32
}
flag = false
} }
data = append(data, d) data = append(data, d)
} }
......
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"testing"
)
func TestCamelString(t *testing.T) {
snake := []string{"pic_url", "hello_world_", "hello__World", "_HelLO_Word", "pic_url_1", "pic_url__1"}
camel := []string{"PicUrl", "HelloWorld", "HelloWorld", "HelLOWord", "PicUrl1", "PicUrl1"}
answer := make(map[string]string)
for i, v := range snake {
answer[v] = camel[i]
}
for _, v := range snake {
res := camelString(v)
if res != answer[v] {
t.Error("Unit Test Fail:", v, res, answer[v])
}
}
}
...@@ -56,6 +56,7 @@ ...@@ -56,6 +56,7 @@
package apiauth package apiauth
import ( import (
"bytes"
"crypto/hmac" "crypto/hmac"
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
...@@ -128,53 +129,32 @@ func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc { ...@@ -128,53 +129,32 @@ func APISecretAuth(f AppIDToAppSecret, timeout int) beego.FilterFunc {
// Signature used to generate signature with the appsecret/method/params/RequestURI // Signature used to generate signature with the appsecret/method/params/RequestURI
func Signature(appsecret, method string, params url.Values, RequestURL string) (result string) { func Signature(appsecret, method string, params url.Values, RequestURL string) (result string) {
var query string var b bytes.Buffer
keys := make([]string, len(params))
pa := make(map[string]string) pa := make(map[string]string)
for k, v := range params { for k, v := range params {
pa[k] = v[0] pa[k] = v[0]
keys = append(keys, k)
} }
vs := mapSorter(pa)
vs.Sort() sort.Strings(keys)
for i := 0; i < vs.Len(); i++ {
if vs.Keys[i] == "signature" { for _, key := range keys {
if key == "signature" {
continue continue
} }
if vs.Keys[i] != "" && vs.Vals[i] != "" {
query = fmt.Sprintf("%v%v%v", query, vs.Keys[i], vs.Vals[i]) val := pa[key]
if key != "" && val != "" {
b.WriteString(key)
b.WriteString(val)
} }
} }
stringToSign := fmt.Sprintf("%v\n%v\n%v\n", method, query, RequestURL)
stringToSign := fmt.Sprintf("%v\n%v\n%v\n", method, b.String(), RequestURL)
sha256 := sha256.New sha256 := sha256.New
hash := hmac.New(sha256, []byte(appsecret)) hash := hmac.New(sha256, []byte(appsecret))
hash.Write([]byte(stringToSign)) hash.Write([]byte(stringToSign))
return base64.StdEncoding.EncodeToString(hash.Sum(nil)) return base64.StdEncoding.EncodeToString(hash.Sum(nil))
} }
type valSorter struct {
Keys []string
Vals []string
}
func mapSorter(m map[string]string) *valSorter {
vs := &valSorter{
Keys: make([]string, 0, len(m)),
Vals: make([]string, 0, len(m)),
}
for k, v := range m {
vs.Keys = append(vs.Keys, k)
vs.Vals = append(vs.Vals, v)
}
return vs
}
func (vs *valSorter) Sort() {
sort.Sort(vs)
}
func (vs *valSorter) Len() int { return len(vs.Keys) }
func (vs *valSorter) Less(i, j int) bool { return vs.Keys[i] < vs.Keys[j] }
func (vs *valSorter) Swap(i, j int) {
vs.Vals[i], vs.Vals[j] = vs.Vals[j], vs.Vals[i]
vs.Keys[i], vs.Keys[j] = vs.Keys[j], vs.Keys[i]
}
package apiauth
import (
"net/url"
"testing"
)
func TestSignature(t *testing.T) {
appsecret := "beego secret"
method := "GET"
RequestURL := "http://localhost/test/url"
params := make(url.Values)
params.Add("arg1", "hello")
params.Add("arg2", "beego")
signature := "mFdpvLh48ca4mDVEItE9++AKKQ/IVca7O/ZyyB8hR58="
if Signature(appsecret, method, params, RequestURL) != signature {
t.Error("Signature error")
}
}
...@@ -51,15 +51,22 @@ const ( ...@@ -51,15 +51,22 @@ const (
var ( var (
// HTTPMETHOD list the supported http methods. // HTTPMETHOD list the supported http methods.
HTTPMETHOD = map[string]string{ HTTPMETHOD = map[string]string{
"GET": "GET", "GET": "GET",
"POST": "POST", "POST": "POST",
"PUT": "PUT", "PUT": "PUT",
"DELETE": "DELETE", "DELETE": "DELETE",
"PATCH": "PATCH", "PATCH": "PATCH",
"OPTIONS": "OPTIONS", "OPTIONS": "OPTIONS",
"HEAD": "HEAD", "HEAD": "HEAD",
"TRACE": "TRACE", "TRACE": "TRACE",
"CONNECT": "CONNECT", "CONNECT": "CONNECT",
"MKCOL": "MKCOL",
"COPY": "COPY",
"MOVE": "MOVE",
"PROPFIND": "PROPFIND",
"PROPPATCH": "PROPPATCH",
"LOCK": "LOCK",
"UNLOCK": "UNLOCK",
} }
// these beego.Controller's methods shouldn't reflect to AutoRouter // these beego.Controller's methods shouldn't reflect to AutoRouter
exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString", exceptMethod = []string{"Init", "Prepare", "Finish", "Render", "RenderString",
......
...@@ -143,6 +143,7 @@ func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error { ...@@ -143,6 +143,7 @@ func (mp *Provider) SessionInit(maxlifetime int64, savePath string) error {
// SessionRead get mysql session by sid // SessionRead get mysql session by sid
func (mp *Provider) SessionRead(sid string) (session.Store, error) { func (mp *Provider) SessionRead(sid string) (session.Store, error) {
c := mp.connectInit() c := mp.connectInit()
defer c.Close()
row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid) row := c.QueryRow("select session_data from "+TableName+" where session_key=?", sid)
var sessiondata []byte var sessiondata []byte
err := row.Scan(&sessiondata) err := row.Scan(&sessiondata)
...@@ -179,6 +180,7 @@ func (mp *Provider) SessionExist(sid string) bool { ...@@ -179,6 +180,7 @@ func (mp *Provider) SessionExist(sid string) bool {
// SessionRegenerate generate new sid for mysql session // SessionRegenerate generate new sid for mysql session
func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) { func (mp *Provider) SessionRegenerate(oldsid, sid string) (session.Store, error) {
c := mp.connectInit() c := mp.connectInit()
defer c.Close()
row := c.QueryRow("select session_data from "+TableName+" where session_key=?", oldsid) row := c.QueryRow("select session_data from "+TableName+" where session_key=?", oldsid)
var sessiondata []byte var sessiondata []byte
err := row.Scan(&sessiondata) err := row.Scan(&sessiondata)
......
...@@ -15,8 +15,7 @@ ...@@ -15,8 +15,7 @@
package session package session
import ( import (
"errors" "fmt"
"io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"os" "os"
...@@ -135,6 +134,9 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) { ...@@ -135,6 +134,9 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) {
} else { } else {
return nil, err return nil, err
} }
defer f.Close()
os.Chtimes(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), time.Now(), time.Now()) os.Chtimes(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), time.Now(), time.Now())
var kv map[interface{}]interface{} var kv map[interface{}]interface{}
b, err := ioutil.ReadAll(f) b, err := ioutil.ReadAll(f)
...@@ -149,7 +151,7 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) { ...@@ -149,7 +151,7 @@ func (fp *FileProvider) SessionRead(sid string) (Store, error) {
return nil, err return nil, err
} }
} }
f.Close()
ss := &FileSessionStore{sid: sid, values: kv} ss := &FileSessionStore{sid: sid, values: kv}
return ss, nil return ss, nil
} }
...@@ -204,49 +206,58 @@ func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) { ...@@ -204,49 +206,58 @@ func (fp *FileProvider) SessionRegenerate(oldsid, sid string) (Store, error) {
filepder.lock.Lock() filepder.lock.Lock()
defer filepder.lock.Unlock() defer filepder.lock.Unlock()
err := os.MkdirAll(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])), 0777) oldPath := path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1]))
if err != nil { oldSidFile := path.Join(oldPath, oldsid)
SLogger.Println(err.Error()) newPath := path.Join(fp.savePath, string(sid[0]), string(sid[1]))
newSidFile := path.Join(newPath, sid)
// new sid file is exist
_, err := os.Stat(newSidFile)
if err == nil {
return nil, fmt.Errorf("newsid %s exist", newSidFile)
} }
err = os.MkdirAll(path.Join(fp.savePath, string(sid[0]), string(sid[1])), 0777)
err = os.MkdirAll(newPath, 0777)
if err != nil { if err != nil {
SLogger.Println(err.Error()) SLogger.Println(err.Error())
} }
_, err = os.Stat(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
var newf *os.File
if err == nil {
return nil, errors.New("newsid exist")
} else if os.IsNotExist(err) {
newf, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
}
_, err = os.Stat(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1]), oldsid)) // if old sid file exist
var f *os.File // 1.read and parse file content
// 2.write content to new sid file
// 3.remove old sid file, change new sid file atime and ctime
// 4.return FileSessionStore
_, err = os.Stat(oldSidFile)
if err == nil { if err == nil {
f, err = os.OpenFile(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1]), oldsid), os.O_RDWR, 0777) b, err := ioutil.ReadFile(oldSidFile)
io.Copy(newf, f)
} else if os.IsNotExist(err) {
newf, err = os.Create(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid))
} else {
return nil, err
}
f.Close()
os.Remove(path.Join(fp.savePath, string(oldsid[0]), string(oldsid[1])))
os.Chtimes(path.Join(fp.savePath, string(sid[0]), string(sid[1]), sid), time.Now(), time.Now())
var kv map[interface{}]interface{}
b, err := ioutil.ReadAll(newf)
if err != nil {
return nil, err
}
if len(b) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = DecodeGob(b)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var kv map[interface{}]interface{}
if len(b) == 0 {
kv = make(map[interface{}]interface{})
} else {
kv, err = DecodeGob(b)
if err != nil {
return nil, err
}
}
ioutil.WriteFile(newSidFile, b, 0777)
os.Remove(oldSidFile)
os.Chtimes(newSidFile, time.Now(), time.Now())
ss := &FileSessionStore{sid: sid, values: kv}
return ss, nil
} }
ss := &FileSessionStore{sid: sid, values: kv}
// if old sid file not exist, just create new sid file and return
newf, err := os.Create(newSidFile)
if err != nil {
return nil, err
}
newf.Close()
ss := &FileSessionStore{sid: sid, values: make(map[interface{}]interface{})}
return ss, nil return ss, nil
} }
......
...@@ -32,8 +32,9 @@ import ( ...@@ -32,8 +32,9 @@ import (
var ( var (
beegoTplFuncMap = make(template.FuncMap) beegoTplFuncMap = make(template.FuncMap)
// beeTemplates caching map and supported template file extensions. beeViewPathTemplateLocked = false
beeTemplates = make(map[string]*template.Template) // beeViewPathTemplates caching map and supported template file extensions per view
beeViewPathTemplates = make(map[string]map[string]*template.Template)
templatesLock sync.RWMutex templatesLock sync.RWMutex
// beeTemplateExt stores the template extension which will build // beeTemplateExt stores the template extension which will build
beeTemplateExt = []string{"tpl", "html"} beeTemplateExt = []string{"tpl", "html"}
...@@ -45,23 +46,33 @@ var ( ...@@ -45,23 +46,33 @@ var (
// writing the output to wr. // writing the output to wr.
// A template will be executed safely in parallel. // A template will be executed safely in parallel.
func ExecuteTemplate(wr io.Writer, name string, data interface{}) error { func ExecuteTemplate(wr io.Writer, name string, data interface{}) error {
return ExecuteViewPathTemplate(wr,name, BConfig.WebConfig.ViewsPath, data)
}
// ExecuteViewPathTemplate applies the template with name and from specific viewPath to the specified data object,
// writing the output to wr.
// A template will be executed safely in parallel.
func ExecuteViewPathTemplate(wr io.Writer, name string, viewPath string, data interface{}) error {
if BConfig.RunMode == DEV { if BConfig.RunMode == DEV {
templatesLock.RLock() templatesLock.RLock()
defer templatesLock.RUnlock() defer templatesLock.RUnlock()
} }
if t, ok := beeTemplates[name]; ok { if beeTemplates,ok := beeViewPathTemplates[viewPath]; ok {
var err error if t, ok := beeTemplates[name]; ok {
if t.Lookup(name) != nil { var err error
err = t.ExecuteTemplate(wr, name, data) if t.Lookup(name) != nil {
} else { err = t.ExecuteTemplate(wr, name, data)
err = t.Execute(wr, data) } else {
} err = t.Execute(wr, data)
if err != nil { }
logs.Trace("template Execute err:", err) if err != nil {
logs.Trace("template Execute err:", err)
}
return err
} }
return err panic("can't find templatefile in the path:" + viewPath + "/" + name)
} }
panic("can't find templatefile in the path:" + name) panic("Uknown view path:" + viewPath)
} }
func init() { func init() {
...@@ -149,6 +160,21 @@ func AddTemplateExt(ext string) { ...@@ -149,6 +160,21 @@ func AddTemplateExt(ext string) {
beeTemplateExt = append(beeTemplateExt, ext) beeTemplateExt = append(beeTemplateExt, ext)
} }
// AddViewPath adds a new path to the supported view paths.
//Can later be used by setting a controller ViewPath to this folder
//will panic if called after beego.Run()
func AddViewPath(viewPath string) error {
if beeViewPathTemplateLocked {
panic("Can not add new view paths after beego.Run()")
}
beeViewPathTemplates[viewPath] = make(map[string]*template.Template)
return BuildTemplate(viewPath)
}
func lockViewPaths() {
beeViewPathTemplateLocked = true
}
// BuildTemplate will build all template files in a directory. // BuildTemplate will build all template files in a directory.
// it makes beego can render any template file in view directory. // it makes beego can render any template file in view directory.
func BuildTemplate(dir string, files ...string) error { func BuildTemplate(dir string, files ...string) error {
...@@ -158,6 +184,10 @@ func BuildTemplate(dir string, files ...string) error { ...@@ -158,6 +184,10 @@ func BuildTemplate(dir string, files ...string) error {
} }
return errors.New("dir open err") return errors.New("dir open err")
} }
beeTemplates,ok := beeViewPathTemplates[dir];
if !ok {
panic("Unknown view path: " + dir)
}
self := &templateFile{ self := &templateFile{
root: dir, root: dir,
files: make(map[string][]string), files: make(map[string][]string),
...@@ -224,7 +254,7 @@ func getTplDeep(root, file, parent string, t *template.Template) (*template.Temp ...@@ -224,7 +254,7 @@ func getTplDeep(root, file, parent string, t *template.Template) (*template.Temp
if !HasTemplateExt(m[1]) { if !HasTemplateExt(m[1]) {
continue continue
} }
t, _, err = getTplDeep(root, m[1], file, t) _, _, err = getTplDeep(root, m[1], file, t)
if err != nil { if err != nil {
return nil, [][]string{}, err return nil, [][]string{}, err
} }
......
...@@ -67,9 +67,10 @@ func TestTemplate(t *testing.T) { ...@@ -67,9 +67,10 @@ func TestTemplate(t *testing.T) {
f.Close() f.Close()
} }
} }
if err := BuildTemplate(dir); err != nil { if err := AddViewPath(dir); err != nil {
t.Fatal(err) t.Fatal(err)
} }
beeTemplates := beeViewPathTemplates[dir]
if len(beeTemplates) != 3 { if len(beeTemplates) != 3 {
t.Fatalf("should be 3 but got %v", len(beeTemplates)) t.Fatalf("should be 3 but got %v", len(beeTemplates))
} }
...@@ -103,6 +104,12 @@ var user = `<!DOCTYPE html> ...@@ -103,6 +104,12 @@ var user = `<!DOCTYPE html>
func TestRelativeTemplate(t *testing.T) { func TestRelativeTemplate(t *testing.T) {
dir := "_beeTmp" dir := "_beeTmp"
//Just add dir to known viewPaths
if err := AddViewPath(dir); err != nil {
t.Fatal(err)
}
files := []string{ files := []string{
"easyui/public/menu.tpl", "easyui/public/menu.tpl",
"easyui/rbac/user.tpl", "easyui/rbac/user.tpl",
...@@ -126,6 +133,7 @@ func TestRelativeTemplate(t *testing.T) { ...@@ -126,6 +133,7 @@ func TestRelativeTemplate(t *testing.T) {
if err := BuildTemplate(dir, files[1]); err != nil { if err := BuildTemplate(dir, files[1]); err != nil {
t.Fatal(err) t.Fatal(err)
} }
beeTemplates := beeViewPathTemplates[dir]
if err := beeTemplates["easyui/rbac/user.tpl"].ExecuteTemplate(os.Stdout, "easyui/rbac/user.tpl", nil); err != nil { if err := beeTemplates["easyui/rbac/user.tpl"].ExecuteTemplate(os.Stdout, "easyui/rbac/user.tpl", nil); err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
...@@ -117,7 +117,9 @@ func (m *URLMap) GetMap() map[string]interface{} { ...@@ -117,7 +117,9 @@ func (m *URLMap) GetMap() map[string]interface{} {
// GetMapData return all mapdata // GetMapData return all mapdata
func (m *URLMap) GetMapData() []map[string]interface{} { func (m *URLMap) GetMapData() []map[string]interface{} {
m.lock.Lock()
defer m.lock.Unlock()
var resultLists []map[string]interface{} var resultLists []map[string]interface{}
for k, v := range m.urlmap { for k, v := range m.urlmap {
......
...@@ -61,10 +61,8 @@ func (m *BeeMap) Set(k interface{}, v interface{}) bool { ...@@ -61,10 +61,8 @@ func (m *BeeMap) Set(k interface{}, v interface{}) bool {
func (m *BeeMap) Check(k interface{}) bool { func (m *BeeMap) Check(k interface{}) bool {
m.lock.RLock() m.lock.RLock()
defer m.lock.RUnlock() defer m.lock.RUnlock()
if _, ok := m.bm[k]; !ok { _, ok := m.bm[k]
return false return ok
}
return true
} }
// Delete the given key and value. // Delete the given key and value.
...@@ -84,3 +82,10 @@ func (m *BeeMap) Items() map[interface{}]interface{} { ...@@ -84,3 +82,10 @@ func (m *BeeMap) Items() map[interface{}]interface{} {
} }
return r return r
} }
// Count returns the number of items within the map.
func (m *BeeMap) Count() int {
m.lock.RLock()
defer m.lock.RUnlock()
return len(m.bm)
}
...@@ -14,25 +14,44 @@ ...@@ -14,25 +14,44 @@
package utils package utils
import ( import "testing"
"testing"
) var safeMap *BeeMap
func Test_beemap(t *testing.T) { func TestNewBeeMap(t *testing.T) {
bm := NewBeeMap() safeMap = NewBeeMap()
if !bm.Set("astaxie", 1) { if safeMap == nil {
t.Error("set Error") t.Fatal("expected to return non-nil BeeMap", "got", safeMap)
}
}
func TestSet(t *testing.T) {
if ok := safeMap.Set("astaxie", 1); !ok {
t.Error("expected", true, "got", false)
}
}
func TestCheck(t *testing.T) {
if exists := safeMap.Check("astaxie"); !exists {
t.Error("expected", true, "got", false)
} }
if !bm.Check("astaxie") { }
t.Error("check err")
func TestGet(t *testing.T) {
if val := safeMap.Get("astaxie"); val.(int) != 1 {
t.Error("expected value", 1, "got", val)
} }
}
if v := bm.Get("astaxie"); v.(int) != 1 { func TestDelete(t *testing.T) {
t.Error("get err") safeMap.Delete("astaxie")
if exists := safeMap.Check("astaxie"); exists {
t.Error("expected element to be deleted")
} }
}
bm.Delete("astaxie") func TestCount(t *testing.T) {
if bm.Check("astaxie") { if count := safeMap.Count(); count != 0 {
t.Error("delete err") t.Error("expected count to be", 0, "got", count)
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment