Commit edb010ca authored by Eric Chiang's avatar Eric Chiang Committed by GitHub

Merge pull request #510 from ericchiang/add-groups-scope-and-ldap-implementation

Add groups scope and LDAP implementation
parents af6aade6 607d9920
...@@ -56,3 +56,19 @@ For situations in which an app does not have access to a browser, the out-of-ban ...@@ -56,3 +56,19 @@ For situations in which an app does not have access to a browser, the out-of-ban
\* In OpenID Connect a client is called a "Relying Party", but "client" seems to \* In OpenID Connect a client is called a "Relying Party", but "client" seems to
be the more common ter, has been around longer and is present in paramter names be the more common ter, has been around longer and is present in paramter names
like "client_id" so we prefer it over "Relying Party" usually. like "client_id" so we prefer it over "Relying Party" usually.
## Groups
Connectors that support groups (currently only the LDAP connector) can embed the groups a user belongs to in the ID Token. Using the scope "groups" during the initial redirect with a connector that supports groups will return an JWT with the following field.
```
{
"groups": [
"cn=ipausers,cn=groups,cn=accounts,dc=example,dc=com,
"cn=team-engineering,cn=groups,cn=accounts,dc=example,dc=com"
],
...
}
```
If the client has also requested a refresh token, the groups field is updated during each refresh request.
...@@ -153,6 +153,7 @@ In addition to `id` and `type`, the `ldap` connector takes the following additio ...@@ -153,6 +153,7 @@ In addition to `id` and `type`, the `ldap` connector takes the following additio
* emailAttribute: a `string`. Required. Attribute to map to Email. Default: `mail` * emailAttribute: a `string`. Required. Attribute to map to Email. Default: `mail`
* searchBeforeAuth: a `boolean`. Perform search for entryDN to be used for bind. * searchBeforeAuth: a `boolean`. Perform search for entryDN to be used for bind.
* searchFilter: a `string`. Filter to apply to search. Variable substititions: `%u` User supplied username/e-mail address. `%b` BaseDN. Searches that return multiple entries are considered ambiguous and will return an error. * searchFilter: a `string`. Filter to apply to search. Variable substititions: `%u` User supplied username/e-mail address. `%b` BaseDN. Searches that return multiple entries are considered ambiguous and will return an error.
* searchGroupFilter: a `string`. A filter which should return group entry for a given user. The string is formatted the same as `searchFilter`, execpt `%u` is replaced by the fully qualified user entry. Groups are only searched if the client request the "groups" scope.
* searchScope: a `string`. Scope of the search. `base|one|sub`. Default: `one` * searchScope: a `string`. Scope of the search. `base|one|sub`. Default: `one`
* searchBindDN: a `string`. DN to bind as for search operations. * searchBindDN: a `string`. DN to bind as for search operations.
* searchBindPw: a `string`. Password for bind for search operations. * searchBindPw: a `string`. Password for bind for search operations.
...@@ -180,19 +181,20 @@ uid=janedoe,cn=users,cn=accounts,dc=auth,dc=example,dc=com ...@@ -180,19 +181,20 @@ uid=janedoe,cn=users,cn=accounts,dc=auth,dc=example,dc=com
The connector then attempts to bind as this entry using the password provided by the end user. The connector then attempts to bind as this entry using the password provided by the end user.
### Example: Searching the directory ### Example: Searching a FreeIPA server with groups
The following configuration will search a directory using an LDAP filter. With FreeIPA The following configuration will search a FreeIPA directory using an LDAP filter.
``` ```
{ {
"type": "ldap", "type": "ldap",
"id": "ldap", "id": "ldap",
"host": "127.0.0.1:389", "host": "127.0.0.1:389",
"baseDN": "cn=auth,dc=example,dc=com", "baseDN": "cn=accounts,dc=example,dc=com",
"searchBeforeAuth": true, "searchBeforeAuth": true,
"searchFilter": "(&(objectClass=person)(uid=%u))", "searchFilter": "(&(objectClass=person)(uid=%u))",
"searchGroupFilter": "(&(objectClass=ipausergroup)(member=%u))",
"searchScope": "sub", "searchScope": "sub",
"searchBindDN": "serviceAccountUser", "searchBindDN": "serviceAccountUser",
...@@ -206,9 +208,15 @@ The following configuration will search a directory using an LDAP filter. With F ...@@ -206,9 +208,15 @@ The following configuration will search a directory using an LDAP filter. With F
(&(objectClass=person)(uid=janedoe)) (&(objectClass=person)(uid=janedoe))
``` ```
If the search finds an entry, it will attempt to use the provided password to bind as that entry. If the search finds an entry, it will attempt to use the provided password to bind as that entry. Searches that return multiple entries are considered ambiguous and will return an error.
__NOTE__: Searches that return multiple entries will return an error. "searchGroupFilter" is a format string similar to "searchFilter" except `%u` is replaced by the fully qualified user entry returned by "searchFilter". So if the initial search returns "uid=janedoe,cn=users,cn=accounts,dc=example,dc=com", the connector will use the search query:
```
(&(objectClass=ipausergroup)(member=uid=janedoe,cn=users,cn=accounts,dc=example,dc=com))
```
If the client requests the "groups" scope, the names of all returned entries are added to the ID Token "groups" claim.
## Setting the Configuration ## Setting the Configuration
......
...@@ -107,11 +107,12 @@ type LDAPConnector struct { ...@@ -107,11 +107,12 @@ type LDAPConnector struct {
nameAttribute string nameAttribute string
emailAttribute string emailAttribute string
searchBeforeAuth bool searchBeforeAuth bool
searchFilter string searchFilter string
searchScope int searchScope int
searchBindDN string searchBindDN string
searchBindPw string searchBindPw string
searchGroupFilter string
bindTemplate string bindTemplate string
...@@ -203,19 +204,20 @@ func (cfg *LDAPConnectorConfig) Connector(ns url.URL, lf oidc.LoginFunc, tpls *t ...@@ -203,19 +204,20 @@ func (cfg *LDAPConnectorConfig) Connector(ns url.URL, lf oidc.LoginFunc, tpls *t
} }
idpc := &LDAPConnector{ idpc := &LDAPConnector{
id: cfg.ID, id: cfg.ID,
namespace: ns, namespace: ns,
loginFunc: lf, loginFunc: lf,
loginTpl: tpl, loginTpl: tpl,
baseDN: cfg.BaseDN, baseDN: cfg.BaseDN,
nameAttribute: cfg.NameAttribute, nameAttribute: cfg.NameAttribute,
emailAttribute: cfg.EmailAttribute, emailAttribute: cfg.EmailAttribute,
searchBeforeAuth: cfg.SearchBeforeAuth, searchBeforeAuth: cfg.SearchBeforeAuth,
searchFilter: cfg.SearchFilter, searchFilter: cfg.SearchFilter,
searchScope: searchScope, searchGroupFilter: cfg.SearchGroupFilter,
searchBindDN: cfg.SearchBindDN, searchScope: searchScope,
searchBindPw: cfg.SearchBindPw, searchBindDN: cfg.SearchBindDN,
bindTemplate: cfg.BindTemplate, searchBindPw: cfg.SearchBindPw,
bindTemplate: cfg.BindTemplate,
ldapPool: &LDAPPool{ ldapPool: &LDAPPool{
MaxIdleConn: cfg.MaxIdleConn, MaxIdleConn: cfg.MaxIdleConn,
PoolCheckTimer: defaultPoolCheckTimer, PoolCheckTimer: defaultPoolCheckTimer,
...@@ -433,12 +435,47 @@ func invalidBindCredentials(err error) bool { ...@@ -433,12 +435,47 @@ func invalidBindCredentials(err error) bool {
func (c *LDAPConnector) formatDN(template, username string) string { func (c *LDAPConnector) formatDN(template, username string) string {
result := template result := template
result = strings.Replace(result, "%u", username, -1) result = strings.Replace(result, "%u", ldap.EscapeFilter(username), -1)
result = strings.Replace(result, "%b", c.baseDN, -1) result = strings.Replace(result, "%b", c.baseDN, -1)
return result return result
} }
func (c *LDAPConnector) Groups(fullUserID string) ([]string, error) {
if !c.searchBeforeAuth {
return nil, fmt.Errorf("cannot search without service account")
}
if c.searchGroupFilter == "" {
return nil, fmt.Errorf("no group filter specified")
}
var groups []string
err := c.ldapPool.Do(func(conn *ldap.Conn) error {
if err := conn.Bind(c.searchBindDN, c.searchBindPw); err != nil {
if !invalidBindCredentials(err) {
log.Errorf("failed to connect to LDAP for search bind: %v", err)
}
return fmt.Errorf("failed to bind: %v", err)
}
req := &ldap.SearchRequest{
BaseDN: c.baseDN,
Scope: c.searchScope,
Filter: c.formatDN(c.searchGroupFilter, fullUserID),
}
resp, err := conn.Search(req)
if err != nil {
return fmt.Errorf("search failed: %v", err)
}
groups = make([]string, len(resp.Entries))
for i, entry := range resp.Entries {
groups[i] = entry.DN
}
return nil
})
return groups, err
}
func (c *LDAPConnector) Identity(username, password string) (*oidc.Identity, error) { func (c *LDAPConnector) Identity(username, password string) (*oidc.Identity, error) {
var ( var (
identity *oidc.Identity identity *oidc.Identity
...@@ -447,8 +484,10 @@ func (c *LDAPConnector) Identity(username, password string) (*oidc.Identity, err ...@@ -447,8 +484,10 @@ func (c *LDAPConnector) Identity(username, password string) (*oidc.Identity, err
if c.searchBeforeAuth { if c.searchBeforeAuth {
err = c.ldapPool.Do(func(conn *ldap.Conn) error { err = c.ldapPool.Do(func(conn *ldap.Conn) error {
if err := conn.Bind(c.searchBindDN, c.searchBindPw); err != nil { if err := conn.Bind(c.searchBindDN, c.searchBindPw); err != nil {
// Don't wrap error as it may be a specific LDAP error. if !invalidBindCredentials(err) {
return err log.Errorf("failed to connect to LDAP for search bind: %v", err)
}
return fmt.Errorf("failed to bind: %v", err)
} }
filter := c.formatDN(c.searchFilter, username) filter := c.formatDN(c.searchFilter, username)
...@@ -491,8 +530,10 @@ func (c *LDAPConnector) Identity(username, password string) (*oidc.Identity, err ...@@ -491,8 +530,10 @@ func (c *LDAPConnector) Identity(username, password string) (*oidc.Identity, err
err = c.ldapPool.Do(func(conn *ldap.Conn) error { err = c.ldapPool.Do(func(conn *ldap.Conn) error {
userBindDN := c.formatDN(c.bindTemplate, username) userBindDN := c.formatDN(c.bindTemplate, username)
if err := conn.Bind(userBindDN, password); err != nil { if err := conn.Bind(userBindDN, password); err != nil {
// Don't wrap error as it may be a specific LDAP error. if !invalidBindCredentials(err) {
return err log.Errorf("failed to connect to LDAP for search bind: %v", err)
}
return fmt.Errorf("failed to bind: %v", err)
} }
req := &ldap.SearchRequest{ req := &ldap.SearchRequest{
...@@ -522,11 +563,7 @@ func (c *LDAPConnector) Identity(username, password string) (*oidc.Identity, err ...@@ -522,11 +563,7 @@ func (c *LDAPConnector) Identity(username, password string) (*oidc.Identity, err
return nil return nil
}) })
} }
if err != nil { if err != nil {
if !invalidBindCredentials(err) {
log.Errorf("failed to connect to LDAP for search bind: %v", err)
}
return nil, err return nil, err
} }
return identity, nil return identity, nil
......
...@@ -60,6 +60,12 @@ type ConnectorConfig interface { ...@@ -60,6 +60,12 @@ type ConnectorConfig interface {
Connector(ns url.URL, loginFunc oidc.LoginFunc, tpls *template.Template) (Connector, error) Connector(ns url.URL, loginFunc oidc.LoginFunc, tpls *template.Template) (Connector, error)
} }
// GroupsConnector is a strategy for mapping a user to a set of groups. This is optionally
// implemented by some connectors.
type GroupsConnector interface {
Groups(fullUserID string) ([]string, error)
}
type ConnectorConfigRepo interface { type ConnectorConfigRepo interface {
All() ([]ConnectorConfig, error) All() ([]ConnectorConfig, error)
GetConnectorByID(repo.Transaction, string) (ConnectorConfig, error) GetConnectorByID(repo.Transaction, string) (ConnectorConfig, error)
......
...@@ -41,6 +41,7 @@ CREATE TABLE refresh_token ( ...@@ -41,6 +41,7 @@ CREATE TABLE refresh_token (
payload_hash blob, payload_hash blob,
user_id text, user_id text,
client_id text, client_id text,
connector_id text,
scopes text scopes text
); );
...@@ -63,7 +64,8 @@ CREATE TABLE session ( ...@@ -63,7 +64,8 @@ CREATE TABLE session (
user_id text, user_id text,
register integer, register integer,
nonce text, nonce text,
scope text scope text,
groups text
); );
CREATE TABLE session_key ( CREATE TABLE session_key (
......
-- +migrate Up
ALTER TABLE refresh_token ADD COLUMN "connector_id" text;
ALTER TABLE session ADD COLUMN "groups" text;
...@@ -90,5 +90,11 @@ var PostgresMigrations migrate.MigrationSource = &migrate.MemoryMigrationSource{ ...@@ -90,5 +90,11 @@ var PostgresMigrations migrate.MigrationSource = &migrate.MemoryMigrationSource{
"-- +migrate Up\nALTER TABLE refresh_token ADD COLUMN \"scopes\" text;\n\nUPDATE refresh_token SET scopes = 'openid profile email offline_access';\n", "-- +migrate Up\nALTER TABLE refresh_token ADD COLUMN \"scopes\" text;\n\nUPDATE refresh_token SET scopes = 'openid profile email offline_access';\n",
}, },
}, },
{
Id: "0014_add_groups.sql",
Up: []string{
"-- +migrate Up\nALTER TABLE refresh_token ADD COLUMN \"connector_id\" text;\nALTER TABLE session ADD COLUMN \"groups\" text;\n",
},
},
}, },
} }
...@@ -41,6 +41,7 @@ type refreshTokenModel struct { ...@@ -41,6 +41,7 @@ type refreshTokenModel struct {
PayloadHash []byte `db:"payload_hash"` PayloadHash []byte `db:"payload_hash"`
UserID string `db:"user_id"` UserID string `db:"user_id"`
ClientID string `db:"client_id"` ClientID string `db:"client_id"`
ConnectorID string `db:"connector_id"`
Scopes string `db:"scopes"` Scopes string `db:"scopes"`
} }
...@@ -89,7 +90,7 @@ func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenG ...@@ -89,7 +90,7 @@ func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenG
} }
} }
func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (string, error) { func (r *refreshTokenRepo) Create(userID, clientID, connectorID string, scopes []string) (string, error) {
if userID == "" { if userID == "" {
return "", refresh.ErrorInvalidUserID return "", refresh.ErrorInvalidUserID
} }
...@@ -112,6 +113,7 @@ func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (str ...@@ -112,6 +113,7 @@ func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (str
PayloadHash: payloadHash, PayloadHash: payloadHash,
UserID: userID, UserID: userID,
ClientID: clientID, ClientID: clientID,
ConnectorID: connectorID,
Scopes: strings.Join(scopes, " "), Scopes: strings.Join(scopes, " "),
} }
...@@ -122,24 +124,24 @@ func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (str ...@@ -122,24 +124,24 @@ func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (str
return buildToken(record.ID, tokenPayload), nil return buildToken(record.ID, tokenPayload), nil
} }
func (r *refreshTokenRepo) Verify(clientID, token string) (string, scope.Scopes, error) { func (r *refreshTokenRepo) Verify(clientID, token string) (userID, connectorID string, scope scope.Scopes, err error) {
tokenID, tokenPayload, err := parseToken(token) tokenID, tokenPayload, err := parseToken(token)
if err != nil { if err != nil {
return "", nil, err return
} }
record, err := r.get(nil, tokenID) record, err := r.get(nil, tokenID)
if err != nil { if err != nil {
return "", nil, err return
} }
if record.ClientID != clientID { if record.ClientID != clientID {
return "", nil, refresh.ErrorInvalidClientID return "", "", nil, refresh.ErrorInvalidClientID
} }
if err := checkTokenPayload(record.PayloadHash, tokenPayload); err != nil { if err = checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
return "", nil, err return
} }
var scopes []string var scopes []string
...@@ -147,7 +149,7 @@ func (r *refreshTokenRepo) Verify(clientID, token string) (string, scope.Scopes, ...@@ -147,7 +149,7 @@ func (r *refreshTokenRepo) Verify(clientID, token string) (string, scope.Scopes,
scopes = strings.Split(record.Scopes, " ") scopes = strings.Split(record.Scopes, " ")
} }
return record.UserID, scopes, nil return record.UserID, record.ConnectorID, scopes, nil
} }
func (r *refreshTokenRepo) Revoke(userID, token string) error { func (r *refreshTokenRepo) Revoke(userID, token string) error {
......
...@@ -44,6 +44,7 @@ type sessionModel struct { ...@@ -44,6 +44,7 @@ type sessionModel struct {
Register bool `db:"register"` Register bool `db:"register"`
Nonce string `db:"nonce"` Nonce string `db:"nonce"`
Scope string `db:"scope"` Scope string `db:"scope"`
Groups string `db:"groups"`
} }
func (s *sessionModel) session() (*session.Session, error) { func (s *sessionModel) session() (*session.Session, error) {
...@@ -75,6 +76,11 @@ func (s *sessionModel) session() (*session.Session, error) { ...@@ -75,6 +76,11 @@ func (s *sessionModel) session() (*session.Session, error) {
Nonce: s.Nonce, Nonce: s.Nonce,
Scope: strings.Fields(s.Scope), Scope: strings.Fields(s.Scope),
} }
if s.Groups != "" {
if err := json.Unmarshal([]byte(s.Groups), &ses.Groups); err != nil {
return nil, fmt.Errorf("failed to decode groups in session: %v", err)
}
}
if s.CreatedAt != 0 { if s.CreatedAt != 0 {
ses.CreatedAt = time.Unix(s.CreatedAt, 0).UTC() ses.CreatedAt = time.Unix(s.CreatedAt, 0).UTC()
...@@ -107,6 +113,14 @@ func newSessionModel(s *session.Session) (*sessionModel, error) { ...@@ -107,6 +113,14 @@ func newSessionModel(s *session.Session) (*sessionModel, error) {
Scope: strings.Join(s.Scope, " "), Scope: strings.Join(s.Scope, " "),
} }
if s.Groups != nil {
data, err := json.Marshal(s.Groups)
if err != nil {
return nil, fmt.Errorf("failed to marshal groups: %v", err)
}
sm.Groups = string(data)
}
if !s.CreatedAt.IsZero() { if !s.CreatedAt.IsZero() {
sm.CreatedAt = s.CreatedAt.Unix() sm.CreatedAt = s.CreatedAt.Unix()
} }
......
...@@ -68,7 +68,7 @@ func (fi bindataFileInfo) Sys() interface{} { ...@@ -68,7 +68,7 @@ func (fi bindataFileInfo) Sys() interface{} {
return nil return nil
} }
var _dataIndexHtml = []byte("\x1f\x8b\x08\x00\x00\x09\x6e\x88\x00\xff\x94\x52\xcd\x4e\xc3\x30\x0c\xbe\xef\x29\xac\x9c\xe0\x30\x7a\x47\x6d\x25\x40\xdc\x90\x26\xf1\x02\x53\x9a\x78\x6d\xb4\xfc\x4c\x89\x8b\x36\x4d\x7b\x77\xdc\x96\xae\x5b\x81\x09\x6e\xfe\x14\xfb\xfb\x89\x9d\x37\xe4\x6c\xb9\x00\xc8\xab\xa0\x0f\xe5\x82\x2b\xae\x37\x21\x3a\x90\x8a\x4c\xf0\x85\xc8\x6c\xa8\x8d\x17\x65\xff\xc4\x8f\x24\x2b\x8b\x23\xea\x70\x9c\x40\x07\x75\x09\x4f\x2d\x35\xe8\xc9\x28\x49\x08\x4c\xf6\x78\xd1\xd0\x49\x5d\x4d\x00\xdc\xa9\xe0\x9c\x5c\x26\xdc\xc9\xc8\x13\x1a\xac\x49\x04\x61\x03\xca\x1a\xa6\x59\x1a\x9d\xee\x2f\x25\x32\xd6\x98\x4b\xe6\xc6\xef\x5a\x02\x3a\xec\xb0\x10\x84\x7b\x12\xe0\xa5\xe3\x5a\xc5\x90\xd2\x7a\x60\x12\x50\xce\xa6\x19\x9d\xcd\x70\x3d\x44\x3b\x1e\xc1\x6c\xe0\x61\xb5\x7a\x86\xd3\x69\x6a\xbd\x54\x48\x6d\xe5\x0c\xf3\x7d\x48\xdb\x32\x7c\xeb\xbf\xa8\x8b\xea\x48\xc6\x1a\xa9\x10\xeb\xca\x4a\xbf\x15\x3d\x1b\xda\x84\xff\xa4\x1a\xe6\xbc\x1e\xc7\xf2\xac\x23\xe7\x05\x7d\x37\x37\x5b\x97\x92\xd6\x56\x52\x6d\x05\x38\xa4\x26\xe8\x42\xb0\x9f\x8e\x70\xd0\x7e\x09\x1a\x17\x3f\xd8\xb8\xfa\x33\xee\x39\x1b\x9a\x36\x3f\xed\xed\x56\x80\xd7\xbd\x6a\xa4\xaf\xb1\x57\x1a\x75\x47\xfb\xd7\xa1\xbe\xc2\xf8\x40\xb7\x02\x45\xac\xf9\x1e\x30\x8a\xbf\xa8\xbf\x8f\xcd\x00\xd9\xef\xd2\x79\x36\x9c\x7b\x9e\x0d\xf7\xff\x19\x00\x00\xff\xff\xaf\x0b\xca\x75\x07\x03\x00\x00") var _dataIndexHtml = []byte("\x1f\x8b\x08\x00\x00\x09\x6e\x88\x00\xff\x94\x93\xcf\x8a\xe3\x30\x0c\xc6\xef\x7d\x0a\xe1\x7b\x37\xf7\xc5\x29\xec\x0e\xbd\x0d\x14\xe6\x05\x8a\x63\xab\x89\xa9\xff\x61\x2b\x43\x4b\xe9\xbb\x8f\x53\x37\x61\x52\xd2\xa1\x73\x93\xd1\x27\x7d\x3f\x49\x98\x77\x64\xcd\x66\x05\xc0\x1b\xaf\xce\x43\x90\xc3\x83\x8f\x16\x84\x24\xed\x5d\xcd\x2a\xe3\x5b\xed\x58\x49\x0d\xd9\x30\x85\x00\xff\x7a\xea\xd0\x91\x96\x82\x10\x72\xd9\x5f\xae\x5d\xe8\x09\xe8\x1c\xb0\x66\x84\x27\x62\xe0\x84\xcd\xb1\x8c\x3e\xa5\xbd\x34\x3a\xcb\x19\x04\x23\x24\x76\xde\x28\x8c\x39\xe5\xad\x15\xeb\x84\x41\xc4\xdc\x46\x81\xd1\x89\xc0\x1f\xa0\x88\xd7\x5a\xa5\x6f\xee\x55\x58\x26\xd9\x9e\x28\x0a\x48\xd2\x07\x4c\xcf\x29\x70\x50\xed\x8b\xea\x45\x8a\xbb\x78\x4e\x70\xb9\x80\x3e\xc0\x9f\xdd\xee\x3f\x5c\xaf\x13\xc4\xcc\x36\xf5\x8d\xd5\xd9\xf8\x53\x98\x3e\x3f\xdf\x6f\x5b\x1c\x76\x64\x49\xc4\x16\xa9\x66\xfb\xc6\x08\x77\x64\xb7\x6e\x68\x12\xfe\xb2\x55\xa9\x73\x6a\x2c\xe3\xd5\xd0\x7c\xb3\x5a\x80\x7b\xb8\xa8\x14\xc6\x34\x42\x1e\x19\x58\xa4\xce\xab\x9a\x65\x9e\xa1\x61\xf1\x7e\xf3\x0a\x57\x0b\x18\xb3\x73\x66\xcd\x04\x34\x2d\x87\x37\x71\xb3\x54\xf9\x30\xc0\xf6\x24\x3b\xe1\x5a\xbc\x39\x8d\xbe\x23\xfe\x7c\xa8\xfb\x30\xce\xd3\x4f\x03\x45\x6c\xf3\xb5\x30\xb2\x57\xdc\x3f\x46\x31\x40\xf5\xdc\x9a\x57\xe5\x43\xf0\xaa\xfc\x90\xaf\x00\x00\x00\xff\xff\x9c\x89\xe2\x28\x29\x03\x00\x00")
func dataIndexHtmlBytes() ([]byte, error) { func dataIndexHtmlBytes() ([]byte, error) {
return bindataRead( return bindataRead(
...@@ -83,7 +83,7 @@ func dataIndexHtml() (*asset, error) { ...@@ -83,7 +83,7 @@ func dataIndexHtml() (*asset, error) {
return nil, err return nil, err
} }
info := bindataFileInfo{name: "data/index.html", size: 775, mode: os.FileMode(420), modTime: time.Unix(1466378108, 0)} info := bindataFileInfo{name: "data/index.html", size: 809, mode: os.FileMode(436), modTime: time.Unix(1468620773, 0)}
a := &asset{bytes: bytes, info: info} a := &asset{bytes: bytes, info: info}
return a, nil return a, nil
} }
......
<html> <html>
<body> <body>
<form action="/login"> <form action="/login">
<table> <p>
<tr> Authenticate for:<input type="text" name="cross_client" placeholder="comma-separated list of client-ids">
<td> Authenticate for: </p>
<br> <p>
(comma-separated list of client-ids) Extra scopes:<input type="text" name="extra_scopes" placeholder="comma-separated list of scopes">
</td> </p>
<td> <input type="text" name="cross_client" > </td>
</tr>
</table>
{{ if .OOB }} {{ if .OOB }}
<input type="submit" value="Login" formtarget="_blank"> <input type="submit" value="Login" formtarget="_blank">
{{ else }} {{ else }}
......
...@@ -218,18 +218,25 @@ func handleLoginFunc(c *oidc.Client) http.HandlerFunc { ...@@ -218,18 +218,25 @@ func handleLoginFunc(c *oidc.Client) http.HandlerFunc {
panic("unable to proceed") panic("unable to proceed")
} }
xClient := r.Form.Get("cross_client") var scopes []string
if xClient != "" { q := u.Query()
if scope := q.Get("scope"); scope != "" {
scopes = strings.Split(scope, " ")
}
if xClient := r.Form.Get("cross_client"); xClient != "" {
xClients := strings.Split(xClient, ",") xClients := strings.Split(xClient, ",")
for i, x := range xClients { for _, x := range xClients {
xClients[i] = scope.ScopeGoogleCrossClient + x scopes = append(scopes, scope.ScopeGoogleCrossClient+x)
} }
q := u.Query() }
scope := q.Get("scope")
scopes := strings.Split(scope, " ") if extraScopes := r.Form.Get("extra_scopes"); extraScopes != "" {
scopes = append(scopes, xClients...) scopes = append(scopes, strings.Split(extraScopes, ",")...)
scope = strings.Join(scopes, " ") }
q.Set("scope", scope)
if scopes != nil {
q.Set("scope", strings.Join(scopes, " "))
u.RawQuery = q.Encode() u.RawQuery = q.Encode()
} }
...@@ -292,57 +299,69 @@ func handleResendFunc(c *oidc.Client, issuerURL, resendURL, cbURL url.URL) http. ...@@ -292,57 +299,69 @@ func handleResendFunc(c *oidc.Client, issuerURL, resendURL, cbURL url.URL) http.
func handleCallbackFunc(c *oidc.Client) http.HandlerFunc { func handleCallbackFunc(c *oidc.Client) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
refreshToken := r.URL.Query().Get("refresh_token")
code := r.URL.Query().Get("code") code := r.URL.Query().Get("code")
if code == "" {
phttp.WriteError(w, http.StatusBadRequest, "code query param must be set")
return
}
tokens, err := exchangeAuthCode(c, code) oac, err := c.OAuthClient()
if err != nil { if err != nil {
phttp.WriteError(w, http.StatusBadRequest, phttp.WriteError(w, http.StatusBadRequest, fmt.Sprintf("unable to create OAuth2 client: %v", err))
fmt.Sprintf("unable to verify auth code with issuer: %v", err))
return return
} }
tok, err := jose.ParseJWT(tokens.IDToken) var token oauth2.TokenResponse
if err != nil {
phttp.WriteError(w, http.StatusBadRequest, switch {
fmt.Sprintf("unable to parse JWT: %v", err)) case code != "":
if token, err = oac.RequestToken(oauth2.GrantTypeAuthCode, code); err != nil {
phttp.WriteError(w, http.StatusBadRequest, fmt.Sprintf("unable to verify auth code with issuer: %v", err))
return
}
case refreshToken != "":
if token, err = oac.RequestToken(oauth2.GrantTypeRefreshToken, refreshToken); err != nil {
phttp.WriteError(w, http.StatusBadRequest, fmt.Sprintf("unable to refresh token: %v", err))
return
}
if token.RefreshToken == "" {
token.RefreshToken = refreshToken
}
default:
phttp.WriteError(w, http.StatusBadRequest, "code query param must be set")
return return
} }
claims, err := tok.Claims() tok, err := jose.ParseJWT(token.IDToken)
if err != nil { if err != nil {
phttp.WriteError(w, http.StatusBadRequest, phttp.WriteError(w, http.StatusBadRequest, fmt.Sprintf("unable to parse JWT: %v", err))
fmt.Sprintf("unable to construct claims: %v", err))
return return
} }
claims := new(bytes.Buffer)
if err := json.Indent(claims, tok.Payload, "", " "); err != nil {
phttp.WriteError(w, http.StatusBadRequest, fmt.Sprintf("unable to construct claims: %v", err))
return
}
s := fmt.Sprintf(` s := fmt.Sprintf(`
<html> <html>
<head>
<style>
/* make pre wrap */
pre {
white-space: pre-wrap; /* css-3 */
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
white-space: -pre-wrap; /* Opera 4-6 */
white-space: -o-pre-wrap; /* Opera 7 */
word-wrap: break-word; /* Internet Explorer 5.5+ */
}
</style>
</head>
<body> <body>
<p> Token: %v</p> <p> Token: <pre><code>%v</code></pre></p>
<p> Claims: %v </p> <p> Claims: <pre><code>%v</code></pre></p>
<a href="/resend?jwt=%s">Resend Verification Email</a> <p> Refresh Token: <pre><code>%v</code></pre></p>
<p> Refresh Token: %v </p> <p><a href="%s?refresh_token=%s">Redeem refresh token</a><p>
<p><a href="/resend?jwt=%s">Resend Verification Email</a></p>
</body> </body>
</html>`, tok.Encode(), claims, tok.Encode(), tokens.RefreshToken) </html>`, tok.Encode(), claims.String(), token.RefreshToken, r.URL.Path, token.RefreshToken, tok.Encode())
w.Write([]byte(s)) w.Write([]byte(s))
} }
} }
func exchangeAuthCode(c *oidc.Client, code string) (oauth2.TokenResponse, error) {
oac, err := c.OAuthClient()
if err != nil {
return oauth2.TokenResponse{}, err
}
t, err := oac.RequestToken(oauth2.GrantTypeAuthCode, code)
if err != nil {
return oauth2.TokenResponse{}, err
}
return t, nil
}
...@@ -20,7 +20,10 @@ import ( ...@@ -20,7 +20,10 @@ import (
var ( var (
testRefreshClientID = "client1" testRefreshClientID = "client1"
testRefreshClientID2 = "client2" testRefreshClientID2 = "client2"
testRefreshClients = []client.LoadableClient{
testRefreshConnectorID = "IDPC-1"
testRefreshClients = []client.LoadableClient{
{ {
Client: client.Client{ Client: client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
...@@ -59,7 +62,7 @@ var ( ...@@ -59,7 +62,7 @@ var (
}, },
RemoteIdentities: []user.RemoteIdentity{ RemoteIdentities: []user.RemoteIdentity{
{ {
ConnectorID: "IDPC-1", ConnectorID: testRefreshConnectorID,
ID: "RID-1", ID: "RID-1",
}, },
}, },
...@@ -103,12 +106,12 @@ func TestRefreshTokenRepoCreateVerify(t *testing.T) { ...@@ -103,12 +106,12 @@ func TestRefreshTokenRepoCreateVerify(t *testing.T) {
for i, tt := range tests { for i, tt := range tests {
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients) repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
tok, err := repo.Create(testRefreshUserID, testRefreshClientID, tt.createScopes) tok, err := repo.Create(testRefreshUserID, testRefreshClientID, testRefreshConnectorID, tt.createScopes)
if err != nil { if err != nil {
t.Fatalf("case %d: failed to create refresh token: %v", i, err) t.Fatalf("case %d: failed to create refresh token: %v", i, err)
} }
tokUserID, gotScopes, err := repo.Verify(tt.verifyClientID, tok) tokUserID, gotConnectorID, gotScopes, err := repo.Verify(tt.verifyClientID, tok)
if tt.wantVerifyErr { if tt.wantVerifyErr {
if err == nil { if err == nil {
t.Errorf("case %d: want non-nil error.", i) t.Errorf("case %d: want non-nil error.", i)
...@@ -126,6 +129,10 @@ func TestRefreshTokenRepoCreateVerify(t *testing.T) { ...@@ -126,6 +129,10 @@ func TestRefreshTokenRepoCreateVerify(t *testing.T) {
t.Errorf("case %d: Verified token returned wrong user id, want=%s, got=%s", i, t.Errorf("case %d: Verified token returned wrong user id, want=%s, got=%s", i,
testRefreshUserID, tokUserID) testRefreshUserID, tokUserID)
} }
if gotConnectorID != testRefreshConnectorID {
t.Errorf("case %d: wanted connector_id=%q got=%q", i, testRefreshConnectorID, gotConnectorID)
}
} }
} }
...@@ -138,7 +145,7 @@ func buildRefreshToken(tokenID int64, tokenPayload []byte) string { ...@@ -138,7 +145,7 @@ func buildRefreshToken(tokenID int64, tokenPayload []byte) string {
func TestRefreshRepoVerifyInvalidTokens(t *testing.T) { func TestRefreshRepoVerifyInvalidTokens(t *testing.T) {
r := db.NewRefreshTokenRepo(connect(t)) r := db.NewRefreshTokenRepo(connect(t))
token, err := r.Create("user-foo", "client-foo", oidc.DefaultScope) token, err := r.Create("user-foo", "client-foo", testRefreshConnectorID, oidc.DefaultScope)
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
...@@ -209,7 +216,7 @@ func TestRefreshRepoVerifyInvalidTokens(t *testing.T) { ...@@ -209,7 +216,7 @@ func TestRefreshRepoVerifyInvalidTokens(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
result, _, err := r.Verify(tt.creds.ID, tt.token) result, _, _, err := r.Verify(tt.creds.ID, tt.token)
if err != tt.err { if err != tt.err {
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err) t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
} }
...@@ -232,7 +239,7 @@ func TestRefreshTokenRepoClientsWithRefreshTokens(t *testing.T) { ...@@ -232,7 +239,7 @@ func TestRefreshTokenRepoClientsWithRefreshTokens(t *testing.T) {
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients) repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
for _, clientID := range tt.clientIDs { for _, clientID := range tt.clientIDs {
_, err := repo.Create(testRefreshUserID, clientID, []string{"openid"}) _, err := repo.Create(testRefreshUserID, clientID, testRefreshConnectorID, []string{"openid"})
if err != nil { if err != nil {
t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err) t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err)
} }
...@@ -281,7 +288,7 @@ func TestRefreshTokenRepoRevokeForClient(t *testing.T) { ...@@ -281,7 +288,7 @@ func TestRefreshTokenRepoRevokeForClient(t *testing.T) {
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients) repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
for _, clientID := range tt.createIDs { for _, clientID := range tt.createIDs {
_, err := repo.Create(testRefreshUserID, clientID, []string{"openid"}) _, err := repo.Create(testRefreshUserID, clientID, testRefreshConnectorID, []string{"openid"})
if err != nil { if err != nil {
t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err) t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err)
} }
...@@ -318,7 +325,7 @@ func TestRefreshTokenRepoRevokeForClient(t *testing.T) { ...@@ -318,7 +325,7 @@ func TestRefreshTokenRepoRevokeForClient(t *testing.T) {
func TestRefreshRepoRevoke(t *testing.T) { func TestRefreshRepoRevoke(t *testing.T) {
r := db.NewRefreshTokenRepo(connect(t)) r := db.NewRefreshTokenRepo(connect(t))
token, err := r.Create("user-foo", "client-foo", oidc.DefaultScope) token, err := r.Create("user-foo", "client-foo", testRefreshConnectorID, oidc.DefaultScope)
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
......
...@@ -104,6 +104,13 @@ func TestSessionRepoCreateGet(t *testing.T) { ...@@ -104,6 +104,13 @@ func TestSessionRepoCreateGet(t *testing.T) {
ExpiresAt: time.Unix(789, 0).UTC(), ExpiresAt: time.Unix(789, 0).UTC(),
Nonce: "oncenay", Nonce: "oncenay",
}, },
session.Session{
ID: "anID",
ClientState: "blargh",
ExpiresAt: time.Unix(789, 0).UTC(),
Nonce: "oncenay",
Groups: []string{"group1", "group2"},
},
} }
for i, tt := range tests { for i, tt := range tests {
......
...@@ -149,7 +149,7 @@ func makeUserAPITestFixtures() *userAPITestFixtures { ...@@ -149,7 +149,7 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
refreshRepo := db.NewRefreshTokenRepo(dbMap) refreshRepo := db.NewRefreshTokenRepo(dbMap)
for _, user := range userUsers { for _, user := range userUsers {
if _, err := refreshRepo.Create(user.User.ID, testClientID, if _, err := refreshRepo.Create(user.User.ID, testClientID,
append([]string{"offline_access"}, oidc.DefaultScope...)); err != nil { "", append([]string{"offline_access"}, oidc.DefaultScope...)); err != nil {
panic("Failed to create refresh token: " + err.Error()) panic("Failed to create refresh token: " + err.Error())
} }
} }
......
...@@ -44,12 +44,12 @@ type RefreshTokenRepo interface { ...@@ -44,12 +44,12 @@ type RefreshTokenRepo interface {
// The scopes will be stored with the refresh token, and used to verify // The scopes will be stored with the refresh token, and used to verify
// against future OIDC refresh requests' scopes. // against future OIDC refresh requests' scopes.
// On success the token will be returned. // On success the token will be returned.
Create(userID, clientID string, scope []string) (string, error) Create(userID, clientID, connectorID string, scope []string) (string, error)
// Verify verifies that a token belongs to the client. // Verify verifies that a token belongs to the client.
// It returns the user ID to which the token belongs, and the scopes stored // It returns the user ID to which the token belongs, and the scopes stored
// with token. // with token.
Verify(clientID, token string) (string, scope.Scopes, error) Verify(clientID, token string) (userID, connectorID string, scope scope.Scopes, err error)
// Revoke deletes the refresh token if the token belongs to the given userID. // Revoke deletes the refresh token if the token belongs to the given userID.
Revoke(userID, token string) error Revoke(userID, token string) error
......
...@@ -6,6 +6,9 @@ const ( ...@@ -6,6 +6,9 @@ const (
// Scope prefix which indicates initiation of a cross-client authentication flow. // Scope prefix which indicates initiation of a cross-client authentication flow.
// See https://developers.google.com/identity/protocols/CrossClientAuth // See https://developers.google.com/identity/protocols/CrossClientAuth
ScopeGoogleCrossClient = "audience:server:client_id:" ScopeGoogleCrossClient = "audience:server:client_id:"
// ScopeGroups indicates that groups should be added to the ID Token.
ScopeGroups = "groups"
) )
type Scopes []string type Scopes []string
......
...@@ -421,6 +421,7 @@ func validateScopes(srv OIDCServer, clientID string, scopes []string) error { ...@@ -421,6 +421,7 @@ func validateScopes(srv OIDCServer, clientID string, scopes []string) error {
foundOpenIDScope = true foundOpenIDScope = true
case curScope == "profile": case curScope == "profile":
case curScope == "email": case curScope == "email":
case curScope == scope.ScopeGroups:
case curScope == "offline_access": case curScope == "offline_access":
// According to the spec, for offline_access scope, the client must // According to the spec, for offline_access scope, the client must
// use a response_type value that would result in an Authorization // use a response_type value that would result in an Authorization
......
...@@ -75,7 +75,8 @@ type Server struct { ...@@ -75,7 +75,8 @@ type Server struct {
OOBTemplate *template.Template OOBTemplate *template.Template
HealthChecks []health.Checkable HealthChecks []health.Checkable
Connectors []connector.Connector // TODO(ericchiang): Make this a map of ID to connector.
Connectors []connector.Connector
ClientRepo client.ClientRepo ClientRepo client.ClientRepo
ConnectorConfigRepo connector.ConnectorConfigRepo ConnectorConfigRepo connector.ConnectorConfigRepo
...@@ -306,6 +307,15 @@ func (s *Server) NewSession(ipdcID, clientID, clientState string, redirectURL ur ...@@ -306,6 +307,15 @@ func (s *Server) NewSession(ipdcID, clientID, clientState string, redirectURL ur
return s.SessionManager.NewSessionKey(sessionID) return s.SessionManager.NewSessionKey(sessionID)
} }
func (s *Server) connector(id string) (connector.Connector, bool) {
for _, c := range s.Connectors {
if c.ID() == id {
return c, true
}
}
return nil, false
}
func (s *Server) Login(ident oidc.Identity, key string) (string, error) { func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
sessionID, err := s.SessionManager.ExchangeKey(key) sessionID, err := s.SessionManager.ExchangeKey(key)
if err != nil { if err != nil {
...@@ -318,6 +328,29 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) { ...@@ -318,6 +328,29 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
} }
log.Infof("Session %s remote identity attached: clientID=%s identity=%#v", sessionID, ses.ClientID, ident) log.Infof("Session %s remote identity attached: clientID=%s identity=%#v", sessionID, ses.ClientID, ident)
// Get the connector used to log the user in.
conn, ok := s.connector(ses.ConnectorID)
if !ok {
return "", fmt.Errorf("session contained invalid connector ID (%s)", ses.ConnectorID)
}
// If the client has requested access to groups, add them here.
if ses.Scope.HasScope(scope.ScopeGroups) {
grouper, ok := conn.(connector.GroupsConnector)
if !ok {
return "", fmt.Errorf("scope %q provided but connector does not support groups", scope.ScopeGroups)
}
groups, err := grouper.Groups(ident.ID)
if err != nil {
return "", fmt.Errorf("failed to retrieve user groups for %q %v", ident.ID, err)
}
// Update the session.
if ses, err = s.SessionManager.AttachGroups(sessionID, groups); err != nil {
return "", fmt.Errorf("failed save groups")
}
}
if ses.Register { if ses.Register {
code, err := s.SessionManager.NewSessionKey(sessionID) code, err := s.SessionManager.NewSessionKey(sessionID)
if err != nil { if err != nil {
...@@ -334,18 +367,6 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) { ...@@ -334,18 +367,6 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
remoteIdentity := user.RemoteIdentity{ConnectorID: ses.ConnectorID, ID: ses.Identity.ID} remoteIdentity := user.RemoteIdentity{ConnectorID: ses.ConnectorID, ID: ses.Identity.ID}
// Get the connector used to log the user in.
var conn connector.Connector
for _, c := range s.Connectors {
if c.ID() == ses.ConnectorID {
conn = c
break
}
}
if conn == nil {
return "", fmt.Errorf("session contained invalid connector ID (%s)", ses.ConnectorID)
}
usr, err := s.UserRepo.GetByRemoteIdentity(nil, remoteIdentity) usr, err := s.UserRepo.GetByRemoteIdentity(nil, remoteIdentity)
if err == user.ErrorNotFound { if err == user.ErrorNotFound {
if ses.Identity.Email == "" { if ses.Identity.Email == "" {
...@@ -508,7 +529,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo ...@@ -508,7 +529,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
if scope == "offline_access" { if scope == "offline_access" {
log.Infof("Session %s requests offline access, will generate refresh token", sessionID) log.Infof("Session %s requests offline access, will generate refresh token", sessionID)
refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID, ses.Scope) refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID, ses.ConnectorID, ses.Scope)
switch err { switch err {
case nil: case nil:
break break
...@@ -535,7 +556,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, ...@@ -535,7 +556,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
return nil, oauth2.NewError(oauth2.ErrorInvalidClient) return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
} }
userID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token) userID, connectorID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token)
switch err { switch err {
case nil: case nil:
break break
...@@ -555,7 +576,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, ...@@ -555,7 +576,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
} }
} }
user, err := s.UserRepo.Get(nil, userID) usr, err := s.UserRepo.Get(nil, userID)
if err != nil { if err != nil {
// The error can be user.ErrorNotFound, but we are not deleting // The error can be user.ErrorNotFound, but we are not deleting
// user at this moment, so this shouldn't happen. // user at this moment, so this shouldn't happen.
...@@ -563,6 +584,43 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, ...@@ -563,6 +584,43 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
return nil, oauth2.NewError(oauth2.ErrorServerError) return nil, oauth2.NewError(oauth2.ErrorServerError)
} }
var groups []string
if rtScopes.HasScope(scope.ScopeGroups) {
conn, ok := s.connector(connectorID)
if !ok {
log.Errorf("refresh token contained invalid connector ID (%s)", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
grouper, ok := conn.(connector.GroupsConnector)
if !ok {
log.Errorf("refresh token requested groups for connector (%s) that doesn't support groups", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
remoteIdentities, err := s.UserRepo.GetRemoteIdentities(nil, userID)
if err != nil {
log.Errorf("failed to get remote identities: %v", err)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
remoteIdentity, ok := func() (user.RemoteIdentity, bool) {
for _, ri := range remoteIdentities {
if ri.ConnectorID == connectorID {
return ri, true
}
}
return user.RemoteIdentity{}, false
}()
if !ok {
log.Errorf("failed to get remote identity for connector %s", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
if groups, err = grouper.Groups(remoteIdentity.ID); err != nil {
log.Errorf("failed to get groups for refresh token: %v", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
}
signer, err := s.KeyManager.Signer() signer, err := s.KeyManager.Signer()
if err != nil { if err != nil {
log.Errorf("Failed to refresh ID token: %v", err) log.Errorf("Failed to refresh ID token: %v", err)
...@@ -572,8 +630,14 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, ...@@ -572,8 +630,14 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
now := time.Now() now := time.Now()
expireAt := now.Add(session.DefaultSessionValidityWindow) expireAt := now.Add(session.DefaultSessionValidityWindow)
claims := oidc.NewClaims(s.IssuerURL.String(), user.ID, creds.ID, now, expireAt) claims := oidc.NewClaims(s.IssuerURL.String(), usr.ID, creds.ID, now, expireAt)
user.AddToClaims(claims) usr.AddToClaims(claims)
if rtScopes.HasScope(scope.ScopeGroups) {
if groups == nil {
groups = []string{}
}
claims["groups"] = groups
}
s.addClaimsFromScope(claims, scope.Scopes(scopes), creds.ID) s.addClaimsFromScope(claims, scope.Scopes(scopes), creds.ID)
......
...@@ -785,8 +785,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -785,8 +785,7 @@ func TestServerRefreshToken(t *testing.T) {
t.Errorf("case %d: error creating other client: %v", i, err) t.Errorf("case %d: error creating other client: %v", i, err)
} }
if _, err := f.srv.RefreshTokenRepo.Create(testUserID1, tt.clientID, if _, err := f.srv.RefreshTokenRepo.Create(testUserID1, tt.clientID, "", tt.createScopes); err != nil {
tt.createScopes); err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
......
...@@ -144,6 +144,18 @@ func (m *SessionManager) AttachUser(sessionID string, userID string) (*session.S ...@@ -144,6 +144,18 @@ func (m *SessionManager) AttachUser(sessionID string, userID string) (*session.S
return s, nil return s, nil
} }
func (m *SessionManager) AttachGroups(sessionID string, groups []string) (*session.Session, error) {
s, err := m.sessions.Get(sessionID)
if err != nil {
return nil, err
}
s.Groups = groups
if err = m.sessions.Update(*s); err != nil {
return nil, err
}
return s, nil
}
func (m *SessionManager) Kill(sessionID string) (*session.Session, error) { func (m *SessionManager) Kill(sessionID string) (*session.Session, error) {
s, err := m.sessions.Get(sessionID) s, err := m.sessions.Get(sessionID)
if err != nil { if err != nil {
......
...@@ -55,6 +55,9 @@ type Session struct { ...@@ -55,6 +55,9 @@ type Session struct {
// Scope is the 'scope' field in the authentication request. Example scopes // Scope is the 'scope' field in the authentication request. Example scopes
// are 'openid', 'email', 'offline', etc. // are 'openid', 'email', 'offline', etc.
Scope scope.Scopes Scope scope.Scopes
// Groups the user belongs to.
Groups []string
} }
// Claims returns a new set of Claims for the current session. // Claims returns a new set of Claims for the current session.
...@@ -65,5 +68,8 @@ func (s *Session) Claims(issuerURL string) jose.Claims { ...@@ -65,5 +68,8 @@ func (s *Session) Claims(issuerURL string) jose.Claims {
if s.Nonce != "" { if s.Nonce != "" {
claims["nonce"] = s.Nonce claims["nonce"] = s.Nonce
} }
if s.Scope.HasScope(scope.ScopeGroups) {
claims["groups"] = s.Groups
}
return claims return claims
} }
...@@ -192,7 +192,7 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) { ...@@ -192,7 +192,7 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
} }
refreshRepo := db.NewRefreshTokenRepo(dbMap) refreshRepo := db.NewRefreshTokenRepo(dbMap)
for _, token := range refreshTokens { for _, token := range refreshTokens {
if _, err := refreshRepo.Create(token.userID, token.clientID, []string{"openid"}); err != nil { if _, err := refreshRepo.Create(token.userID, token.clientID, "local", []string{"openid"}); err != nil {
panic("Failed to create refresh token: " + err.Error()) panic("Failed to create refresh token: " + err.Error())
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment