Commit cab8458c authored by astaxie's avatar astaxie Committed by GitHub

Merge pull request #2651 from astaxie/develop

v1.8.3
parents 83814a76 f0b95c55
...@@ -33,6 +33,8 @@ install: ...@@ -33,6 +33,8 @@ install:
- go get github.com/ssdb/gossdb/ssdb - go get github.com/ssdb/gossdb/ssdb
- go get github.com/cloudflare/golz4 - go get github.com/cloudflare/golz4
- go get github.com/gogo/protobuf/proto - go get github.com/gogo/protobuf/proto
- go get github.com/Knetic/govaluate
- go get github.com/hsluoyz/casbin
- go get -u honnef.co/go/tools/cmd/gosimple - go get -u honnef.co/go/tools/cmd/gosimple
- go get -u github.com/mdempsky/unconvert - go get -u github.com/mdempsky/unconvert
- go get -u github.com/gordonklaus/ineffassign - go get -u github.com/gordonklaus/ineffassign
......
...@@ -23,7 +23,7 @@ import ( ...@@ -23,7 +23,7 @@ import (
const ( const (
// VERSION represent beego web framework version. // VERSION represent beego web framework version.
VERSION = "1.8.2" VERSION = "1.8.3"
// DEV is for develop // DEV is for develop
DEV = "dev" DEV = "dev"
......
...@@ -28,7 +28,7 @@ func GetString(v interface{}) string { ...@@ -28,7 +28,7 @@ func GetString(v interface{}) string {
return string(result) return string(result)
default: default:
if v != nil { if v != nil {
return fmt.Sprintf("%v", result) return fmt.Sprint(result)
} }
} }
return "" return ""
......
...@@ -171,6 +171,22 @@ func (ctx *Context) CheckXSRFCookie() bool { ...@@ -171,6 +171,22 @@ func (ctx *Context) CheckXSRFCookie() bool {
return true return true
} }
// RenderMethodResult renders the return value of a controller method to the output
func (ctx *Context) RenderMethodResult(result interface{}) {
if result != nil {
renderer, ok := result.(Renderer)
if !ok {
err, ok := result.(error)
if ok {
renderer = errorRenderer(err)
} else {
renderer = jsonRenderer(result)
}
}
renderer.Render(ctx)
}
}
//Response is a wrapper for the http.ResponseWriter //Response is a wrapper for the http.ResponseWriter
//started set to true if response was written to then don't execute other handler //started set to true if response was written to then don't execute other handler
type Response struct { type Response struct {
......
...@@ -168,6 +168,19 @@ func sanitizeValue(v string) string { ...@@ -168,6 +168,19 @@ func sanitizeValue(v string) string {
return cookieValueSanitizer.Replace(v) return cookieValueSanitizer.Replace(v)
} }
func jsonRenderer(value interface{}) Renderer {
return rendererFunc(func(ctx *Context) {
ctx.Output.JSON(value, false, false)
})
}
func errorRenderer(err error) Renderer {
return rendererFunc(func(ctx *Context) {
ctx.Output.SetStatus(500)
ctx.WriteString(err.Error())
})
}
// JSON writes json to response body. // JSON writes json to response body.
// if coding is true, it converts utf-8 to \u0000 type. // if coding is true, it converts utf-8 to \u0000 type.
func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, coding bool) error { func (output *BeegoOutput) JSON(data interface{}, hasIndent bool, coding bool) error {
...@@ -330,9 +343,8 @@ func (output *BeegoOutput) IsServerError() bool { ...@@ -330,9 +343,8 @@ func (output *BeegoOutput) IsServerError() bool {
} }
func stringsToJSON(str string) string { func stringsToJSON(str string) string {
rs := []rune(str)
var jsons bytes.Buffer var jsons bytes.Buffer
for _, r := range rs { for _, r := range str {
rint := int(r) rint := int(r)
if rint < 128 { if rint < 128 {
jsons.WriteRune(r) jsons.WriteRune(r)
......
package param
import (
"fmt"
"reflect"
beecontext "github.com/astaxie/beego/context"
"github.com/astaxie/beego/logs"
)
// ConvertParams converts http method params to values that will be passed to the method controller as arguments
func ConvertParams(methodParams []*MethodParam, methodType reflect.Type, ctx *beecontext.Context) (result []reflect.Value) {
result = make([]reflect.Value, 0, len(methodParams))
for i := 0; i < len(methodParams); i++ {
reflectValue := convertParam(methodParams[i], methodType.In(i), ctx)
result = append(result, reflectValue)
}
return
}
func convertParam(param *MethodParam, paramType reflect.Type, ctx *beecontext.Context) (result reflect.Value) {
paramValue := getParamValue(param, ctx)
if paramValue == "" {
if param.required {
ctx.Abort(400, fmt.Sprintf("Missing parameter %s", param.name))
} else {
paramValue = param.defaultValue
}
}
reflectValue, err := parseValue(param, paramValue, paramType)
if err != nil {
logs.Debug(fmt.Sprintf("Error converting param %s to type %s. Value: %v, Error: %s", param.name, paramType, paramValue, err))
ctx.Abort(400, fmt.Sprintf("Invalid parameter %s. Can not convert %v to type %s", param.name, paramValue, paramType))
}
return reflectValue
}
func getParamValue(param *MethodParam, ctx *beecontext.Context) string {
switch param.in {
case body:
return string(ctx.Input.RequestBody)
case header:
return ctx.Input.Header(param.name)
case path:
return ctx.Input.Query(":" + param.name)
default:
return ctx.Input.Query(param.name)
}
}
func parseValue(param *MethodParam, paramValue string, paramType reflect.Type) (result reflect.Value, err error) {
if paramValue == "" {
return reflect.Zero(paramType), nil
}
parser := getParser(param, paramType)
value, err := parser.parse(paramValue, paramType)
if err != nil {
return result, err
}
return safeConvert(reflect.ValueOf(value), paramType)
}
func safeConvert(value reflect.Value, t reflect.Type) (result reflect.Value, err error) {
defer func() {
if r := recover(); r != nil {
var ok bool
err, ok = r.(error)
if !ok {
err = fmt.Errorf("%v", r)
}
}
}()
result = value.Convert(t)
return
}
package param
import (
"fmt"
"strings"
)
//MethodParam keeps param information to be auto passed to controller methods
type MethodParam struct {
name string
in paramType
required bool
defaultValue string
}
type paramType byte
const (
param paramType = iota
path
body
header
)
//New creates a new MethodParam with name and specific options
func New(name string, opts ...MethodParamOption) *MethodParam {
return newParam(name, nil, opts)
}
func newParam(name string, parser paramParser, opts []MethodParamOption) (param *MethodParam) {
param = &MethodParam{name: name}
for _, option := range opts {
option(param)
}
return
}
//Make creates an array of MethodParmas or an empty array
func Make(list ...*MethodParam) []*MethodParam {
if len(list) > 0 {
return list
}
return nil
}
func (mp *MethodParam) String() string {
options := []string{}
result := "param.New(\"" + mp.name + "\""
if mp.required {
options = append(options, "param.IsRequired")
}
switch mp.in {
case path:
options = append(options, "param.InPath")
case body:
options = append(options, "param.InBody")
case header:
options = append(options, "param.InHeader")
}
if mp.defaultValue != "" {
options = append(options, fmt.Sprintf(`param.Default("%s")`, mp.defaultValue))
}
if len(options) > 0 {
result += ", "
}
result += strings.Join(options, ", ")
result += ")"
return result
}
package param
import (
"fmt"
)
// MethodParamOption defines a func which apply options on a MethodParam
type MethodParamOption func(*MethodParam)
// IsRequired indicates that this param is required and can not be ommited from the http request
var IsRequired MethodParamOption = func(p *MethodParam) {
p.required = true
}
// InHeader indicates that this param is passed via an http header
var InHeader MethodParamOption = func(p *MethodParam) {
p.in = header
}
// InPath indicates that this param is part of the URL path
var InPath MethodParamOption = func(p *MethodParam) {
p.in = path
}
// InBody indicates that this param is passed as an http request body
var InBody MethodParamOption = func(p *MethodParam) {
p.in = body
}
// Default provides a default value for the http param
func Default(defaultValue interface{}) MethodParamOption {
return func(p *MethodParam) {
if defaultValue != nil {
p.defaultValue = fmt.Sprint(defaultValue)
}
}
}
package param
import (
"encoding/json"
"reflect"
"strconv"
"strings"
"time"
)
type paramParser interface {
parse(value string, toType reflect.Type) (interface{}, error)
}
func getParser(param *MethodParam, t reflect.Type) paramParser {
switch t.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return intParser{}
case reflect.Slice:
if t.Elem().Kind() == reflect.Uint8 { //treat []byte as string
return stringParser{}
}
if param.in == body {
return jsonParser{}
}
elemParser := getParser(param, t.Elem())
if elemParser == (jsonParser{}) {
return elemParser
}
return sliceParser(elemParser)
case reflect.Bool:
return boolParser{}
case reflect.String:
return stringParser{}
case reflect.Float32, reflect.Float64:
return floatParser{}
case reflect.Ptr:
elemParser := getParser(param, t.Elem())
if elemParser == (jsonParser{}) {
return elemParser
}
return ptrParser(elemParser)
default:
if t.PkgPath() == "time" && t.Name() == "Time" {
return timeParser{}
}
return jsonParser{}
}
}
type parserFunc func(value string, toType reflect.Type) (interface{}, error)
func (f parserFunc) parse(value string, toType reflect.Type) (interface{}, error) {
return f(value, toType)
}
type boolParser struct {
}
func (p boolParser) parse(value string, toType reflect.Type) (interface{}, error) {
return strconv.ParseBool(value)
}
type stringParser struct {
}
func (p stringParser) parse(value string, toType reflect.Type) (interface{}, error) {
return value, nil
}
type intParser struct {
}
func (p intParser) parse(value string, toType reflect.Type) (interface{}, error) {
return strconv.Atoi(value)
}
type floatParser struct {
}
func (p floatParser) parse(value string, toType reflect.Type) (interface{}, error) {
if toType.Kind() == reflect.Float32 {
res, err := strconv.ParseFloat(value, 32)
if err != nil {
return nil, err
}
return float32(res), nil
}
return strconv.ParseFloat(value, 64)
}
type timeParser struct {
}
func (p timeParser) parse(value string, toType reflect.Type) (result interface{}, err error) {
result, err = time.Parse(time.RFC3339, value)
if err != nil {
result, err = time.Parse("2006-01-02", value)
}
return
}
type jsonParser struct {
}
func (p jsonParser) parse(value string, toType reflect.Type) (interface{}, error) {
pResult := reflect.New(toType)
v := pResult.Interface()
err := json.Unmarshal([]byte(value), v)
if err != nil {
return nil, err
}
return pResult.Elem().Interface(), nil
}
func sliceParser(elemParser paramParser) paramParser {
return parserFunc(func(value string, toType reflect.Type) (interface{}, error) {
values := strings.Split(value, ",")
result := reflect.MakeSlice(toType, 0, len(values))
elemType := toType.Elem()
for _, v := range values {
parsedValue, err := elemParser.parse(v, elemType)
if err != nil {
return nil, err
}
result = reflect.Append(result, reflect.ValueOf(parsedValue))
}
return result.Interface(), nil
})
}
func ptrParser(elemParser paramParser) paramParser {
return parserFunc(func(value string, toType reflect.Type) (interface{}, error) {
parsedValue, err := elemParser.parse(value, toType.Elem())
if err != nil {
return nil, err
}
newValPtr := reflect.New(toType.Elem())
newVal := reflect.Indirect(newValPtr)
convertedVal, err := safeConvert(reflect.ValueOf(parsedValue), toType.Elem())
if err != nil {
return nil, err
}
newVal.Set(convertedVal)
return newValPtr.Interface(), nil
})
}
package param
import "testing"
import "reflect"
import "time"
type testDefinition struct {
strValue string
expectedValue interface{}
expectedParser paramParser
}
func Test_Parsers(t *testing.T) {
//ints
checkParser(testDefinition{"1", 1, intParser{}}, t)
checkParser(testDefinition{"-1", int64(-1), intParser{}}, t)
checkParser(testDefinition{"1", uint64(1), intParser{}}, t)
//floats
checkParser(testDefinition{"1.0", float32(1.0), floatParser{}}, t)
checkParser(testDefinition{"-1.0", float64(-1.0), floatParser{}}, t)
//strings
checkParser(testDefinition{"AB", "AB", stringParser{}}, t)
checkParser(testDefinition{"AB", []byte{65, 66}, stringParser{}}, t)
//bools
checkParser(testDefinition{"true", true, boolParser{}}, t)
checkParser(testDefinition{"0", false, boolParser{}}, t)
//timeParser
checkParser(testDefinition{"2017-05-30T13:54:53Z", time.Date(2017, 5, 30, 13, 54, 53, 0, time.UTC), timeParser{}}, t)
checkParser(testDefinition{"2017-05-30", time.Date(2017, 5, 30, 0, 0, 0, 0, time.UTC), timeParser{}}, t)
//json
checkParser(testDefinition{`{"X": 5, "Y":"Z"}`, struct {
X int
Y string
}{5, "Z"}, jsonParser{}}, t)
//slice in query is parsed as comma delimited
checkParser(testDefinition{`1,2`, []int{1, 2}, sliceParser(intParser{})}, t)
//slice in body is parsed as json
checkParser(testDefinition{`["a","b"]`, []string{"a", "b"}, jsonParser{}}, t, MethodParam{in: body})
//pointers
var someInt = 1
checkParser(testDefinition{`1`, &someInt, ptrParser(intParser{})}, t)
var someStruct = struct{ X int }{5}
checkParser(testDefinition{`{"X": 5}`, &someStruct, jsonParser{}}, t)
}
func checkParser(def testDefinition, t *testing.T, methodParam ...MethodParam) {
toType := reflect.TypeOf(def.expectedValue)
var mp MethodParam
if len(methodParam) == 0 {
mp = MethodParam{}
} else {
mp = methodParam[0]
}
parser := getParser(&mp, toType)
if reflect.TypeOf(parser) != reflect.TypeOf(def.expectedParser) {
t.Errorf("Invalid parser for value %v. Expected: %v, actual: %v", def.strValue, reflect.TypeOf(def.expectedParser).Name(), reflect.TypeOf(parser).Name())
return
}
result, err := parser.parse(def.strValue, toType)
if err != nil {
t.Errorf("Parsing error for value %v. Expected result: %v, error: %v", def.strValue, def.expectedValue, err)
return
}
convResult, err := safeConvert(reflect.ValueOf(result), toType)
if err != nil {
t.Errorf("Convertion error for %v. from value: %v, toType: %v, error: %v", def.strValue, result, toType, err)
return
}
if !reflect.DeepEqual(convResult.Interface(), def.expectedValue) {
t.Errorf("Parsing error for value %v. Expected result: %v, actual: %v", def.strValue, def.expectedValue, result)
}
}
package context
// Renderer defines an http response renderer
type Renderer interface {
Render(ctx *Context)
}
type rendererFunc func(ctx *Context)
func (f rendererFunc) Render(ctx *Context) {
f(ctx)
}
package context
import (
"strconv"
"net/http"
)
const (
//BadRequest indicates http error 400
BadRequest StatusCode = http.StatusBadRequest
//NotFound indicates http error 404
NotFound StatusCode = http.StatusNotFound
)
// StatusCode sets the http response status code
type StatusCode int
func (s StatusCode) Error() string {
return strconv.Itoa(int(s))
}
// Render sets the http status code
func (s StatusCode) Render(ctx *Context) {
ctx.Output.SetStatus(int(s))
}
...@@ -28,6 +28,7 @@ import ( ...@@ -28,6 +28,7 @@ import (
"strings" "strings"
"github.com/astaxie/beego/context" "github.com/astaxie/beego/context"
"github.com/astaxie/beego/context/param"
"github.com/astaxie/beego/session" "github.com/astaxie/beego/session"
) )
...@@ -51,6 +52,7 @@ type ControllerComments struct { ...@@ -51,6 +52,7 @@ type ControllerComments struct {
Router string Router string
AllowHTTPMethods []string AllowHTTPMethods []string
Params []map[string]string Params []map[string]string
MethodParams []*param.MethodParam
} }
// Controller defines some basic http request handler operations, such as // Controller defines some basic http request handler operations, such as
......
...@@ -252,6 +252,30 @@ func forbidden(rw http.ResponseWriter, r *http.Request) { ...@@ -252,6 +252,30 @@ func forbidden(rw http.ResponseWriter, r *http.Request) {
) )
} }
// show 422 missing xsrf token
func missingxsrf(rw http.ResponseWriter, r *http.Request) {
responseError(rw, r,
422,
"<br>The page you have requested is forbidden."+
"<br>Perhaps you are here because:"+
"<br><br><ul>"+
"<br>'_xsrf' argument missing from POST"+
"</ul>",
)
}
// show 417 invalid xsrf token
func invalidxsrf(rw http.ResponseWriter, r *http.Request) {
responseError(rw, r,
417,
"<br>The page you have requested is forbidden."+
"<br>Perhaps you are here because:"+
"<br><br><ul>"+
"<br>expected XSRF not found"+
"</ul>",
)
}
// show 404 not found error. // show 404 not found error.
func notFound(rw http.ResponseWriter, r *http.Request) { func notFound(rw http.ResponseWriter, r *http.Request) {
responseError(rw, r, responseError(rw, r,
......
...@@ -32,6 +32,8 @@ func registerDefaultErrorHandler() error { ...@@ -32,6 +32,8 @@ func registerDefaultErrorHandler() error {
"502": badGateway, "502": badGateway,
"503": serviceUnavailable, "503": serviceUnavailable,
"504": gatewayTimeout, "504": gatewayTimeout,
"417": invalidxsrf,
"422": missingxsrf,
} }
for e, h := range m { for e, h := range m {
if _, ok := ErrorMaps[e]; !ok { if _, ok := ErrorMaps[e]; !ok {
......
...@@ -833,7 +833,11 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con ...@@ -833,7 +833,11 @@ func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Con
if err := rs.Scan(&ref); err != nil { if err := rs.Scan(&ref); err != nil {
return 0, err return 0, err
} }
args = append(args, reflect.ValueOf(ref).Interface()) pkValue, err := d.convertValueFromDB(mi.fields.pk, reflect.ValueOf(ref).Interface(), tz)
if err != nil {
return 0, err
}
args = append(args, pkValue)
cnt++ cnt++
} }
......
...@@ -75,7 +75,7 @@ func registerModel(PrefixOrSuffix string, model interface{}, isPrefix bool) { ...@@ -75,7 +75,7 @@ func registerModel(PrefixOrSuffix string, model interface{}, isPrefix bool) {
} }
if mi.fields.pk == nil { if mi.fields.pk == nil {
fmt.Printf("<orm.RegisterModel> `%s` need a primary key field, default use 'id' if not set\n", name) fmt.Printf("<orm.RegisterModel> `%s` needs a primary key field, default is to use 'id' if not set\n", name)
os.Exit(2) os.Exit(2)
} }
......
...@@ -107,7 +107,7 @@ func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect ...@@ -107,7 +107,7 @@ func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect
if mi, ok := modelCache.getByFullName(name); ok { if mi, ok := modelCache.getByFullName(name); ok {
return mi, ind return mi, ind
} }
panic(fmt.Errorf("<Ormer> table: `%s` not found, maybe not RegisterModel", name)) panic(fmt.Errorf("<Ormer> table: `%s` not found, make sure it was registered with `RegisterModel()`", name))
} }
// get field info from model info by given field name // get field info from model info by given field name
......
...@@ -493,19 +493,33 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) { ...@@ -493,19 +493,33 @@ func (o *rawSet) QueryRows(containers ...interface{}) (int64, error) {
} }
} }
} else { } else {
for i := 0; i < ind.NumField(); i++ { // define recursive function
f := ind.Field(i) var recursiveSetField func(rv reflect.Value)
fe := ind.Type().Field(i) recursiveSetField = func(rv reflect.Value) {
_, tags := parseStructTag(fe.Tag.Get(defaultStructTagName)) for i := 0; i < rv.NumField(); i++ {
var col string f := rv.Field(i)
if col = tags["column"]; col == "" { fe := rv.Type().Field(i)
col = snakeString(fe.Name)
} // check if the field is a Struct
if v, ok := columnsMp[col]; ok { // recursive the Struct type
value := reflect.ValueOf(v).Elem().Interface() if fe.Type.Kind() == reflect.Struct {
o.setFieldValue(f, value) recursiveSetField(f)
}
_, tags := parseStructTag(fe.Tag.Get(defaultStructTagName))
var col string
if col = tags["column"]; col == "" {
col = snakeString(fe.Name)
}
if v, ok := columnsMp[col]; ok {
value := reflect.ValueOf(v).Elem().Interface()
o.setFieldValue(f, value)
}
} }
} }
// init call the recursive function
recursiveSetField(ind)
} }
if eTyps[0].Kind() == reflect.Ptr { if eTyps[0].Kind() == reflect.Ptr {
......
...@@ -1661,6 +1661,13 @@ func TestRawQueryRow(t *testing.T) { ...@@ -1661,6 +1661,13 @@ func TestRawQueryRow(t *testing.T) {
throwFail(t, AssertIs(pid, nil)) throwFail(t, AssertIs(pid, nil))
} }
// user_profile table
type userProfile struct {
User
Age int
Money float64
}
func TestQueryRows(t *testing.T) { func TestQueryRows(t *testing.T) {
Q := dDbBaser.TableQuote() Q := dDbBaser.TableQuote()
...@@ -1731,6 +1738,19 @@ func TestQueryRows(t *testing.T) { ...@@ -1731,6 +1738,19 @@ func TestQueryRows(t *testing.T) {
throwFailNow(t, AssertIs(usernames[1], "astaxie")) throwFailNow(t, AssertIs(usernames[1], "astaxie"))
throwFailNow(t, AssertIs(ids[2], 4)) throwFailNow(t, AssertIs(ids[2], 4))
throwFailNow(t, AssertIs(usernames[2], "nobody")) throwFailNow(t, AssertIs(usernames[2], "nobody"))
//test query rows by nested struct
var l []userProfile
query = fmt.Sprintf("SELECT * FROM %suser_profile%s LEFT JOIN %suser%s ON %suser_profile%s.%sid%s = %suser%s.%sid%s", Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q)
num, err = dORM.Raw(query).QueryRows(&l)
throwFailNow(t, err)
throwFailNow(t, AssertIs(num, 2))
throwFailNow(t, AssertIs(len(l), 2))
throwFailNow(t, AssertIs(l[0].UserName, "slene"))
throwFailNow(t, AssertIs(l[0].Age, 28))
throwFailNow(t, AssertIs(l[1].UserName, "astaxie"))
throwFailNow(t, AssertIs(l[1].Age, 30))
} }
func TestRawValues(t *testing.T) { func TestRawValues(t *testing.T) {
......
...@@ -24,9 +24,13 @@ import ( ...@@ -24,9 +24,13 @@ import (
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"regexp"
"sort" "sort"
"strconv"
"strings" "strings"
"unicode"
"github.com/astaxie/beego/context/param"
"github.com/astaxie/beego/logs" "github.com/astaxie/beego/logs"
"github.com/astaxie/beego/utils" "github.com/astaxie/beego/utils"
) )
...@@ -35,6 +39,7 @@ var globalRouterTemplate = `package routers ...@@ -35,6 +39,7 @@ var globalRouterTemplate = `package routers
import ( import (
"github.com/astaxie/beego" "github.com/astaxie/beego"
"github.com/astaxie/beego/context/param"
) )
func init() { func init() {
...@@ -81,7 +86,7 @@ func parserPkg(pkgRealpath, pkgpath string) error { ...@@ -81,7 +86,7 @@ func parserPkg(pkgRealpath, pkgpath string) error {
if specDecl.Recv != nil { if specDecl.Recv != nil {
exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser
if ok { if ok {
parserComments(specDecl.Doc, specDecl.Name.String(), fmt.Sprint(exp.X), pkgpath) parserComments(specDecl, fmt.Sprint(exp.X), pkgpath)
} }
} }
} }
...@@ -93,42 +98,167 @@ func parserPkg(pkgRealpath, pkgpath string) error { ...@@ -93,42 +98,167 @@ func parserPkg(pkgRealpath, pkgpath string) error {
return nil return nil
} }
func parserComments(comments *ast.CommentGroup, funcName, controllerName, pkgpath string) error { type parsedComment struct {
if comments != nil && comments.List != nil { routerPath string
for _, c := range comments.List { methods []string
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//")) params map[string]parsedParam
if strings.HasPrefix(t, "@router") { }
elements := strings.TrimLeft(t, "@router ")
e1 := strings.SplitN(elements, " ", 2) type parsedParam struct {
if len(e1) < 1 { name string
return errors.New("you should has router information") datatype string
} location string
key := pkgpath + ":" + controllerName defValue string
cc := ControllerComments{} required bool
cc.Method = funcName }
cc.Router = e1[0]
if len(e1) == 2 && e1[1] != "" { func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error {
e1 = strings.SplitN(e1[1], " ", 2) if f.Doc != nil {
if len(e1) >= 1 { parsedComment, err := parseComment(f.Doc.List)
cc.AllowHTTPMethods = strings.Split(strings.Trim(e1[0], "[]"), ",") if err != nil {
} else { return err
cc.AllowHTTPMethods = append(cc.AllowHTTPMethods, "get") }
} if parsedComment.routerPath != "" {
key := pkgpath + ":" + controllerName
cc := ControllerComments{}
cc.Method = f.Name.String()
cc.Router = parsedComment.routerPath
cc.AllowHTTPMethods = parsedComment.methods
cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment)
genInfoList[key] = append(genInfoList[key], cc)
}
}
return nil
}
func buildMethodParams(funcParams []*ast.Field, pc *parsedComment) []*param.MethodParam {
result := make([]*param.MethodParam, 0, len(funcParams))
for _, fparam := range funcParams {
methodParam := buildMethodParam(fparam, pc)
result = append(result, methodParam)
}
return result
}
func buildMethodParam(fparam *ast.Field, pc *parsedComment) *param.MethodParam {
options := []param.MethodParamOption{}
name := fparam.Names[0].Name
if cparam, ok := pc.params[name]; ok {
//Build param from comment info
name = cparam.name
if cparam.required {
options = append(options, param.IsRequired)
}
switch cparam.location {
case "body":
options = append(options, param.InBody)
case "header":
options = append(options, param.InHeader)
case "path":
options = append(options, param.InPath)
}
if cparam.defValue != "" {
options = append(options, param.Default(cparam.defValue))
}
} else {
if paramInPath(name, pc.routerPath) {
options = append(options, param.InPath)
}
}
return param.New(name, options...)
}
func paramInPath(name, route string) bool {
return strings.HasSuffix(route, ":"+name) ||
strings.Contains(route, ":"+name+"/")
}
var routeRegex = regexp.MustCompile(`@router\s+(\S+)(?:\s+\[(\S+)\])?`)
func parseComment(lines []*ast.Comment) (pc *parsedComment, err error) {
pc = &parsedComment{}
for _, c := range lines {
t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
if strings.HasPrefix(t, "@router") {
matches := routeRegex.FindStringSubmatch(t)
if len(matches) == 3 {
pc.routerPath = matches[1]
methods := matches[2]
if methods == "" {
pc.methods = []string{"get"}
//pc.hasGet = true
} else { } else {
cc.AllowHTTPMethods = append(cc.AllowHTTPMethods, "get") pc.methods = strings.Split(methods, ",")
//pc.hasGet = strings.Contains(methods, "get")
} }
if len(e1) == 2 && e1[1] != "" { } else {
keyval := strings.Split(strings.Trim(e1[1], "[]"), " ") return nil, errors.New("Router information is missing")
for _, kv := range keyval { }
kk := strings.Split(kv, ":") } else if strings.HasPrefix(t, "@Param") {
cc.Params = append(cc.Params, map[string]string{strings.Join(kk[:len(kk)-1], ":"): kk[len(kk)-1]}) pv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Param")))
} if len(pv) < 4 {
} logs.Error("Invalid @Param format. Needs at least 4 parameters")
genInfoList[key] = append(genInfoList[key], cc) }
p := parsedParam{}
names := strings.SplitN(pv[0], "=>", 2)
p.name = names[0]
funcParamName := p.name
if len(names) > 1 {
funcParamName = names[1]
}
p.location = pv[1]
p.datatype = pv[2]
switch len(pv) {
case 5:
p.required, _ = strconv.ParseBool(pv[3])
case 6:
p.defValue = pv[3]
p.required, _ = strconv.ParseBool(pv[4])
}
if pc.params == nil {
pc.params = map[string]parsedParam{}
} }
pc.params[funcParamName] = p
} }
} }
return nil return
}
// direct copy from bee\g_docs.go
// analisys params return []string
// @Param query form string true "The email for login"
// [query form string true "The email for login"]
func getparams(str string) []string {
var s []rune
var j int
var start bool
var r []string
var quoted int8
for _, c := range str {
if unicode.IsSpace(c) && quoted == 0 {
if !start {
continue
} else {
start = false
j++
r = append(r, string(s))
s = make([]rune, 0)
continue
}
}
start = true
if c == '"' {
quoted ^= 1
continue
}
s = append(s, c)
}
if len(s) > 0 {
r = append(r, string(s))
}
return r
} }
func genRouterCode(pkgRealpath string) { func genRouterCode(pkgRealpath string) {
...@@ -163,12 +293,24 @@ func genRouterCode(pkgRealpath string) { ...@@ -163,12 +293,24 @@ func genRouterCode(pkgRealpath string) {
} }
params = strings.TrimRight(params, ",") + "}" params = strings.TrimRight(params, ",") + "}"
} }
methodParams := "param.Make("
if len(c.MethodParams) > 0 {
lines := make([]string, 0, len(c.MethodParams))
for _, m := range c.MethodParams {
lines = append(lines, fmt.Sprint(m))
}
methodParams += "\n " +
strings.Join(lines, ",\n ") +
",\n "
}
methodParams += ")"
globalinfo = globalinfo + ` globalinfo = globalinfo + `
beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"], beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"],
beego.ControllerComments{ beego.ControllerComments{
Method: "` + strings.TrimSpace(c.Method) + `", Method: "` + strings.TrimSpace(c.Method) + `",
` + "Router: `" + c.Router + "`" + `, ` + "Router: `" + c.Router + "`" + `,
AllowHTTPMethods: ` + allmethod + `, AllowHTTPMethods: ` + allmethod + `,
MethodParams: ` + methodParams + `,
Params: ` + params + `}) Params: ` + params + `})
` `
} }
......
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package authz provides handlers to enable ACL, RBAC, ABAC authorization support.
// Simple Usage:
// import(
// "github.com/astaxie/beego"
// "github.com/astaxie/beego/plugins/authz"
// "github.com/hsluoyz/casbin"
// )
//
// func main(){
// // mediate the access for every request
// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")))
// beego.Run()
// }
//
//
// Advanced Usage:
//
// func main(){
// e := casbin.NewEnforcer("authz_model.conf", "")
// e.AddRoleForUser("alice", "admin")
// e.AddPolicy(...)
//
// beego.InsertFilter("*", beego.BeforeRouter, authz.NewAuthorizer(e))
// beego.Run()
// }
package authz
import (
"github.com/astaxie/beego"
"github.com/astaxie/beego/context"
"github.com/hsluoyz/casbin"
"net/http"
)
// NewAuthorizer returns the authorizer.
// Use a casbin enforcer as input
func NewAuthorizer(e *casbin.Enforcer) beego.FilterFunc {
return func(ctx *context.Context) {
a := &BasicAuthorizer{enforcer: e}
if !a.CheckPermission(ctx.Request) {
a.RequirePermission(ctx.ResponseWriter)
}
}
}
// BasicAuthorizer stores the casbin handler
type BasicAuthorizer struct {
enforcer *casbin.Enforcer
}
// GetUserName gets the user name from the request.
// Currently, only HTTP basic authentication is supported
func (a *BasicAuthorizer) GetUserName(r *http.Request) string {
username, _, _ := r.BasicAuth()
return username
}
// CheckPermission checks the user/method/path combination from the request.
// Returns true (permission granted) or false (permission forbidden)
func (a *BasicAuthorizer) CheckPermission(r *http.Request) bool {
user := a.GetUserName(r)
method := r.Method
path := r.URL.Path
return a.enforcer.Enforce(user, path, method)
}
// RequirePermission returns the 403 Forbidden to the client
func (a *BasicAuthorizer) RequirePermission(w http.ResponseWriter) {
w.WriteHeader(403)
w.Write([]byte("403 Forbidden\n"))
}
[request_definition]
r = sub, obj, act
[policy_definition]
p = sub, obj, act
[role_definition]
g = _, _
[policy_effect]
e = some(where (p.eft == allow))
[matchers]
m = g(r.sub, p.sub) && keyMatch(r.obj, p.obj) && (r.act == p.act || p.act == "*")
\ No newline at end of file
p, alice, /dataset1/*, GET
p, alice, /dataset1/resource1, POST
p, bob, /dataset2/resource1, *
p, bob, /dataset2/resource2, GET
p, bob, /dataset2/folder1/*, POST
p, dataset1_admin, /dataset1/*, *
g, cathy, dataset1_admin
\ No newline at end of file
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package authz
import (
"github.com/astaxie/beego"
"github.com/astaxie/beego/context"
"github.com/astaxie/beego/plugins/auth"
"github.com/hsluoyz/casbin"
"net/http"
"net/http/httptest"
"testing"
)
func testRequest(t *testing.T, handler *beego.ControllerRegister, user string, path string, method string, code int) {
r, _ := http.NewRequest(method, path, nil)
r.SetBasicAuth(user, "123")
w := httptest.NewRecorder()
handler.ServeHTTP(w, r)
if w.Code != code {
t.Errorf("%s, %s, %s: %d, supposed to be %d", user, path, method, w.Code, code)
}
}
func TestBasic(t *testing.T) {
handler := beego.NewControllerRegister()
handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("alice", "123"))
handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")))
handler.Any("*", func(ctx *context.Context) {
ctx.Output.SetStatus(200)
})
testRequest(t, handler, "alice", "/dataset1/resource1", "GET", 200)
testRequest(t, handler, "alice", "/dataset1/resource1", "POST", 200)
testRequest(t, handler, "alice", "/dataset1/resource2", "GET", 200)
testRequest(t, handler, "alice", "/dataset1/resource2", "POST", 403)
}
func TestPathWildcard(t *testing.T) {
handler := beego.NewControllerRegister()
handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("bob", "123"))
handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")))
handler.Any("*", func(ctx *context.Context) {
ctx.Output.SetStatus(200)
})
testRequest(t, handler, "bob", "/dataset2/resource1", "GET", 200)
testRequest(t, handler, "bob", "/dataset2/resource1", "POST", 200)
testRequest(t, handler, "bob", "/dataset2/resource1", "DELETE", 200)
testRequest(t, handler, "bob", "/dataset2/resource2", "GET", 200)
testRequest(t, handler, "bob", "/dataset2/resource2", "POST", 403)
testRequest(t, handler, "bob", "/dataset2/resource2", "DELETE", 403)
testRequest(t, handler, "bob", "/dataset2/folder1/item1", "GET", 403)
testRequest(t, handler, "bob", "/dataset2/folder1/item1", "POST", 200)
testRequest(t, handler, "bob", "/dataset2/folder1/item1", "DELETE", 403)
testRequest(t, handler, "bob", "/dataset2/folder1/item2", "GET", 403)
testRequest(t, handler, "bob", "/dataset2/folder1/item2", "POST", 200)
testRequest(t, handler, "bob", "/dataset2/folder1/item2", "DELETE", 403)
}
func TestRBAC(t *testing.T) {
handler := beego.NewControllerRegister()
handler.InsertFilter("*", beego.BeforeRouter, auth.Basic("cathy", "123"))
e := casbin.NewEnforcer("authz_model.conf", "authz_policy.csv")
handler.InsertFilter("*", beego.BeforeRouter, NewAuthorizer(e))
handler.Any("*", func(ctx *context.Context) {
ctx.Output.SetStatus(200)
})
// cathy can access all /dataset1/* resources via all methods because it has the dataset1_admin role.
testRequest(t, handler, "cathy", "/dataset1/item", "GET", 200)
testRequest(t, handler, "cathy", "/dataset1/item", "POST", 200)
testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 200)
testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403)
testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403)
testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403)
// delete all roles on user cathy, so cathy cannot access any resources now.
e.DeleteRolesForUser("cathy")
testRequest(t, handler, "cathy", "/dataset1/item", "GET", 403)
testRequest(t, handler, "cathy", "/dataset1/item", "POST", 403)
testRequest(t, handler, "cathy", "/dataset1/item", "DELETE", 403)
testRequest(t, handler, "cathy", "/dataset2/item", "GET", 403)
testRequest(t, handler, "cathy", "/dataset2/item", "POST", 403)
testRequest(t, handler, "cathy", "/dataset2/item", "DELETE", 403)
}
...@@ -27,6 +27,7 @@ import ( ...@@ -27,6 +27,7 @@ import (
"time" "time"
beecontext "github.com/astaxie/beego/context" beecontext "github.com/astaxie/beego/context"
"github.com/astaxie/beego/context/param"
"github.com/astaxie/beego/logs" "github.com/astaxie/beego/logs"
"github.com/astaxie/beego/toolbox" "github.com/astaxie/beego/toolbox"
"github.com/astaxie/beego/utils" "github.com/astaxie/beego/utils"
...@@ -116,6 +117,7 @@ type ControllerInfo struct { ...@@ -116,6 +117,7 @@ type ControllerInfo struct {
handler http.Handler handler http.Handler
runFunction FilterFunc runFunction FilterFunc
routerType int routerType int
methodParams []*param.MethodParam
} }
// ControllerRegister containers registered router rules, controller handlers and filters. // ControllerRegister containers registered router rules, controller handlers and filters.
...@@ -151,6 +153,10 @@ func NewControllerRegister() *ControllerRegister { ...@@ -151,6 +153,10 @@ func NewControllerRegister() *ControllerRegister {
// Add("/api",&RestController{},"get,post:ApiFunc" // Add("/api",&RestController{},"get,post:ApiFunc"
// Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc") // Add("/simple",&SimpleController{},"get:GetFunc;post:PostFunc")
func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) { func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingMethods ...string) {
p.addWithMethodParams(pattern, c, nil, mappingMethods...)
}
func (p *ControllerRegister) addWithMethodParams(pattern string, c ControllerInterface, methodParams []*param.MethodParam, mappingMethods ...string) {
reflectVal := reflect.ValueOf(c) reflectVal := reflect.ValueOf(c)
t := reflect.Indirect(reflectVal).Type() t := reflect.Indirect(reflectVal).Type()
methods := make(map[string]string) methods := make(map[string]string)
...@@ -181,6 +187,7 @@ func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingM ...@@ -181,6 +187,7 @@ func (p *ControllerRegister) Add(pattern string, c ControllerInterface, mappingM
route.methods = methods route.methods = methods
route.routerType = routerTypeBeego route.routerType = routerTypeBeego
route.controllerType = t route.controllerType = t
route.methodParams = methodParams
if len(methods) == 0 { if len(methods) == 0 {
for _, m := range HTTPMETHOD { for _, m := range HTTPMETHOD {
p.addToRouter(m, pattern, route) p.addToRouter(m, pattern, route)
...@@ -245,7 +252,7 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) { ...@@ -245,7 +252,7 @@ func (p *ControllerRegister) Include(cList ...ControllerInterface) {
key := t.PkgPath() + ":" + t.Name() key := t.PkgPath() + ":" + t.Name()
if comm, ok := GlobalControllerRouter[key]; ok { if comm, ok := GlobalControllerRouter[key]; ok {
for _, a := range comm { for _, a := range comm {
p.Add(a.Router, c, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method) p.addWithMethodParams(a.Router, c, a.MethodParams, strings.Join(a.AllowHTTPMethods, ",")+":"+a.Method)
} }
} }
} }
...@@ -624,11 +631,12 @@ func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath str ...@@ -624,11 +631,12 @@ func (p *ControllerRegister) execFilter(context *beecontext.Context, urlPath str
func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) { func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
startTime := time.Now() startTime := time.Now()
var ( var (
runRouter reflect.Type runRouter reflect.Type
findRouter bool findRouter bool
runMethod string runMethod string
routerInfo *ControllerInfo methodParams []*param.MethodParam
isRunnable bool routerInfo *ControllerInfo
isRunnable bool
) )
context := p.pool.Get().(*beecontext.Context) context := p.pool.Get().(*beecontext.Context)
context.Reset(rw, r) context.Reset(rw, r)
...@@ -740,6 +748,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) ...@@ -740,6 +748,7 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
routerInfo.handler.ServeHTTP(rw, r) routerInfo.handler.ServeHTTP(rw, r)
} else { } else {
runRouter = routerInfo.controllerType runRouter = routerInfo.controllerType
methodParams = routerInfo.methodParams
method := r.Method method := r.Method
if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodPost { if r.Method == http.MethodPost && context.Input.Query("_method") == http.MethodPost {
method = http.MethodPut method = http.MethodPut
...@@ -802,9 +811,14 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request) ...@@ -802,9 +811,14 @@ func (p *ControllerRegister) ServeHTTP(rw http.ResponseWriter, r *http.Request)
execController.Options() execController.Options()
default: default:
if !execController.HandlerFunc(runMethod) { if !execController.HandlerFunc(runMethod) {
var in []reflect.Value
method := vc.MethodByName(runMethod) method := vc.MethodByName(runMethod)
method.Call(in) in := param.ConvertParams(methodParams, method.Type(), context)
out := method.Call(in)
//For backward compatibility we only handle response if we had incoming methodParams
if methodParams != nil {
p.handleParamResponse(context, execController, out)
}
} }
} }
...@@ -884,6 +898,20 @@ Admin: ...@@ -884,6 +898,20 @@ Admin:
} }
} }
func (p *ControllerRegister) handleParamResponse(context *beecontext.Context, execController ControllerInterface, results []reflect.Value) {
//looping in reverse order for the case when both error and value are returned and error sets the response status code
for i := len(results) - 1; i >= 0; i-- {
result := results[i]
if result.Kind() != reflect.Interface || !result.IsNil() {
resultValue := result.Interface()
context.RenderMethodResult(resultValue)
}
}
if !context.ResponseWriter.Started && context.Output.Status == 0 {
context.Output.SetStatus(200)
}
}
// FindRouter Find Router info for URL // FindRouter Find Router info for URL
func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo *ControllerInfo, isFind bool) { func (p *ControllerRegister) FindRouter(context *beecontext.Context) (routerInfo *ControllerInfo, isFind bool) {
var urlPath = context.Input.URL() var urlPath = context.Input.URL()
......
...@@ -22,19 +22,19 @@ package swagger ...@@ -22,19 +22,19 @@ package swagger
// Swagger list the resource // Swagger list the resource
type Swagger struct { type Swagger struct {
SwaggerVersion string `json:"swagger,omitempty" yaml:"swagger,omitempty"` SwaggerVersion string `json:"swagger,omitempty" yaml:"swagger,omitempty"`
Infos Information `json:"info" yaml:"info"` Infos Information `json:"info" yaml:"info"`
Host string `json:"host,omitempty" yaml:"host,omitempty"` Host string `json:"host,omitempty" yaml:"host,omitempty"`
BasePath string `json:"basePath,omitempty" yaml:"basePath,omitempty"` BasePath string `json:"basePath,omitempty" yaml:"basePath,omitempty"`
Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"` Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"`
Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"` Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"`
Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"` Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"`
Paths map[string]*Item `json:"paths" yaml:"paths"` Paths map[string]*Item `json:"paths" yaml:"paths"`
Definitions map[string]Schema `json:"definitions,omitempty" yaml:"definitions,omitempty"` Definitions map[string]Schema `json:"definitions,omitempty" yaml:"definitions,omitempty"`
SecurityDefinitions map[string]Security `json:"securityDefinitions,omitempty" yaml:"securityDefinitions,omitempty"` SecurityDefinitions map[string]Security `json:"securityDefinitions,omitempty" yaml:"securityDefinitions,omitempty"`
Security map[string][]string `json:"security,omitempty" yaml:"security,omitempty"` Security []map[string][]string `json:"security,omitempty" yaml:"security,omitempty"`
Tags []Tag `json:"tags,omitempty" yaml:"tags,omitempty"` Tags []Tag `json:"tags,omitempty" yaml:"tags,omitempty"`
ExternalDocs *ExternalDocs `json:"externalDocs,omitempty" yaml:"externalDocs,omitempty"` ExternalDocs *ExternalDocs `json:"externalDocs,omitempty" yaml:"externalDocs,omitempty"`
} }
// Information Provides metadata about the API. The metadata can be used by the clients if needed. // Information Provides metadata about the API. The metadata can be used by the clients if needed.
...@@ -75,16 +75,17 @@ type Item struct { ...@@ -75,16 +75,17 @@ type Item struct {
// Operation Describes a single API operation on a path. // Operation Describes a single API operation on a path.
type Operation struct { type Operation struct {
Tags []string `json:"tags,omitempty" yaml:"tags,omitempty"` Tags []string `json:"tags,omitempty" yaml:"tags,omitempty"`
Summary string `json:"summary,omitempty" yaml:"summary,omitempty"` Summary string `json:"summary,omitempty" yaml:"summary,omitempty"`
Description string `json:"description,omitempty" yaml:"description,omitempty"` Description string `json:"description,omitempty" yaml:"description,omitempty"`
OperationID string `json:"operationId,omitempty" yaml:"operationId,omitempty"` OperationID string `json:"operationId,omitempty" yaml:"operationId,omitempty"`
Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"` Consumes []string `json:"consumes,omitempty" yaml:"consumes,omitempty"`
Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"` Produces []string `json:"produces,omitempty" yaml:"produces,omitempty"`
Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"` Schemes []string `json:"schemes,omitempty" yaml:"schemes,omitempty"`
Parameters []Parameter `json:"parameters,omitempty" yaml:"parameters,omitempty"` Parameters []Parameter `json:"parameters,omitempty" yaml:"parameters,omitempty"`
Responses map[string]Response `json:"responses,omitempty" yaml:"responses,omitempty"` Responses map[string]Response `json:"responses,omitempty" yaml:"responses,omitempty"`
Deprecated bool `json:"deprecated,omitempty" yaml:"deprecated,omitempty"` Security []map[string][]string `json:"security,omitempty" yaml:"security,omitempty"`
Deprecated bool `json:"deprecated,omitempty" yaml:"deprecated,omitempty"`
} }
// Parameter Describes a single operation parameter. // Parameter Describes a single operation parameter.
......
...@@ -27,9 +27,10 @@ import ( ...@@ -27,9 +27,10 @@ import (
) )
const ( const (
formatTime = "15:04:05" formatTime = "15:04:05"
formatDate = "2006-01-02" formatDate = "2006-01-02"
formatDateTime = "2006-01-02 15:04:05" formatDateTime = "2006-01-02 15:04:05"
formatDateTimeT = "2006-01-02T15:04:05"
) )
// Substr returns the substr from start to length. // Substr returns the substr from start to length.
...@@ -53,21 +54,21 @@ func Substr(s string, start, length int) string { ...@@ -53,21 +54,21 @@ func Substr(s string, start, length int) string {
// HTML2str returns escaping text convert from html. // HTML2str returns escaping text convert from html.
func HTML2str(html string) string { func HTML2str(html string) string {
re, _ := regexp.Compile("\\<[\\S\\s]+?\\>") re, _ := regexp.Compile(`\<[\S\s]+?\>`)
html = re.ReplaceAllStringFunc(html, strings.ToLower) html = re.ReplaceAllStringFunc(html, strings.ToLower)
//remove STYLE //remove STYLE
re, _ = regexp.Compile("\\<style[\\S\\s]+?\\</style\\>") re, _ = regexp.Compile(`\<style[\S\s]+?\</style\>`)
html = re.ReplaceAllString(html, "") html = re.ReplaceAllString(html, "")
//remove SCRIPT //remove SCRIPT
re, _ = regexp.Compile("\\<script[\\S\\s]+?\\</script\\>") re, _ = regexp.Compile(`\<script[\S\s]+?\</script\>`)
html = re.ReplaceAllString(html, "") html = re.ReplaceAllString(html, "")
re, _ = regexp.Compile("\\<[\\S\\s]+?\\>") re, _ = regexp.Compile(`\<[\S\s]+?\>`)
html = re.ReplaceAllString(html, "\n") html = re.ReplaceAllString(html, "\n")
re, _ = regexp.Compile("\\s{2,}") re, _ = regexp.Compile(`\s{2,}`)
html = re.ReplaceAllString(html, "\n") html = re.ReplaceAllString(html, "\n")
return strings.TrimSpace(html) return strings.TrimSpace(html)
...@@ -360,8 +361,13 @@ func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) e ...@@ -360,8 +361,13 @@ func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) e
value = value[:25] value = value[:25]
t, err = time.ParseInLocation(time.RFC3339, value, time.Local) t, err = time.ParseInLocation(time.RFC3339, value, time.Local)
} else if len(value) >= 19 { } else if len(value) >= 19 {
value = value[:19] if strings.Contains(value, "T") {
t, err = time.ParseInLocation(formatDateTime, value, time.Local) value = value[:19]
t, err = time.ParseInLocation(formatDateTimeT, value, time.Local)
} else {
value = value[:19]
t, err = time.ParseInLocation(formatDateTime, value, time.Local)
}
} else if len(value) >= 10 { } else if len(value) >= 10 {
if len(value) > 10 { if len(value) > 10 {
value = value[:10] value = value[:10]
...@@ -373,7 +379,6 @@ func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) e ...@@ -373,7 +379,6 @@ func parseFormToStruct(form url.Values, objT reflect.Type, objV reflect.Value) e
} }
t, err = time.ParseInLocation(formatTime, value, time.Local) t, err = time.ParseInLocation(formatTime, value, time.Local)
} }
if err != nil { if err != nil {
return err return err
} }
......
...@@ -35,6 +35,12 @@ func TestRequired(t *testing.T) { ...@@ -35,6 +35,12 @@ func TestRequired(t *testing.T) {
if valid.Required("", "string").Ok { if valid.Required("", "string").Ok {
t.Error("\"'\" string should be false") t.Error("\"'\" string should be false")
} }
if valid.Required(" ", "string").Ok {
t.Error("\" \" string should be false") // For #2361
}
if valid.Required("\n", "string").Ok {
t.Error("new line string should be false") // For #2361
}
if !valid.Required("astaxie", "string").Ok { if !valid.Required("astaxie", "string").Ok {
t.Error("string should be true") t.Error("string should be true")
} }
...@@ -175,10 +181,10 @@ func TestAlphaNumeric(t *testing.T) { ...@@ -175,10 +181,10 @@ func TestAlphaNumeric(t *testing.T) {
func TestMatch(t *testing.T) { func TestMatch(t *testing.T) {
valid := Validation{} valid := Validation{}
if valid.Match("suchuangji@gmail", regexp.MustCompile("^\\w+@\\w+\\.\\w+$"), "match").Ok { if valid.Match("suchuangji@gmail", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok {
t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be false") t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be false")
} }
if !valid.Match("suchuangji@gmail.com", regexp.MustCompile("^\\w+@\\w+\\.\\w+$"), "match").Ok { if !valid.Match("suchuangji@gmail.com", regexp.MustCompile(`^\w+@\w+\.\w+$`), "match").Ok {
t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be true") t.Error("\"suchuangji@gmail\" match \"^\\w+@\\w+\\.\\w+$\" should be true")
} }
} }
...@@ -186,10 +192,10 @@ func TestMatch(t *testing.T) { ...@@ -186,10 +192,10 @@ func TestMatch(t *testing.T) {
func TestNoMatch(t *testing.T) { func TestNoMatch(t *testing.T) {
valid := Validation{} valid := Validation{}
if valid.NoMatch("123@gmail", regexp.MustCompile("[^\\w\\d]"), "nomatch").Ok { if valid.NoMatch("123@gmail", regexp.MustCompile(`[^\w\d]`), "nomatch").Ok {
t.Error("\"123@gmail\" not match \"[^\\w\\d]\" should be false") t.Error("\"123@gmail\" not match \"[^\\w\\d]\" should be false")
} }
if !valid.NoMatch("123gmail", regexp.MustCompile("[^\\w\\d]"), "match").Ok { if !valid.NoMatch("123gmail", regexp.MustCompile(`[^\w\d]`), "match").Ok {
t.Error("\"123@gmail\" not match \"[^\\w\\d@]\" should be true") t.Error("\"123@gmail\" not match \"[^\\w\\d@]\" should be true")
} }
} }
......
...@@ -18,6 +18,7 @@ import ( ...@@ -18,6 +18,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"regexp" "regexp"
"strings"
"time" "time"
"unicode/utf8" "unicode/utf8"
) )
...@@ -98,7 +99,7 @@ func (r Required) IsSatisfied(obj interface{}) bool { ...@@ -98,7 +99,7 @@ func (r Required) IsSatisfied(obj interface{}) bool {
} }
if str, ok := obj.(string); ok { if str, ok := obj.(string); ok {
return len(str) > 0 return len(strings.TrimSpace(str)) > 0
} }
if _, ok := obj.(bool); ok { if _, ok := obj.(bool); ok {
return true return true
...@@ -145,7 +146,7 @@ func (r Required) IsSatisfied(obj interface{}) bool { ...@@ -145,7 +146,7 @@ func (r Required) IsSatisfied(obj interface{}) bool {
// DefaultMessage return the default error message // DefaultMessage return the default error message
func (r Required) DefaultMessage() string { func (r Required) DefaultMessage() string {
return fmt.Sprint(MessageTmpls["Required"]) return MessageTmpls["Required"]
} }
// GetKey return the r.Key // GetKey return the r.Key
...@@ -364,7 +365,7 @@ func (a Alpha) IsSatisfied(obj interface{}) bool { ...@@ -364,7 +365,7 @@ func (a Alpha) IsSatisfied(obj interface{}) bool {
// DefaultMessage return the default Length error message // DefaultMessage return the default Length error message
func (a Alpha) DefaultMessage() string { func (a Alpha) DefaultMessage() string {
return fmt.Sprint(MessageTmpls["Alpha"]) return MessageTmpls["Alpha"]
} }
// GetKey return the m.Key // GetKey return the m.Key
...@@ -397,7 +398,7 @@ func (n Numeric) IsSatisfied(obj interface{}) bool { ...@@ -397,7 +398,7 @@ func (n Numeric) IsSatisfied(obj interface{}) bool {
// DefaultMessage return the default Length error message // DefaultMessage return the default Length error message
func (n Numeric) DefaultMessage() string { func (n Numeric) DefaultMessage() string {
return fmt.Sprint(MessageTmpls["Numeric"]) return MessageTmpls["Numeric"]
} }
// GetKey return the n.Key // GetKey return the n.Key
...@@ -430,7 +431,7 @@ func (a AlphaNumeric) IsSatisfied(obj interface{}) bool { ...@@ -430,7 +431,7 @@ func (a AlphaNumeric) IsSatisfied(obj interface{}) bool {
// DefaultMessage return the default Length error message // DefaultMessage return the default Length error message
func (a AlphaNumeric) DefaultMessage() string { func (a AlphaNumeric) DefaultMessage() string {
return fmt.Sprint(MessageTmpls["AlphaNumeric"]) return MessageTmpls["AlphaNumeric"]
} }
// GetKey return the a.Key // GetKey return the a.Key
...@@ -495,7 +496,7 @@ func (n NoMatch) GetLimitValue() interface{} { ...@@ -495,7 +496,7 @@ func (n NoMatch) GetLimitValue() interface{} {
return n.Regexp.String() return n.Regexp.String()
} }
var alphaDashPattern = regexp.MustCompile("[^\\d\\w-_]") var alphaDashPattern = regexp.MustCompile(`[^\d\w-_]`)
// AlphaDash check not Alpha // AlphaDash check not Alpha
type AlphaDash struct { type AlphaDash struct {
...@@ -505,7 +506,7 @@ type AlphaDash struct { ...@@ -505,7 +506,7 @@ type AlphaDash struct {
// DefaultMessage return the default AlphaDash error message // DefaultMessage return the default AlphaDash error message
func (a AlphaDash) DefaultMessage() string { func (a AlphaDash) DefaultMessage() string {
return fmt.Sprint(MessageTmpls["AlphaDash"]) return MessageTmpls["AlphaDash"]
} }
// GetKey return the n.Key // GetKey return the n.Key
...@@ -518,7 +519,7 @@ func (a AlphaDash) GetLimitValue() interface{} { ...@@ -518,7 +519,7 @@ func (a AlphaDash) GetLimitValue() interface{} {
return nil return nil
} }
var emailPattern = regexp.MustCompile("^[\\w!#$%&'*+/=?^_`{|}~-]+(?:\\.[\\w!#$%&'*+/=?^_`{|}~-]+)*@(?:[\\w](?:[\\w-]*[\\w])?\\.)+[a-zA-Z0-9](?:[\\w-]*[\\w])?$") var emailPattern = regexp.MustCompile(`^[\w!#$%&'*+/=?^_` + "`" + `{|}~-]+(?:\.[\w!#$%&'*+/=?^_` + "`" + `{|}~-]+)*@(?:[\w](?:[\w-]*[\w])?\.)+[a-zA-Z0-9](?:[\w-]*[\w])?$`)
// Email check struct // Email check struct
type Email struct { type Email struct {
...@@ -528,7 +529,7 @@ type Email struct { ...@@ -528,7 +529,7 @@ type Email struct {
// DefaultMessage return the default Email error message // DefaultMessage return the default Email error message
func (e Email) DefaultMessage() string { func (e Email) DefaultMessage() string {
return fmt.Sprint(MessageTmpls["Email"]) return MessageTmpls["Email"]
} }
// GetKey return the n.Key // GetKey return the n.Key
...@@ -541,7 +542,7 @@ func (e Email) GetLimitValue() interface{} { ...@@ -541,7 +542,7 @@ func (e Email) GetLimitValue() interface{} {
return nil return nil
} }
var ipPattern = regexp.MustCompile("^((2[0-4]\\d|25[0-5]|[01]?\\d\\d?)\\.){3}(2[0-4]\\d|25[0-5]|[01]?\\d\\d?)$") var ipPattern = regexp.MustCompile(`^((2[0-4]\d|25[0-5]|[01]?\d\d?)\.){3}(2[0-4]\d|25[0-5]|[01]?\d\d?)$`)
// IP check struct // IP check struct
type IP struct { type IP struct {
...@@ -551,7 +552,7 @@ type IP struct { ...@@ -551,7 +552,7 @@ type IP struct {
// DefaultMessage return the default IP error message // DefaultMessage return the default IP error message
func (i IP) DefaultMessage() string { func (i IP) DefaultMessage() string {
return fmt.Sprint(MessageTmpls["IP"]) return MessageTmpls["IP"]
} }
// GetKey return the i.Key // GetKey return the i.Key
...@@ -564,7 +565,7 @@ func (i IP) GetLimitValue() interface{} { ...@@ -564,7 +565,7 @@ func (i IP) GetLimitValue() interface{} {
return nil return nil
} }
var base64Pattern = regexp.MustCompile("^(?:[A-Za-z0-99+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?$") var base64Pattern = regexp.MustCompile(`^(?:[A-Za-z0-99+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?$`)
// Base64 check struct // Base64 check struct
type Base64 struct { type Base64 struct {
...@@ -574,7 +575,7 @@ type Base64 struct { ...@@ -574,7 +575,7 @@ type Base64 struct {
// DefaultMessage return the default Base64 error message // DefaultMessage return the default Base64 error message
func (b Base64) DefaultMessage() string { func (b Base64) DefaultMessage() string {
return fmt.Sprint(MessageTmpls["Base64"]) return MessageTmpls["Base64"]
} }
// GetKey return the b.Key // GetKey return the b.Key
...@@ -588,7 +589,7 @@ func (b Base64) GetLimitValue() interface{} { ...@@ -588,7 +589,7 @@ func (b Base64) GetLimitValue() interface{} {
} }
// just for chinese mobile phone number // just for chinese mobile phone number
var mobilePattern = regexp.MustCompile("^((\\+86)|(86))?(1(([35][0-9])|[8][0-9]|[7][06789]|[4][579]))\\d{8}$") var mobilePattern = regexp.MustCompile(`^((\+86)|(86))?(1(([35][0-9])|[8][0-9]|[7][06789]|[4][579]))\d{8}$`)
// Mobile check struct // Mobile check struct
type Mobile struct { type Mobile struct {
...@@ -598,7 +599,7 @@ type Mobile struct { ...@@ -598,7 +599,7 @@ type Mobile struct {
// DefaultMessage return the default Mobile error message // DefaultMessage return the default Mobile error message
func (m Mobile) DefaultMessage() string { func (m Mobile) DefaultMessage() string {
return fmt.Sprint(MessageTmpls["Mobile"]) return MessageTmpls["Mobile"]
} }
// GetKey return the m.Key // GetKey return the m.Key
...@@ -612,7 +613,7 @@ func (m Mobile) GetLimitValue() interface{} { ...@@ -612,7 +613,7 @@ func (m Mobile) GetLimitValue() interface{} {
} }
// just for chinese telephone number // just for chinese telephone number
var telPattern = regexp.MustCompile("^(0\\d{2,3}(\\-)?)?\\d{7,8}$") var telPattern = regexp.MustCompile(`^(0\d{2,3}(\-)?)?\d{7,8}$`)
// Tel check telephone struct // Tel check telephone struct
type Tel struct { type Tel struct {
...@@ -622,7 +623,7 @@ type Tel struct { ...@@ -622,7 +623,7 @@ type Tel struct {
// DefaultMessage return the default Tel error message // DefaultMessage return the default Tel error message
func (t Tel) DefaultMessage() string { func (t Tel) DefaultMessage() string {
return fmt.Sprint(MessageTmpls["Tel"]) return MessageTmpls["Tel"]
} }
// GetKey return the t.Key // GetKey return the t.Key
...@@ -649,7 +650,7 @@ func (p Phone) IsSatisfied(obj interface{}) bool { ...@@ -649,7 +650,7 @@ func (p Phone) IsSatisfied(obj interface{}) bool {
// DefaultMessage return the default Phone error message // DefaultMessage return the default Phone error message
func (p Phone) DefaultMessage() string { func (p Phone) DefaultMessage() string {
return fmt.Sprint(MessageTmpls["Phone"]) return MessageTmpls["Phone"]
} }
// GetKey return the p.Key // GetKey return the p.Key
...@@ -663,7 +664,7 @@ func (p Phone) GetLimitValue() interface{} { ...@@ -663,7 +664,7 @@ func (p Phone) GetLimitValue() interface{} {
} }
// just for chinese zipcode // just for chinese zipcode
var zipCodePattern = regexp.MustCompile("^[1-9]\\d{5}$") var zipCodePattern = regexp.MustCompile(`^[1-9]\d{5}$`)
// ZipCode check the zip struct // ZipCode check the zip struct
type ZipCode struct { type ZipCode struct {
...@@ -673,7 +674,7 @@ type ZipCode struct { ...@@ -673,7 +674,7 @@ type ZipCode struct {
// DefaultMessage return the default Zip error message // DefaultMessage return the default Zip error message
func (z ZipCode) DefaultMessage() string { func (z ZipCode) DefaultMessage() string {
return fmt.Sprint(MessageTmpls["ZipCode"]) return MessageTmpls["ZipCode"]
} }
// GetKey return the z.Key // GetKey return the z.Key
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment