Commit 6c8a7f13 authored by astaxie's avatar astaxie

beego: router change to method Tree

parent e00eab7f
...@@ -35,9 +35,16 @@ const ( ...@@ -35,9 +35,16 @@ const (
var ( var (
// custom error when user stop request handler manually. // custom error when user stop request handler manually.
USERSTOPRUN = errors.New("User stop run") USERSTOPRUN = errors.New("User stop run")
GlobalControllerRouter map[string]map[string]*Tree //pkgpath+controller:method:routertree GlobalControllerRouter map[string]*ControllerComments //pkgpath+controller:comments
) )
// store the comment for the controller method
type ControllerComments struct {
method string
router string
allowHTTPMethods []string
}
// Controller defines some basic http request handler operations, such as // Controller defines some basic http request handler operations, such as
// http context, template and view, session and xsrf. // http context, template and view, session and xsrf.
type Controller struct { type Controller struct {
...@@ -56,7 +63,7 @@ type Controller struct { ...@@ -56,7 +63,7 @@ type Controller struct {
AppController interface{} AppController interface{}
EnableRender bool EnableRender bool
EnableXSRF bool EnableXSRF bool
Routers map[string]*Tree //method:routertree methodMapping map[string]func() //method:routertree
} }
// ControllerInterface is an interface to uniform all controller handler. // ControllerInterface is an interface to uniform all controller handler.
...@@ -74,7 +81,7 @@ type ControllerInterface interface { ...@@ -74,7 +81,7 @@ type ControllerInterface interface {
Render() error Render() error
XsrfToken() string XsrfToken() string
CheckXsrfCookie() bool CheckXsrfCookie() bool
HandlerFunc(fn interface{}) HandlerFunc(fn string)
URLMapping() URLMapping()
} }
...@@ -90,7 +97,7 @@ func (c *Controller) Init(ctx *context.Context, controllerName, actionName strin ...@@ -90,7 +97,7 @@ func (c *Controller) Init(ctx *context.Context, controllerName, actionName strin
c.EnableRender = true c.EnableRender = true
c.EnableXSRF = true c.EnableXSRF = true
c.Data = ctx.Input.Data c.Data = ctx.Input.Data
c.Routers = make(map[string]*Tree) c.methodMapping = make(map[string]func())
} }
// Prepare runs after Init before request function execution. // Prepare runs after Init before request function execution.
...@@ -139,9 +146,11 @@ func (c *Controller) Options() { ...@@ -139,9 +146,11 @@ func (c *Controller) Options() {
} }
// call function fn // call function fn
func (c *Controller) HandlerFunc(fn interface{}) { func (c *Controller) HandlerFunc(fnname string) {
if v, ok := fn.(func()); ok { if v, ok := c.methodMapping[fnname]; ok {
v() v()
} else {
Error("call funcname not exist in the methodMapping: " + fnname)
} }
} }
...@@ -149,19 +158,8 @@ func (c *Controller) HandlerFunc(fn interface{}) { ...@@ -149,19 +158,8 @@ func (c *Controller) HandlerFunc(fn interface{}) {
func (c *Controller) URLMapping() { func (c *Controller) URLMapping() {
} }
func (c *Controller) Mapping(method, pattern string, fn func()) { func (c *Controller) Mapping(method string, fn func()) {
method = strings.ToLower(method) c.methodMapping[method] = fn
if !utils.InSlice(method, HTTPMETHOD) && method != "*" {
Critical("add mapping method:" + method + " is a valid method")
return
}
if t, ok := c.Routers[method]; ok {
t.AddRouter(pattern, fn)
} else {
t = NewTree()
t.AddRouter(pattern, fn)
c.Routers[method] = t
}
} }
// Render sends the response with rendered template bytes as text/html type. // Render sends the response with rendered template bytes as text/html type.
......
...@@ -182,7 +182,15 @@ func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace { ...@@ -182,7 +182,15 @@ func (n *Namespace) Handler(rootpath string, h http.Handler) *Namespace {
//) //)
func (n *Namespace) Namespace(ns ...*Namespace) *Namespace { func (n *Namespace) Namespace(ns ...*Namespace) *Namespace {
for _, ni := range ns { for _, ni := range ns {
n.handlers.routers.AddTree(ni.prefix, ni.handlers.routers) for k, v := range ni.handlers.routers {
if t, ok := n.handlers.routers[k]; ok {
n.handlers.routers[k].AddTree(ni.prefix, v)
} else {
t = NewTree()
t.AddTree(ni.prefix, v)
n.handlers.routers[k] = t
}
}
if n.handlers.enableFilter { if n.handlers.enableFilter {
for pos, filterList := range ni.handlers.filters { for pos, filterList := range ni.handlers.filters {
for _, mr := range filterList { for _, mr := range filterList {
...@@ -201,7 +209,15 @@ func (n *Namespace) Namespace(ns ...*Namespace) *Namespace { ...@@ -201,7 +209,15 @@ func (n *Namespace) Namespace(ns ...*Namespace) *Namespace {
// support multi Namespace // support multi Namespace
func AddNamespace(nl ...*Namespace) { func AddNamespace(nl ...*Namespace) {
for _, n := range nl { for _, n := range nl {
BeeApp.Handlers.routers.AddTree(n.prefix, n.handlers.routers) for k, v := range n.handlers.routers {
if t, ok := BeeApp.Handlers.routers[k]; ok {
BeeApp.Handlers.routers[k].AddTree(n.prefix, v)
} else {
t = NewTree()
t.AddTree(n.prefix, v)
BeeApp.Handlers.routers[k] = t
}
}
if n.handlers.enableFilter { if n.handlers.enableFilter {
for pos, filterList := range n.handlers.filters { for pos, filterList := range n.handlers.filters {
for _, mr := range filterList { for _, mr := range filterList {
......
...@@ -4,3 +4,43 @@ ...@@ -4,3 +4,43 @@
// @license http://github.com/astaxie/beego/blob/master/LICENSE // @license http://github.com/astaxie/beego/blob/master/LICENSE
// @authors astaxie // @authors astaxie
package beego package beego
import (
"os"
"path/filepath"
)
var globalControllerRouter = `package routers
import (
"github.com/astaxie/beego"
)
func init() {
{{.globalinfo}}
}
`
func parserPkg(pkgpath string) error {
err := filepath.Walk(pkgpath, func(path string, info os.FileInfo, err error) error {
if err != nil {
Error("error scan app Controller source:", err)
return err
}
//if is normal file or name is temp skip
//directory is needed
if !info.IsDir() || info.Name() == "tmp" {
return nil
}
//fileSet := token.NewFileSet()
//astPkgs, err := parser.ParseDir(fileSet, path, func(info os.FileInfo) bool {
// name := info.Name()
// return !info.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
//}, parser.ParseComments)
return nil
})
return err
}
...@@ -12,7 +12,9 @@ import ( ...@@ -12,7 +12,9 @@ import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"os"
"path" "path"
"path/filepath"
"reflect" "reflect"
"runtime" "runtime"
"strconv" "strconv"
...@@ -67,7 +69,7 @@ type controllerInfo struct { ...@@ -67,7 +69,7 @@ type controllerInfo struct {
// ControllerRegistor containers registered router rules, controller handlers and filters. // ControllerRegistor containers registered router rules, controller handlers and filters.
type ControllerRegistor struct { type ControllerRegistor struct {
routers *Tree routers map[string]*Tree
enableFilter bool enableFilter bool
filters map[int][]*FilterRouter filters map[int][]*FilterRouter
} }
...@@ -75,7 +77,7 @@ type ControllerRegistor struct { ...@@ -75,7 +77,7 @@ type ControllerRegistor struct {
// NewControllerRegistor returns a new ControllerRegistor. // NewControllerRegistor returns a new ControllerRegistor.
func NewControllerRegistor() *ControllerRegistor { func NewControllerRegistor() *ControllerRegistor {
return &ControllerRegistor{ return &ControllerRegistor{
routers: NewTree(), routers: make(map[string]*Tree),
filters: make(map[int][]*FilterRouter), filters: make(map[int][]*FilterRouter),
} }
} }
...@@ -120,17 +122,69 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM ...@@ -120,17 +122,69 @@ func (p *ControllerRegistor) Add(pattern string, c ControllerInterface, mappingM
route.methods = methods route.methods = methods
route.routerType = routerTypeBeego route.routerType = routerTypeBeego
route.controllerType = t route.controllerType = t
p.routers.AddRouter(pattern, route) if len(methods) == 0 {
for _, m := range HTTPMETHOD {
p.addToRouter(m, pattern, route)
}
} else {
for k, _ := range methods {
if k == "*" {
for _, m := range HTTPMETHOD {
p.addToRouter(m, pattern, route)
}
} else {
p.addToRouter(k, pattern, route)
}
}
}
}
func (p *ControllerRegistor) addToRouter(method, pattern string, r *controllerInfo) {
if t, ok := p.routers[method]; ok {
t.AddRouter(pattern, r)
} else {
t := NewTree()
t.AddRouter(pattern, r)
p.routers[method] = t
}
} }
// only when the Runmode is dev will generate router file in the router/auto.go from the controller // only when the Runmode is dev will generate router file in the router/auto.go from the controller
// Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{}) // Include(&BankAccount{}, &OrderController{},&RefundController{},&ReceiptController{})
func (p *ControllerRegistor) Include(cList ...ControllerInterface) { func (p *ControllerRegistor) Include(cList ...ControllerInterface) {
if RunMode == "dev" { if RunMode == "dev" {
skip := make(map[string]bool, 10)
for _, c := range cList {
reflectVal := reflect.ValueOf(c)
t := reflect.Indirect(reflectVal).Type()
gopath := os.Getenv("GOPATH")
if gopath == "" {
panic("you are in dev mode. So please set gopath")
}
pkgpath := ""
wgopath := filepath.SplitList(gopath)
for _, wg := range wgopath {
wg, _ = filepath.EvalSymlinks(filepath.Join(wg, "src", t.PkgPath()))
if utils.FileExists(wg) {
pkgpath = wg
break
}
}
if pkgpath != "" {
if _, ok := skip[pkgpath]; !ok {
skip[pkgpath] = true
parserPkg(pkgpath)
}
}
}
}
for _, c := range cList { for _, c := range cList {
reflectVal := reflect.ValueOf(c) reflectVal := reflect.ValueOf(c)
t := reflect.Indirect(reflectVal).Type() t := reflect.Indirect(reflectVal).Type()
t.PkgPath() key := t.PkgPath() + ":" + t.Name()
if comm, ok := GlobalControllerRouter[key]; ok {
p.Add(comm.router, c, strings.Join(comm.allowHTTPMethods, ",")+":"+comm.method)
} }
} }
} }
...@@ -228,7 +282,15 @@ func (p *ControllerRegistor) AddMethod(method, pattern string, f FilterFunc) { ...@@ -228,7 +282,15 @@ func (p *ControllerRegistor) AddMethod(method, pattern string, f FilterFunc) {
methods[method] = method methods[method] = method
} }
route.methods = methods route.methods = methods
p.routers.AddRouter(pattern, route) for k, _ := range methods {
if k == "*" {
for _, m := range HTTPMETHOD {
p.addToRouter(m, pattern, route)
}
} else {
p.addToRouter(k, pattern, route)
}
}
} }
// add user defined Handler // add user defined Handler
...@@ -241,7 +303,9 @@ func (p *ControllerRegistor) Handler(pattern string, h http.Handler, options ... ...@@ -241,7 +303,9 @@ func (p *ControllerRegistor) Handler(pattern string, h http.Handler, options ...
pattern = path.Join(pattern, "?:all") pattern = path.Join(pattern, "?:all")
} }
} }
p.routers.AddRouter(pattern, route) for _, m := range HTTPMETHOD {
p.addToRouter(m, pattern, route)
}
} }
// Add auto router to ControllerRegistor. // Add auto router to ControllerRegistor.
...@@ -270,7 +334,9 @@ func (p *ControllerRegistor) AddAutoPrefix(prefix string, c ControllerInterface) ...@@ -270,7 +334,9 @@ func (p *ControllerRegistor) AddAutoPrefix(prefix string, c ControllerInterface)
route.methods = map[string]string{"*": rt.Method(i).Name} route.methods = map[string]string{"*": rt.Method(i).Name}
route.controllerType = ct route.controllerType = ct
pattern := path.Join(prefix, controllerName, strings.ToLower(rt.Method(i).Name), "*") pattern := path.Join(prefix, controllerName, strings.ToLower(rt.Method(i).Name), "*")
p.routers.AddRouter(pattern, route) for _, m := range HTTPMETHOD {
p.addToRouter(m, pattern, route)
}
} }
} }
} }
...@@ -317,12 +383,13 @@ func (p *ControllerRegistor) UrlFor(endpoint string, values ...string) string { ...@@ -317,12 +383,13 @@ func (p *ControllerRegistor) UrlFor(endpoint string, values ...string) string {
} }
controllName := strings.Join(paths[:len(paths)-1], ".") controllName := strings.Join(paths[:len(paths)-1], ".")
methodName := paths[len(paths)-1] methodName := paths[len(paths)-1]
ok, url := p.geturl(p.routers, "/", controllName, methodName, params) for _, t := range p.routers {
ok, url := p.geturl(t, "/", controllName, methodName, params)
if ok { if ok {
return url return url
} else {
return ""
} }
}
return ""
} }
func (p *ControllerRegistor) geturl(t *Tree, url, controllName, methodName string, params map[string]string) (bool, string) { func (p *ControllerRegistor) geturl(t *Tree, url, controllName, methodName string, params map[string]string) (bool, string) {
...@@ -436,6 +503,7 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) ...@@ -436,6 +503,7 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
starttime := time.Now() starttime := time.Now()
requestPath := r.URL.Path requestPath := r.URL.Path
method := strings.ToLower(r.Method)
var runrouter reflect.Type var runrouter reflect.Type
var findrouter bool var findrouter bool
var runMethod string var runMethod string
...@@ -485,7 +553,7 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) ...@@ -485,7 +553,7 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
}() }()
} }
if !utils.InSlice(strings.ToLower(r.Method), HTTPMETHOD) { if !utils.InSlice(method, HTTPMETHOD) {
http.Error(w, "Method Not Allowed", 405) http.Error(w, "Method Not Allowed", 405)
goto Admin goto Admin
} }
...@@ -512,7 +580,8 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) ...@@ -512,7 +580,8 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
} }
if !findrouter { if !findrouter {
runObject, p := p.routers.Match(requestPath) if t, ok := p.routers[method]; ok {
runObject, p := t.Match(requestPath)
if r, ok := runObject.(*controllerInfo); ok { if r, ok := runObject.(*controllerInfo); ok {
routerInfo = r routerInfo = r
findrouter = true findrouter = true
...@@ -526,6 +595,8 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request) ...@@ -526,6 +595,8 @@ func (p *ControllerRegistor) ServeHTTP(rw http.ResponseWriter, r *http.Request)
} }
} }
}
//if no matches to url, throw a not found exception //if no matches to url, throw a not found exception
if !findrouter { if !findrouter {
middleware.Exception("404", rw, r, "") middleware.Exception("404", rw, r, "")
......
...@@ -291,6 +291,17 @@ func (a *AdminController) Get() { ...@@ -291,6 +291,17 @@ func (a *AdminController) Get() {
a.Ctx.WriteString("hello") a.Ctx.WriteString("hello")
} }
func TestRouterFunc(t *testing.T) {
mux := NewControllerRegistor()
mux.Get("/action", beegoFilterFunc)
mux.Post("/action", beegoFilterFunc)
rw, r := testRequest("GET", "/action")
mux.ServeHTTP(rw, r)
if rw.Body.String() != "hello" {
t.Errorf("TestRouterFunc can't run")
}
}
func BenchmarkFunc(b *testing.B) { func BenchmarkFunc(b *testing.B) {
mux := NewControllerRegistor() mux := NewControllerRegistor()
mux.Get("/action", beegoFilterFunc) mux.Get("/action", beegoFilterFunc)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment