Commit b02a3a31 authored by Eric Chiang's avatar Eric Chiang

*: add "groups" scope

parent 731dadb2
......@@ -41,6 +41,7 @@ CREATE TABLE refresh_token (
payload_hash blob,
user_id text,
client_id text,
connector_id text,
scopes text
);
......@@ -63,7 +64,8 @@ CREATE TABLE session (
user_id text,
register integer,
nonce text,
scope text
scope text,
groups text
);
CREATE TABLE session_key (
......
-- +migrate Up
ALTER TABLE refresh_token ADD COLUMN "connector_id" text;
ALTER TABLE session ADD COLUMN "groups" text;
......@@ -90,5 +90,11 @@ var PostgresMigrations migrate.MigrationSource = &migrate.MemoryMigrationSource{
"-- +migrate Up\nALTER TABLE refresh_token ADD COLUMN \"scopes\" text;\n\nUPDATE refresh_token SET scopes = 'openid profile email offline_access';\n",
},
},
{
Id: "0014_add_groups.sql",
Up: []string{
"-- +migrate Up\nALTER TABLE refresh_token ADD COLUMN \"connector_id\" text;\nALTER TABLE session ADD COLUMN \"groups\" text;\n",
},
},
},
}
......@@ -41,6 +41,7 @@ type refreshTokenModel struct {
PayloadHash []byte `db:"payload_hash"`
UserID string `db:"user_id"`
ClientID string `db:"client_id"`
ConnectorID string `db:"connector_id"`
Scopes string `db:"scopes"`
}
......@@ -89,7 +90,7 @@ func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenG
}
}
func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (string, error) {
func (r *refreshTokenRepo) Create(userID, clientID, connectorID string, scopes []string) (string, error) {
if userID == "" {
return "", refresh.ErrorInvalidUserID
}
......@@ -112,6 +113,7 @@ func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (str
PayloadHash: payloadHash,
UserID: userID,
ClientID: clientID,
ConnectorID: connectorID,
Scopes: strings.Join(scopes, " "),
}
......@@ -122,24 +124,24 @@ func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (str
return buildToken(record.ID, tokenPayload), nil
}
func (r *refreshTokenRepo) Verify(clientID, token string) (string, scope.Scopes, error) {
func (r *refreshTokenRepo) Verify(clientID, token string) (userID, connectorID string, scope scope.Scopes, err error) {
tokenID, tokenPayload, err := parseToken(token)
if err != nil {
return "", nil, err
return
}
record, err := r.get(nil, tokenID)
if err != nil {
return "", nil, err
return
}
if record.ClientID != clientID {
return "", nil, refresh.ErrorInvalidClientID
return "", "", nil, refresh.ErrorInvalidClientID
}
if err := checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
return "", nil, err
if err = checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
return
}
var scopes []string
......@@ -147,7 +149,7 @@ func (r *refreshTokenRepo) Verify(clientID, token string) (string, scope.Scopes,
scopes = strings.Split(record.Scopes, " ")
}
return record.UserID, scopes, nil
return record.UserID, record.ConnectorID, scopes, nil
}
func (r *refreshTokenRepo) Revoke(userID, token string) error {
......
......@@ -44,6 +44,7 @@ type sessionModel struct {
Register bool `db:"register"`
Nonce string `db:"nonce"`
Scope string `db:"scope"`
Groups string `db:"groups"`
}
func (s *sessionModel) session() (*session.Session, error) {
......@@ -75,6 +76,11 @@ func (s *sessionModel) session() (*session.Session, error) {
Nonce: s.Nonce,
Scope: strings.Fields(s.Scope),
}
if s.Groups != "" {
if err := json.Unmarshal([]byte(s.Groups), &ses.Groups); err != nil {
return nil, fmt.Errorf("failed to decode groups in session: %v", err)
}
}
if s.CreatedAt != 0 {
ses.CreatedAt = time.Unix(s.CreatedAt, 0).UTC()
......@@ -107,6 +113,14 @@ func newSessionModel(s *session.Session) (*sessionModel, error) {
Scope: strings.Join(s.Scope, " "),
}
if s.Groups != nil {
data, err := json.Marshal(s.Groups)
if err != nil {
return nil, fmt.Errorf("failed to marshal groups: %v", err)
}
sm.Groups = string(data)
}
if !s.CreatedAt.IsZero() {
sm.CreatedAt = s.CreatedAt.Unix()
}
......
......@@ -20,7 +20,10 @@ import (
var (
testRefreshClientID = "client1"
testRefreshClientID2 = "client2"
testRefreshClients = []client.LoadableClient{
testRefreshConnectorID = "IDPC-1"
testRefreshClients = []client.LoadableClient{
{
Client: client.Client{
Credentials: oidc.ClientCredentials{
......@@ -59,7 +62,7 @@ var (
},
RemoteIdentities: []user.RemoteIdentity{
{
ConnectorID: "IDPC-1",
ConnectorID: testRefreshConnectorID,
ID: "RID-1",
},
},
......@@ -103,12 +106,12 @@ func TestRefreshTokenRepoCreateVerify(t *testing.T) {
for i, tt := range tests {
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
tok, err := repo.Create(testRefreshUserID, testRefreshClientID, tt.createScopes)
tok, err := repo.Create(testRefreshUserID, testRefreshClientID, testRefreshConnectorID, tt.createScopes)
if err != nil {
t.Fatalf("case %d: failed to create refresh token: %v", i, err)
}
tokUserID, gotScopes, err := repo.Verify(tt.verifyClientID, tok)
tokUserID, gotConnectorID, gotScopes, err := repo.Verify(tt.verifyClientID, tok)
if tt.wantVerifyErr {
if err == nil {
t.Errorf("case %d: want non-nil error.", i)
......@@ -126,6 +129,10 @@ func TestRefreshTokenRepoCreateVerify(t *testing.T) {
t.Errorf("case %d: Verified token returned wrong user id, want=%s, got=%s", i,
testRefreshUserID, tokUserID)
}
if gotConnectorID != testRefreshConnectorID {
t.Errorf("case %d: wanted connector_id=%q got=%q", i, testRefreshConnectorID, gotConnectorID)
}
}
}
......@@ -138,7 +145,7 @@ func buildRefreshToken(tokenID int64, tokenPayload []byte) string {
func TestRefreshRepoVerifyInvalidTokens(t *testing.T) {
r := db.NewRefreshTokenRepo(connect(t))
token, err := r.Create("user-foo", "client-foo", oidc.DefaultScope)
token, err := r.Create("user-foo", "client-foo", testRefreshConnectorID, oidc.DefaultScope)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
......@@ -209,7 +216,7 @@ func TestRefreshRepoVerifyInvalidTokens(t *testing.T) {
}
for i, tt := range tests {
result, _, err := r.Verify(tt.creds.ID, tt.token)
result, _, _, err := r.Verify(tt.creds.ID, tt.token)
if err != tt.err {
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
}
......@@ -232,7 +239,7 @@ func TestRefreshTokenRepoClientsWithRefreshTokens(t *testing.T) {
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
for _, clientID := range tt.clientIDs {
_, err := repo.Create(testRefreshUserID, clientID, []string{"openid"})
_, err := repo.Create(testRefreshUserID, clientID, testRefreshConnectorID, []string{"openid"})
if err != nil {
t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err)
}
......@@ -281,7 +288,7 @@ func TestRefreshTokenRepoRevokeForClient(t *testing.T) {
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
for _, clientID := range tt.createIDs {
_, err := repo.Create(testRefreshUserID, clientID, []string{"openid"})
_, err := repo.Create(testRefreshUserID, clientID, testRefreshConnectorID, []string{"openid"})
if err != nil {
t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err)
}
......@@ -318,7 +325,7 @@ func TestRefreshTokenRepoRevokeForClient(t *testing.T) {
func TestRefreshRepoRevoke(t *testing.T) {
r := db.NewRefreshTokenRepo(connect(t))
token, err := r.Create("user-foo", "client-foo", oidc.DefaultScope)
token, err := r.Create("user-foo", "client-foo", testRefreshConnectorID, oidc.DefaultScope)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
......
......@@ -104,6 +104,13 @@ func TestSessionRepoCreateGet(t *testing.T) {
ExpiresAt: time.Unix(789, 0).UTC(),
Nonce: "oncenay",
},
session.Session{
ID: "anID",
ClientState: "blargh",
ExpiresAt: time.Unix(789, 0).UTC(),
Nonce: "oncenay",
Groups: []string{"group1", "group2"},
},
}
for i, tt := range tests {
......
......@@ -149,7 +149,7 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
refreshRepo := db.NewRefreshTokenRepo(dbMap)
for _, user := range userUsers {
if _, err := refreshRepo.Create(user.User.ID, testClientID,
append([]string{"offline_access"}, oidc.DefaultScope...)); err != nil {
"", append([]string{"offline_access"}, oidc.DefaultScope...)); err != nil {
panic("Failed to create refresh token: " + err.Error())
}
}
......
......@@ -44,12 +44,12 @@ type RefreshTokenRepo interface {
// The scopes will be stored with the refresh token, and used to verify
// against future OIDC refresh requests' scopes.
// On success the token will be returned.
Create(userID, clientID string, scope []string) (string, error)
Create(userID, clientID, connectorID string, scope []string) (string, error)
// Verify verifies that a token belongs to the client.
// It returns the user ID to which the token belongs, and the scopes stored
// with token.
Verify(clientID, token string) (string, scope.Scopes, error)
Verify(clientID, token string) (userID, connectorID string, scope scope.Scopes, err error)
// Revoke deletes the refresh token if the token belongs to the given userID.
Revoke(userID, token string) error
......
......@@ -6,6 +6,9 @@ const (
// Scope prefix which indicates initiation of a cross-client authentication flow.
// See https://developers.google.com/identity/protocols/CrossClientAuth
ScopeGoogleCrossClient = "audience:server:client_id:"
// ScopeGroups indicates that groups should be added to the ID Token.
ScopeGroups = "groups"
)
type Scopes []string
......
......@@ -421,6 +421,7 @@ func validateScopes(srv OIDCServer, clientID string, scopes []string) error {
foundOpenIDScope = true
case curScope == "profile":
case curScope == "email":
case curScope == scope.ScopeGroups:
case curScope == "offline_access":
// According to the spec, for offline_access scope, the client must
// use a response_type value that would result in an Authorization
......
......@@ -75,7 +75,8 @@ type Server struct {
OOBTemplate *template.Template
HealthChecks []health.Checkable
Connectors []connector.Connector
// TODO(ericchiang): Make this a map of ID to connector.
Connectors []connector.Connector
ClientRepo client.ClientRepo
ConnectorConfigRepo connector.ConnectorConfigRepo
......@@ -306,6 +307,15 @@ func (s *Server) NewSession(ipdcID, clientID, clientState string, redirectURL ur
return s.SessionManager.NewSessionKey(sessionID)
}
func (s *Server) connector(id string) (connector.Connector, bool) {
for _, c := range s.Connectors {
if c.ID() == id {
return c, true
}
}
return nil, false
}
func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
sessionID, err := s.SessionManager.ExchangeKey(key)
if err != nil {
......@@ -318,6 +328,29 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
}
log.Infof("Session %s remote identity attached: clientID=%s identity=%#v", sessionID, ses.ClientID, ident)
// Get the connector used to log the user in.
conn, ok := s.connector(ses.ConnectorID)
if !ok {
return "", fmt.Errorf("session contained invalid connector ID (%s)", ses.ConnectorID)
}
// If the client has requested access to groups, add them here.
if ses.Scope.HasScope(scope.ScopeGroups) {
grouper, ok := conn.(connector.GroupsConnector)
if !ok {
return "", fmt.Errorf("scope %q provided but connector does not support groups", scope.ScopeGroups)
}
groups, err := grouper.Groups(ident.ID)
if err != nil {
return "", fmt.Errorf("failed to retrieve user groups for %q %v", ident.ID, err)
}
// Update the session.
if ses, err = s.SessionManager.AttachGroups(sessionID, groups); err != nil {
return "", fmt.Errorf("failed save groups")
}
}
if ses.Register {
code, err := s.SessionManager.NewSessionKey(sessionID)
if err != nil {
......@@ -334,18 +367,6 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
remoteIdentity := user.RemoteIdentity{ConnectorID: ses.ConnectorID, ID: ses.Identity.ID}
// Get the connector used to log the user in.
var conn connector.Connector
for _, c := range s.Connectors {
if c.ID() == ses.ConnectorID {
conn = c
break
}
}
if conn == nil {
return "", fmt.Errorf("session contained invalid connector ID (%s)", ses.ConnectorID)
}
usr, err := s.UserRepo.GetByRemoteIdentity(nil, remoteIdentity)
if err == user.ErrorNotFound {
if ses.Identity.Email == "" {
......@@ -508,7 +529,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
if scope == "offline_access" {
log.Infof("Session %s requests offline access, will generate refresh token", sessionID)
refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID, ses.Scope)
refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID, ses.ConnectorID, ses.Scope)
switch err {
case nil:
break
......@@ -535,7 +556,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
}
userID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token)
userID, connectorID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token)
switch err {
case nil:
break
......@@ -555,7 +576,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
}
}
user, err := s.UserRepo.Get(nil, userID)
usr, err := s.UserRepo.Get(nil, userID)
if err != nil {
// The error can be user.ErrorNotFound, but we are not deleting
// user at this moment, so this shouldn't happen.
......@@ -563,6 +584,43 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
var groups []string
if rtScopes.HasScope(scope.ScopeGroups) {
conn, ok := s.connector(connectorID)
if !ok {
log.Errorf("refresh token contained invalid connector ID (%s)", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
grouper, ok := conn.(connector.GroupsConnector)
if !ok {
log.Errorf("refresh token requested groups for connector (%s) that doesn't support groups", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
remoteIdentities, err := s.UserRepo.GetRemoteIdentities(nil, userID)
if err != nil {
log.Errorf("failed to get remote identities: %v", err)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
remoteIdentity, ok := func() (user.RemoteIdentity, bool) {
for _, ri := range remoteIdentities {
if ri.ConnectorID == connectorID {
return ri, true
}
}
return user.RemoteIdentity{}, false
}()
if !ok {
log.Errorf("failed to get remote identity for connector %s", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
if groups, err = grouper.Groups(remoteIdentity.ID); err != nil {
log.Errorf("failed to get groups for refresh token: %v", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
}
signer, err := s.KeyManager.Signer()
if err != nil {
log.Errorf("Failed to refresh ID token: %v", err)
......@@ -572,8 +630,14 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
now := time.Now()
expireAt := now.Add(session.DefaultSessionValidityWindow)
claims := oidc.NewClaims(s.IssuerURL.String(), user.ID, creds.ID, now, expireAt)
user.AddToClaims(claims)
claims := oidc.NewClaims(s.IssuerURL.String(), usr.ID, creds.ID, now, expireAt)
usr.AddToClaims(claims)
if rtScopes.HasScope(scope.ScopeGroups) {
if groups == nil {
groups = []string{}
}
claims["groups"] = groups
}
s.addClaimsFromScope(claims, scope.Scopes(scopes), creds.ID)
......
......@@ -785,8 +785,7 @@ func TestServerRefreshToken(t *testing.T) {
t.Errorf("case %d: error creating other client: %v", i, err)
}
if _, err := f.srv.RefreshTokenRepo.Create(testUserID1, tt.clientID,
tt.createScopes); err != nil {
if _, err := f.srv.RefreshTokenRepo.Create(testUserID1, tt.clientID, "", tt.createScopes); err != nil {
t.Fatalf("Unexpected error: %v", err)
}
......
......@@ -144,6 +144,18 @@ func (m *SessionManager) AttachUser(sessionID string, userID string) (*session.S
return s, nil
}
func (m *SessionManager) AttachGroups(sessionID string, groups []string) (*session.Session, error) {
s, err := m.sessions.Get(sessionID)
if err != nil {
return nil, err
}
s.Groups = groups
if err = m.sessions.Update(*s); err != nil {
return nil, err
}
return s, nil
}
func (m *SessionManager) Kill(sessionID string) (*session.Session, error) {
s, err := m.sessions.Get(sessionID)
if err != nil {
......
......@@ -55,6 +55,9 @@ type Session struct {
// Scope is the 'scope' field in the authentication request. Example scopes
// are 'openid', 'email', 'offline', etc.
Scope scope.Scopes
// Groups the user belongs to.
Groups []string
}
// Claims returns a new set of Claims for the current session.
......@@ -65,5 +68,8 @@ func (s *Session) Claims(issuerURL string) jose.Claims {
if s.Nonce != "" {
claims["nonce"] = s.Nonce
}
if s.Scope.HasScope(scope.ScopeGroups) {
claims["groups"] = s.Groups
}
return claims
}
......@@ -192,7 +192,7 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
}
refreshRepo := db.NewRefreshTokenRepo(dbMap)
for _, token := range refreshTokens {
if _, err := refreshRepo.Create(token.userID, token.clientID, []string{"openid"}); err != nil {
if _, err := refreshRepo.Create(token.userID, token.clientID, "local", []string{"openid"}); err != nil {
panic("Failed to create refresh token: " + err.Error())
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment