Commit 6b4aa883 authored by Eric Chiang's avatar Eric Chiang

Merge pull request #280 from ericchiang/user_api

*: move user API auth to middleware and fix return status
parents ace8253c 0ada4c80
...@@ -51,24 +51,33 @@ func (s *UserMgmtServer) HTTPHandler() http.Handler { ...@@ -51,24 +51,33 @@ func (s *UserMgmtServer) HTTPHandler() http.Handler {
r := httprouter.New() r := httprouter.New()
r.RedirectTrailingSlash = false r.RedirectTrailingSlash = false
r.RedirectFixedPath = false r.RedirectFixedPath = false
r.GET(UsersListEndpoint, s.listUsers) r.GET(UsersListEndpoint, s.authAPIHandle(s.listUsers))
r.POST(UsersCreateEndpoint, s.createUser) r.POST(UsersCreateEndpoint, s.authAPIHandle(s.createUser))
r.POST(UsersDisableEndpoint, s.disableUser) r.POST(UsersDisableEndpoint, s.authAPIHandle(s.disableUser))
r.GET(UsersGetEndpoint, s.getUser) r.GET(UsersGetEndpoint, s.authAPIHandle(s.getUser))
return r return r
} }
func (s *UserMgmtServer) listUsers(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { // authedHandle is an HTTP handle which requires requests to be authenticated as an admin user.
type authedHandle func(w http.ResponseWriter, r *http.Request, ps httprouter.Params, creds api.Creds)
// authAPIHandle is a middleware function with authenticates an HTTP request before passing
// it along to the authedHandle.
func (s *UserMgmtServer) authAPIHandle(handle authedHandle) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
creds, err := s.getCreds(r) creds, err := s.getCreds(r)
if err != nil { if err != nil {
s.writeError(w, err) s.writeError(w, err)
return return
} }
handle(w, r, ps, creds)
}
}
func (s *UserMgmtServer) listUsers(w http.ResponseWriter, r *http.Request, ps httprouter.Params, creds api.Creds) {
nextPageToken := r.URL.Query().Get("nextPageToken") nextPageToken := r.URL.Query().Get("nextPageToken")
maxResults, err := intFromQuery(r.URL.Query(), "maxResults", defaultMaxResults) maxResults, err := intFromQuery(r.URL.Query(), "maxResults", defaultMaxResults)
if err != nil { if err != nil {
writeAPIError(w, http.StatusBadRequest, writeAPIError(w, http.StatusBadRequest,
newAPIError(errorInvalidRequest, "maxResults must be an integer")) newAPIError(errorInvalidRequest, "maxResults must be an integer"))
...@@ -88,13 +97,7 @@ func (s *UserMgmtServer) listUsers(w http.ResponseWriter, r *http.Request, ps ht ...@@ -88,13 +97,7 @@ func (s *UserMgmtServer) listUsers(w http.ResponseWriter, r *http.Request, ps ht
writeResponseWithBody(w, http.StatusOK, usersResponse) writeResponseWithBody(w, http.StatusOK, usersResponse)
} }
func (s *UserMgmtServer) getUser(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { func (s *UserMgmtServer) getUser(w http.ResponseWriter, r *http.Request, ps httprouter.Params, creds api.Creds) {
creds, err := s.getCreds(r)
if err != nil {
s.writeError(w, err)
return
}
id := ps.ByName("id") id := ps.ByName("id")
if id == "" { if id == "" {
writeAPIError(w, http.StatusBadRequest, writeAPIError(w, http.StatusBadRequest,
...@@ -113,16 +116,9 @@ func (s *UserMgmtServer) getUser(w http.ResponseWriter, r *http.Request, ps http ...@@ -113,16 +116,9 @@ func (s *UserMgmtServer) getUser(w http.ResponseWriter, r *http.Request, ps http
writeResponseWithBody(w, http.StatusOK, userResponse) writeResponseWithBody(w, http.StatusOK, userResponse)
} }
func (s *UserMgmtServer) createUser(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { func (s *UserMgmtServer) createUser(w http.ResponseWriter, r *http.Request, ps httprouter.Params, creds api.Creds) {
creds, err := s.getCreds(r)
if err != nil {
s.writeError(w, err)
return
}
createReq := schema.UserCreateRequest{} createReq := schema.UserCreateRequest{}
err = json.NewDecoder(r.Body).Decode(&createReq) if err := json.NewDecoder(r.Body).Decode(&createReq); err != nil {
if err != nil {
writeInvalidRequest(w, "cannot parse JSON body") writeInvalidRequest(w, "cannot parse JSON body")
return return
} }
...@@ -143,13 +139,7 @@ func (s *UserMgmtServer) createUser(w http.ResponseWriter, r *http.Request, ps h ...@@ -143,13 +139,7 @@ func (s *UserMgmtServer) createUser(w http.ResponseWriter, r *http.Request, ps h
writeResponseWithBody(w, http.StatusOK, createdResponse) writeResponseWithBody(w, http.StatusOK, createdResponse)
} }
func (s *UserMgmtServer) disableUser(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { func (s *UserMgmtServer) disableUser(w http.ResponseWriter, r *http.Request, ps httprouter.Params, creds api.Creds) {
creds, err := s.getCreds(r)
if err != nil {
s.writeError(w, err)
return
}
id := ps.ByName("id") id := ps.ByName("id")
if id == "" { if id == "" {
writeAPIError(w, http.StatusBadRequest, writeAPIError(w, http.StatusBadRequest,
...@@ -158,8 +148,7 @@ func (s *UserMgmtServer) disableUser(w http.ResponseWriter, r *http.Request, ps ...@@ -158,8 +148,7 @@ func (s *UserMgmtServer) disableUser(w http.ResponseWriter, r *http.Request, ps
} }
disableReq := schema.UserDisableRequest{} disableReq := schema.UserDisableRequest{}
err = json.NewDecoder(r.Body).Decode(&disableReq) if err := json.NewDecoder(r.Body).Decode(&disableReq); err != nil {
if err != nil {
writeInvalidRequest(w, "cannot parse JSON body") writeInvalidRequest(w, "cannot parse JSON body")
} }
...@@ -240,7 +229,7 @@ func (s *UserMgmtServer) getCreds(r *http.Request) (api.Creds, error) { ...@@ -240,7 +229,7 @@ func (s *UserMgmtServer) getCreds(r *http.Request) (api.Creds, error) {
return api.Creds{}, err return api.Creds{}, err
} }
if !isAdmin { if !isAdmin {
return api.Creds{}, api.ErrorUnauthorized return api.Creds{}, api.ErrorForbidden
} }
return api.Creds{ return api.Creds{
......
...@@ -31,7 +31,8 @@ var ( ...@@ -31,7 +31,8 @@ var (
ErrorDuplicateEmail = newError("duplicate_email", "Email already in use.", http.StatusBadRequest) ErrorDuplicateEmail = newError("duplicate_email", "Email already in use.", http.StatusBadRequest)
ErrorResourceNotFound = newError("resource_not_found", "Resource could not be found.", http.StatusNotFound) ErrorResourceNotFound = newError("resource_not_found", "Resource could not be found.", http.StatusNotFound)
ErrorUnauthorized = newError("unauthorized", "The given user and client are not authorized to make this request.", http.StatusUnauthorized) ErrorUnauthorized = newError("unauthorized", "Necessary credentials not provided.", http.StatusUnauthorized)
ErrorForbidden = newError("forbidden", "The given user and client are not authorized to make this request.", http.StatusForbidden)
ErrorMaxResultsTooHigh = newError("max_results_too_high", fmt.Sprintf("The max number of results per page is %d", maxUsersPerPage), http.StatusBadRequest) ErrorMaxResultsTooHigh = newError("max_results_too_high", fmt.Sprintf("The max number of results per page is %d", maxUsersPerPage), http.StatusBadRequest)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment