Commit e5db3023 authored by Joe Bowers's avatar Joe Bowers

server: expose user disable API endpoint

parent b33cfbf5
...@@ -97,10 +97,25 @@ func (r *userRepo) Create(tx repo.Transaction, usr user.User) (err error) { ...@@ -97,10 +97,25 @@ func (r *userRepo) Create(tx repo.Transaction, usr user.User) (err error) {
} }
err = r.insert(tx, usr) err = r.insert(tx, usr)
return err
}
func (r *userRepo) Disable(tx repo.Transaction, userID string, disable bool) error {
if userID == "" {
return user.ErrorInvalidID
}
qt := pq.QuoteIdentifier(userTableName)
ex := r.executor(tx)
result, err := ex.Exec(fmt.Sprintf("UPDATE %s SET disabled = $2 WHERE id = $1", qt), userID, disable)
if err != nil { if err != nil {
return err return err
} }
if ct, err := result.RowsAffected(); err == nil && ct == 0 {
return user.ErrorInvalidID
}
return nil return nil
} }
......
...@@ -141,6 +141,7 @@ func makeUserAPITestFixtures() *userAPITestFixtures { ...@@ -141,6 +141,7 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
f.trans = &tokenHandlerTransport{ f.trans = &tokenHandlerTransport{
Handler: usrSrv.HTTPHandler(), Handler: usrSrv.HTTPHandler(),
Token: userGoodToken,
} }
hc := &http.Client{ hc := &http.Client{
Transport: f.trans, Transport: f.trans,
...@@ -530,6 +531,48 @@ func TestCreateUser(t *testing.T) { ...@@ -530,6 +531,48 @@ func TestCreateUser(t *testing.T) {
} }
} }
func TestDisableUser(t *testing.T) {
tests := []struct {
id string
disable bool
}{
{
id: "ID-2",
disable: true,
},
{
id: "ID-4",
disable: false,
},
}
for i, tt := range tests {
f := makeUserAPITestFixtures()
usr, err := f.client.Users.Get(tt.id).Do()
if err != nil {
t.Fatalf("case %v: unexpected error: %v", i, err)
}
if usr.User.Disabled == tt.disable {
t.Fatalf("case %v: misconfigured test, initial disabled state should be %v but was %v", i, !tt.disable, usr.User.Disabled)
}
_, err = f.client.Users.Disable(tt.id, &schema.UserDisableRequest{
Disable: tt.disable,
}).Do()
if err != nil {
t.Fatalf("case %v: unexpected error: %v", i, err)
}
usr, err = f.client.Users.Get(tt.id).Do()
if err != nil {
t.Fatalf("case %v: unexpected error: %v", i, err)
}
if usr.User.Disabled != tt.disable {
t.Errorf("case %v: user disabled state incorrect. wanted: %v found: %v", i, tt.disable, usr.User.Disabled)
}
}
}
type testEmailer struct { type testEmailer struct {
cantEmail bool cantEmail bool
lastEmail string lastEmail string
......
...@@ -108,6 +108,8 @@ type User struct { ...@@ -108,6 +108,8 @@ type User struct {
CreatedAt string `json:"createdAt,omitempty"` CreatedAt string `json:"createdAt,omitempty"`
Disabled bool `json:"disabled,omitempty"`
DisplayName string `json:"displayName,omitempty"` DisplayName string `json:"displayName,omitempty"`
Email string `json:"email,omitempty"` Email string `json:"email,omitempty"`
......
...@@ -108,6 +108,9 @@ const DiscoveryJSON = `{ ...@@ -108,6 +108,9 @@ const DiscoveryJSON = `{
"admin": { "admin": {
"type": "boolean" "type": "boolean"
}, },
"disabled": {
"type": "boolean"
},
"createdAt": { "createdAt": {
"type": "string", "type": "string",
"format": "date-time" "format": "date-time"
......
...@@ -102,6 +102,9 @@ ...@@ -102,6 +102,9 @@
"admin": { "admin": {
"type": "boolean" "type": "boolean"
}, },
"disabled": {
"type": "boolean"
},
"createdAt": { "createdAt": {
"type": "string", "type": "string",
"format": "date-time" "format": "date-time"
......
...@@ -244,7 +244,11 @@ func (s *Server) HTTPHandler() http.Handler { ...@@ -244,7 +244,11 @@ func (s *Server) HTTPHandler() http.Handler {
mux.Handle(path.Join(apiBasePath, clientPath), s.NewClientTokenAuthHandler(clientHandler)) mux.Handle(path.Join(apiBasePath, clientPath), s.NewClientTokenAuthHandler(clientHandler))
usersAPI := usersapi.NewUsersAPI(s.UserManager, s.ClientIdentityRepo, s.UserEmailer, s.localConnectorID) usersAPI := usersapi.NewUsersAPI(s.UserManager, s.ClientIdentityRepo, s.UserEmailer, s.localConnectorID)
mux.Handle(path.Join(apiBasePath, UsersSubTree), NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientIdentityRepo).HTTPHandler()) handler := NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientIdentityRepo).HTTPHandler()
path := path.Join(apiBasePath, UsersSubTree)
mux.Handle(path, handler)
mux.Handle(path+"/", handler)
return http.Handler(mux) return http.Handler(mux)
} }
......
...@@ -27,6 +27,7 @@ var ( ...@@ -27,6 +27,7 @@ var (
UsersListEndpoint = addBasePath(UsersSubTree) UsersListEndpoint = addBasePath(UsersSubTree)
UsersCreateEndpoint = addBasePath(UsersSubTree) UsersCreateEndpoint = addBasePath(UsersSubTree)
UsersGetEndpoint = addBasePath(UsersSubTree + "/:id") UsersGetEndpoint = addBasePath(UsersSubTree + "/:id")
UsersDisableEndpoint = addBasePath(UsersSubTree + "/:id/disable")
) )
type UserMgmtServer struct { type UserMgmtServer struct {
...@@ -51,6 +52,7 @@ func (s *UserMgmtServer) HTTPHandler() http.Handler { ...@@ -51,6 +52,7 @@ func (s *UserMgmtServer) HTTPHandler() http.Handler {
r.RedirectFixedPath = false r.RedirectFixedPath = false
r.GET(UsersListEndpoint, s.listUsers) r.GET(UsersListEndpoint, s.listUsers)
r.POST(UsersCreateEndpoint, s.createUser) r.POST(UsersCreateEndpoint, s.createUser)
r.POST(UsersDisableEndpoint, s.disableUser)
r.GET(UsersGetEndpoint, s.getUser) r.GET(UsersGetEndpoint, s.getUser)
return r return r
} }
...@@ -140,6 +142,35 @@ func (s *UserMgmtServer) createUser(w http.ResponseWriter, r *http.Request, ps h ...@@ -140,6 +142,35 @@ func (s *UserMgmtServer) createUser(w http.ResponseWriter, r *http.Request, ps h
writeResponseWithBody(w, http.StatusOK, createdResponse) writeResponseWithBody(w, http.StatusOK, createdResponse)
} }
func (s *UserMgmtServer) disableUser(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
creds, err := s.getCreds(r)
if err != nil {
s.writeError(w, err)
return
}
id := ps.ByName("id")
if id == "" {
writeAPIError(w, http.StatusBadRequest,
newAPIError(errorInvalidRequest, "id is required"))
return
}
disableReq := schema.UserDisableRequest{}
err = json.NewDecoder(r.Body).Decode(&disableReq)
if err != nil {
writeInvalidRequest(w, "cannot parse JSON body")
}
resp, err := s.api.DisableUser(creds, id, disableReq.Disable)
if err != nil {
s.writeError(w, err)
return
}
writeResponseWithBody(w, http.StatusOK, resp)
}
func (s *UserMgmtServer) writeError(w http.ResponseWriter, err error) { func (s *UserMgmtServer) writeError(w http.ResponseWriter, err error) {
log.Errorf("Error calling user management API: %v: ", err) log.Errorf("Error calling user management API: %v: ", err)
if apiErr, ok := err.(api.Error); ok { if apiErr, ok := err.(api.Error); ok {
......
...@@ -121,6 +121,21 @@ func (u *UsersAPI) GetUser(creds Creds, id string) (schema.User, error) { ...@@ -121,6 +121,21 @@ func (u *UsersAPI) GetUser(creds Creds, id string) (schema.User, error) {
return userToSchemaUser(usr), nil return userToSchemaUser(usr), nil
} }
func (u *UsersAPI) DisableUser(creds Creds, userID string, disable bool) (schema.UserDisableResponse, error) {
log.Infof("userAPI: DisableUser")
if !u.Authorize(creds) {
return schema.UserDisableResponse{}, ErrorUnauthorized
}
if err := u.manager.Disable(userID, disable); err != nil {
return schema.UserDisableResponse{}, mapError(err)
}
return schema.UserDisableResponse{
Ok: true,
}, nil
}
func (u *UsersAPI) CreateUser(creds Creds, usr schema.User, redirURL url.URL) (schema.UserCreateResponse, error) { func (u *UsersAPI) CreateUser(creds Creds, usr schema.User, redirURL url.URL) (schema.UserCreateResponse, error) {
log.Infof("userAPI: CreateUser") log.Infof("userAPI: CreateUser")
if !u.Authorize(creds) { if !u.Authorize(creds) {
...@@ -207,6 +222,7 @@ func userToSchemaUser(usr user.User) schema.User { ...@@ -207,6 +222,7 @@ func userToSchemaUser(usr user.User) schema.User {
EmailVerified: usr.EmailVerified, EmailVerified: usr.EmailVerified,
DisplayName: usr.DisplayName, DisplayName: usr.DisplayName,
Admin: usr.Admin, Admin: usr.Admin,
Disabled: usr.Disabled,
CreatedAt: usr.CreatedAt.UTC().Format(time.RFC3339), CreatedAt: usr.CreatedAt.UTC().Format(time.RFC3339),
} }
} }
...@@ -218,6 +234,7 @@ func schemaUserToUser(usr schema.User) user.User { ...@@ -218,6 +234,7 @@ func schemaUserToUser(usr schema.User) user.User {
EmailVerified: usr.EmailVerified, EmailVerified: usr.EmailVerified,
DisplayName: usr.DisplayName, DisplayName: usr.DisplayName,
Admin: usr.Admin, Admin: usr.Admin,
Disabled: usr.Disabled,
} }
} }
......
...@@ -94,6 +94,13 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) { ...@@ -94,6 +94,13 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
Email: "id3@example.com", Email: "id3@example.com",
CreatedAt: clock.Now(), CreatedAt: clock.Now(),
}, },
}, {
User: user.User{
ID: "ID-4",
Email: "id4@example.com",
CreatedAt: clock.Now(),
Disabled: true,
},
}, },
}) })
pwr := user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{ pwr := user.NewPasswordInfoRepoFromPasswordInfos([]user.PasswordInfo{
...@@ -369,3 +376,44 @@ func TestCreateUser(t *testing.T) { ...@@ -369,3 +376,44 @@ func TestCreateUser(t *testing.T) {
} }
} }
} }
func TestDisableUsers(t *testing.T) {
tests := []struct {
id string
disable bool
}{
{
id: "ID-1",
disable: true,
},
{
id: "ID-1",
disable: false,
},
{
id: "ID-4",
disable: true,
},
{
id: "ID-4",
disable: false,
},
}
for i, tt := range tests {
api, _ := makeTestFixtures()
_, err := api.DisableUser(goodCreds, tt.id, tt.disable)
if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err)
}
usr, err := api.GetUser(goodCreds, tt.id)
if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err)
}
if usr.Disabled != tt.disable {
t.Errorf("case %d: user disable state wrong. wanted: %v got: %v", i, tt.disable, usr.Disabled)
}
}
}
...@@ -102,6 +102,22 @@ func (m *Manager) CreateUser(user User, hashedPassword Password, connID string) ...@@ -102,6 +102,22 @@ func (m *Manager) CreateUser(user User, hashedPassword Password, connID string)
return user.ID, nil return user.ID, nil
} }
func (m *Manager) Disable(userID string, disabled bool) error {
tx, err := m.begin()
if err = m.userRepo.Disable(tx, userID, disabled); err != nil {
rollback(tx)
return err
}
if err = tx.Commit(); err != nil {
rollback(tx)
return err
}
return nil
}
// RegisterWithRemoteIdentity creates new user and attaches the given remote identity. // RegisterWithRemoteIdentity creates new user and attaches the given remote identity.
func (m *Manager) RegisterWithRemoteIdentity(email string, emailVerified bool, rid RemoteIdentity) (string, error) { func (m *Manager) RegisterWithRemoteIdentity(email string, emailVerified bool, rid RemoteIdentity) (string, error) {
tx, err := m.begin() tx, err := m.begin()
......
...@@ -80,6 +80,8 @@ type UserRepo interface { ...@@ -80,6 +80,8 @@ type UserRepo interface {
GetByEmail(tx repo.Transaction, email string) (User, error) GetByEmail(tx repo.Transaction, email string) (User, error)
Disable(tx repo.Transaction, id string, disabled bool) error
Update(repo.Transaction, User) error Update(repo.Transaction, User) error
GetByRemoteIdentity(repo.Transaction, RemoteIdentity) (User, error) GetByRemoteIdentity(repo.Transaction, RemoteIdentity) (User, error)
...@@ -254,6 +256,16 @@ func (r *memUserRepo) Update(_ repo.Transaction, user User) error { ...@@ -254,6 +256,16 @@ func (r *memUserRepo) Update(_ repo.Transaction, user User) error {
return nil return nil
} }
func (r *memUserRepo) Disable(_ repo.Transaction, id string, disable bool) error {
user, ok := r.usersByID[id]
if !ok {
return ErrorNotFound
}
user.Disabled = disable
r.set(user)
return nil
}
func (r *memUserRepo) AddRemoteIdentity(_ repo.Transaction, userID string, ri RemoteIdentity) error { func (r *memUserRepo) AddRemoteIdentity(_ repo.Transaction, userID string, ri RemoteIdentity) error {
_, ok := r.usersByID[userID] _, ok := r.usersByID[userID]
if !ok { if !ok {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment