Commit edb010ca authored by Eric Chiang's avatar Eric Chiang Committed by GitHub

Merge pull request #510 from ericchiang/add-groups-scope-and-ldap-implementation

Add groups scope and LDAP implementation
parents af6aade6 607d9920
......@@ -56,3 +56,19 @@ For situations in which an app does not have access to a browser, the out-of-ban
\* In OpenID Connect a client is called a "Relying Party", but "client" seems to
be the more common ter, has been around longer and is present in paramter names
like "client_id" so we prefer it over "Relying Party" usually.
## Groups
Connectors that support groups (currently only the LDAP connector) can embed the groups a user belongs to in the ID Token. Using the scope "groups" during the initial redirect with a connector that supports groups will return an JWT with the following field.
```
{
"groups": [
"cn=ipausers,cn=groups,cn=accounts,dc=example,dc=com,
"cn=team-engineering,cn=groups,cn=accounts,dc=example,dc=com"
],
...
}
```
If the client has also requested a refresh token, the groups field is updated during each refresh request.
......@@ -153,6 +153,7 @@ In addition to `id` and `type`, the `ldap` connector takes the following additio
* emailAttribute: a `string`. Required. Attribute to map to Email. Default: `mail`
* searchBeforeAuth: a `boolean`. Perform search for entryDN to be used for bind.
* searchFilter: a `string`. Filter to apply to search. Variable substititions: `%u` User supplied username/e-mail address. `%b` BaseDN. Searches that return multiple entries are considered ambiguous and will return an error.
* searchGroupFilter: a `string`. A filter which should return group entry for a given user. The string is formatted the same as `searchFilter`, execpt `%u` is replaced by the fully qualified user entry. Groups are only searched if the client request the "groups" scope.
* searchScope: a `string`. Scope of the search. `base|one|sub`. Default: `one`
* searchBindDN: a `string`. DN to bind as for search operations.
* searchBindPw: a `string`. Password for bind for search operations.
......@@ -180,19 +181,20 @@ uid=janedoe,cn=users,cn=accounts,dc=auth,dc=example,dc=com
The connector then attempts to bind as this entry using the password provided by the end user.
### Example: Searching the directory
### Example: Searching a FreeIPA server with groups
The following configuration will search a directory using an LDAP filter. With FreeIPA
The following configuration will search a FreeIPA directory using an LDAP filter.
```
{
"type": "ldap",
"id": "ldap",
"host": "127.0.0.1:389",
"baseDN": "cn=auth,dc=example,dc=com",
"baseDN": "cn=accounts,dc=example,dc=com",
"searchBeforeAuth": true,
"searchFilter": "(&(objectClass=person)(uid=%u))",
"searchGroupFilter": "(&(objectClass=ipausergroup)(member=%u))",
"searchScope": "sub",
"searchBindDN": "serviceAccountUser",
......@@ -206,9 +208,15 @@ The following configuration will search a directory using an LDAP filter. With F
(&(objectClass=person)(uid=janedoe))
```
If the search finds an entry, it will attempt to use the provided password to bind as that entry.
If the search finds an entry, it will attempt to use the provided password to bind as that entry. Searches that return multiple entries are considered ambiguous and will return an error.
__NOTE__: Searches that return multiple entries will return an error.
"searchGroupFilter" is a format string similar to "searchFilter" except `%u` is replaced by the fully qualified user entry returned by "searchFilter". So if the initial search returns "uid=janedoe,cn=users,cn=accounts,dc=example,dc=com", the connector will use the search query:
```
(&(objectClass=ipausergroup)(member=uid=janedoe,cn=users,cn=accounts,dc=example,dc=com))
```
If the client requests the "groups" scope, the names of all returned entries are added to the ID Token "groups" claim.
## Setting the Configuration
......
......@@ -107,11 +107,12 @@ type LDAPConnector struct {
nameAttribute string
emailAttribute string
searchBeforeAuth bool
searchFilter string
searchScope int
searchBindDN string
searchBindPw string
searchBeforeAuth bool
searchFilter string
searchScope int
searchBindDN string
searchBindPw string
searchGroupFilter string
bindTemplate string
......@@ -203,19 +204,20 @@ func (cfg *LDAPConnectorConfig) Connector(ns url.URL, lf oidc.LoginFunc, tpls *t
}
idpc := &LDAPConnector{
id: cfg.ID,
namespace: ns,
loginFunc: lf,
loginTpl: tpl,
baseDN: cfg.BaseDN,
nameAttribute: cfg.NameAttribute,
emailAttribute: cfg.EmailAttribute,
searchBeforeAuth: cfg.SearchBeforeAuth,
searchFilter: cfg.SearchFilter,
searchScope: searchScope,
searchBindDN: cfg.SearchBindDN,
searchBindPw: cfg.SearchBindPw,
bindTemplate: cfg.BindTemplate,
id: cfg.ID,
namespace: ns,
loginFunc: lf,
loginTpl: tpl,
baseDN: cfg.BaseDN,
nameAttribute: cfg.NameAttribute,
emailAttribute: cfg.EmailAttribute,
searchBeforeAuth: cfg.SearchBeforeAuth,
searchFilter: cfg.SearchFilter,
searchGroupFilter: cfg.SearchGroupFilter,
searchScope: searchScope,
searchBindDN: cfg.SearchBindDN,
searchBindPw: cfg.SearchBindPw,
bindTemplate: cfg.BindTemplate,
ldapPool: &LDAPPool{
MaxIdleConn: cfg.MaxIdleConn,
PoolCheckTimer: defaultPoolCheckTimer,
......@@ -433,12 +435,47 @@ func invalidBindCredentials(err error) bool {
func (c *LDAPConnector) formatDN(template, username string) string {
result := template
result = strings.Replace(result, "%u", username, -1)
result = strings.Replace(result, "%u", ldap.EscapeFilter(username), -1)
result = strings.Replace(result, "%b", c.baseDN, -1)
return result
}
func (c *LDAPConnector) Groups(fullUserID string) ([]string, error) {
if !c.searchBeforeAuth {
return nil, fmt.Errorf("cannot search without service account")
}
if c.searchGroupFilter == "" {
return nil, fmt.Errorf("no group filter specified")
}
var groups []string
err := c.ldapPool.Do(func(conn *ldap.Conn) error {
if err := conn.Bind(c.searchBindDN, c.searchBindPw); err != nil {
if !invalidBindCredentials(err) {
log.Errorf("failed to connect to LDAP for search bind: %v", err)
}
return fmt.Errorf("failed to bind: %v", err)
}
req := &ldap.SearchRequest{
BaseDN: c.baseDN,
Scope: c.searchScope,
Filter: c.formatDN(c.searchGroupFilter, fullUserID),
}
resp, err := conn.Search(req)
if err != nil {
return fmt.Errorf("search failed: %v", err)
}
groups = make([]string, len(resp.Entries))
for i, entry := range resp.Entries {
groups[i] = entry.DN
}
return nil
})
return groups, err
}
func (c *LDAPConnector) Identity(username, password string) (*oidc.Identity, error) {
var (
identity *oidc.Identity
......@@ -447,8 +484,10 @@ func (c *LDAPConnector) Identity(username, password string) (*oidc.Identity, err
if c.searchBeforeAuth {
err = c.ldapPool.Do(func(conn *ldap.Conn) error {
if err := conn.Bind(c.searchBindDN, c.searchBindPw); err != nil {
// Don't wrap error as it may be a specific LDAP error.
return err
if !invalidBindCredentials(err) {
log.Errorf("failed to connect to LDAP for search bind: %v", err)
}
return fmt.Errorf("failed to bind: %v", err)
}
filter := c.formatDN(c.searchFilter, username)
......@@ -491,8 +530,10 @@ func (c *LDAPConnector) Identity(username, password string) (*oidc.Identity, err
err = c.ldapPool.Do(func(conn *ldap.Conn) error {
userBindDN := c.formatDN(c.bindTemplate, username)
if err := conn.Bind(userBindDN, password); err != nil {
// Don't wrap error as it may be a specific LDAP error.
return err
if !invalidBindCredentials(err) {
log.Errorf("failed to connect to LDAP for search bind: %v", err)
}
return fmt.Errorf("failed to bind: %v", err)
}
req := &ldap.SearchRequest{
......@@ -522,11 +563,7 @@ func (c *LDAPConnector) Identity(username, password string) (*oidc.Identity, err
return nil
})
}
if err != nil {
if !invalidBindCredentials(err) {
log.Errorf("failed to connect to LDAP for search bind: %v", err)
}
return nil, err
}
return identity, nil
......
......@@ -60,6 +60,12 @@ type ConnectorConfig interface {
Connector(ns url.URL, loginFunc oidc.LoginFunc, tpls *template.Template) (Connector, error)
}
// GroupsConnector is a strategy for mapping a user to a set of groups. This is optionally
// implemented by some connectors.
type GroupsConnector interface {
Groups(fullUserID string) ([]string, error)
}
type ConnectorConfigRepo interface {
All() ([]ConnectorConfig, error)
GetConnectorByID(repo.Transaction, string) (ConnectorConfig, error)
......
......@@ -41,6 +41,7 @@ CREATE TABLE refresh_token (
payload_hash blob,
user_id text,
client_id text,
connector_id text,
scopes text
);
......@@ -63,7 +64,8 @@ CREATE TABLE session (
user_id text,
register integer,
nonce text,
scope text
scope text,
groups text
);
CREATE TABLE session_key (
......
-- +migrate Up
ALTER TABLE refresh_token ADD COLUMN "connector_id" text;
ALTER TABLE session ADD COLUMN "groups" text;
......@@ -90,5 +90,11 @@ var PostgresMigrations migrate.MigrationSource = &migrate.MemoryMigrationSource{
"-- +migrate Up\nALTER TABLE refresh_token ADD COLUMN \"scopes\" text;\n\nUPDATE refresh_token SET scopes = 'openid profile email offline_access';\n",
},
},
{
Id: "0014_add_groups.sql",
Up: []string{
"-- +migrate Up\nALTER TABLE refresh_token ADD COLUMN \"connector_id\" text;\nALTER TABLE session ADD COLUMN \"groups\" text;\n",
},
},
},
}
......@@ -41,6 +41,7 @@ type refreshTokenModel struct {
PayloadHash []byte `db:"payload_hash"`
UserID string `db:"user_id"`
ClientID string `db:"client_id"`
ConnectorID string `db:"connector_id"`
Scopes string `db:"scopes"`
}
......@@ -89,7 +90,7 @@ func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenG
}
}
func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (string, error) {
func (r *refreshTokenRepo) Create(userID, clientID, connectorID string, scopes []string) (string, error) {
if userID == "" {
return "", refresh.ErrorInvalidUserID
}
......@@ -112,6 +113,7 @@ func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (str
PayloadHash: payloadHash,
UserID: userID,
ClientID: clientID,
ConnectorID: connectorID,
Scopes: strings.Join(scopes, " "),
}
......@@ -122,24 +124,24 @@ func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (str
return buildToken(record.ID, tokenPayload), nil
}
func (r *refreshTokenRepo) Verify(clientID, token string) (string, scope.Scopes, error) {
func (r *refreshTokenRepo) Verify(clientID, token string) (userID, connectorID string, scope scope.Scopes, err error) {
tokenID, tokenPayload, err := parseToken(token)
if err != nil {
return "", nil, err
return
}
record, err := r.get(nil, tokenID)
if err != nil {
return "", nil, err
return
}
if record.ClientID != clientID {
return "", nil, refresh.ErrorInvalidClientID
return "", "", nil, refresh.ErrorInvalidClientID
}
if err := checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
return "", nil, err
if err = checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
return
}
var scopes []string
......@@ -147,7 +149,7 @@ func (r *refreshTokenRepo) Verify(clientID, token string) (string, scope.Scopes,
scopes = strings.Split(record.Scopes, " ")
}
return record.UserID, scopes, nil
return record.UserID, record.ConnectorID, scopes, nil
}
func (r *refreshTokenRepo) Revoke(userID, token string) error {
......
......@@ -44,6 +44,7 @@ type sessionModel struct {
Register bool `db:"register"`
Nonce string `db:"nonce"`
Scope string `db:"scope"`
Groups string `db:"groups"`
}
func (s *sessionModel) session() (*session.Session, error) {
......@@ -75,6 +76,11 @@ func (s *sessionModel) session() (*session.Session, error) {
Nonce: s.Nonce,
Scope: strings.Fields(s.Scope),
}
if s.Groups != "" {
if err := json.Unmarshal([]byte(s.Groups), &ses.Groups); err != nil {
return nil, fmt.Errorf("failed to decode groups in session: %v", err)
}
}
if s.CreatedAt != 0 {
ses.CreatedAt = time.Unix(s.CreatedAt, 0).UTC()
......@@ -107,6 +113,14 @@ func newSessionModel(s *session.Session) (*sessionModel, error) {
Scope: strings.Join(s.Scope, " "),
}
if s.Groups != nil {
data, err := json.Marshal(s.Groups)
if err != nil {
return nil, fmt.Errorf("failed to marshal groups: %v", err)
}
sm.Groups = string(data)
}
if !s.CreatedAt.IsZero() {
sm.CreatedAt = s.CreatedAt.Unix()
}
......
......@@ -68,7 +68,7 @@ func (fi bindataFileInfo) Sys() interface{} {
return nil
}
var _dataIndexHtml = []byte("\x1f\x8b\x08\x00\x00\x09\x6e\x88\x00\xff\x94\x52\xcd\x4e\xc3\x30\x0c\xbe\xef\x29\xac\x9c\xe0\x30\x7a\x47\x6d\x25\x40\xdc\x90\x26\xf1\x02\x53\x9a\x78\x6d\xb4\xfc\x4c\x89\x8b\x36\x4d\x7b\x77\xdc\x96\xae\x5b\x81\x09\x6e\xfe\x14\xfb\xfb\x89\x9d\x37\xe4\x6c\xb9\x00\xc8\xab\xa0\x0f\xe5\x82\x2b\xae\x37\x21\x3a\x90\x8a\x4c\xf0\x85\xc8\x6c\xa8\x8d\x17\x65\xff\xc4\x8f\x24\x2b\x8b\x23\xea\x70\x9c\x40\x07\x75\x09\x4f\x2d\x35\xe8\xc9\x28\x49\x08\x4c\xf6\x78\xd1\xd0\x49\x5d\x4d\x00\xdc\xa9\xe0\x9c\x5c\x26\xdc\xc9\xc8\x13\x1a\xac\x49\x04\x61\x03\xca\x1a\xa6\x59\x1a\x9d\xee\x2f\x25\x32\xd6\x98\x4b\xe6\xc6\xef\x5a\x02\x3a\xec\xb0\x10\x84\x7b\x12\xe0\xa5\xe3\x5a\xc5\x90\xd2\x7a\x60\x12\x50\xce\xa6\x19\x9d\xcd\x70\x3d\x44\x3b\x1e\xc1\x6c\xe0\x61\xb5\x7a\x86\xd3\x69\x6a\xbd\x54\x48\x6d\xe5\x0c\xf3\x7d\x48\xdb\x32\x7c\xeb\xbf\xa8\x8b\xea\x48\xc6\x1a\xa9\x10\xeb\xca\x4a\xbf\x15\x3d\x1b\xda\x84\xff\xa4\x1a\xe6\xbc\x1e\xc7\xf2\xac\x23\xe7\x05\x7d\x37\x37\x5b\x97\x92\xd6\x56\x52\x6d\x05\x38\xa4\x26\xe8\x42\xb0\x9f\x8e\x70\xd0\x7e\x09\x1a\x17\x3f\xd8\xb8\xfa\x33\xee\x39\x1b\x9a\x36\x3f\xed\xed\x56\x80\xd7\xbd\x6a\xa4\xaf\xb1\x57\x1a\x75\x47\xfb\xd7\xa1\xbe\xc2\xf8\x40\xb7\x02\x45\xac\xf9\x1e\x30\x8a\xbf\xa8\xbf\x8f\xcd\x00\xd9\xef\xd2\x79\x36\x9c\x7b\x9e\x0d\xf7\xff\x19\x00\x00\xff\xff\xaf\x0b\xca\x75\x07\x03\x00\x00")
var _dataIndexHtml = []byte("\x1f\x8b\x08\x00\x00\x09\x6e\x88\x00\xff\x94\x93\xcf\x8a\xe3\x30\x0c\xc6\xef\x7d\x0a\xe1\x7b\x37\xf7\xc5\x29\xec\x0e\xbd\x0d\x14\xe6\x05\x8a\x63\xab\x89\xa9\xff\x61\x2b\x43\x4b\xe9\xbb\x8f\x53\x37\x61\x52\xd2\xa1\x73\x93\xd1\x27\x7d\x3f\x49\x98\x77\x64\xcd\x66\x05\xc0\x1b\xaf\xce\x43\x90\xc3\x83\x8f\x16\x84\x24\xed\x5d\xcd\x2a\xe3\x5b\xed\x58\x49\x0d\xd9\x30\x85\x00\xff\x7a\xea\xd0\x91\x96\x82\x10\x72\xd9\x5f\xae\x5d\xe8\x09\xe8\x1c\xb0\x66\x84\x27\x62\xe0\x84\xcd\xb1\x8c\x3e\xa5\xbd\x34\x3a\xcb\x19\x04\x23\x24\x76\xde\x28\x8c\x39\xe5\xad\x15\xeb\x84\x41\xc4\xdc\x46\x81\xd1\x89\xc0\x1f\xa0\x88\xd7\x5a\xa5\x6f\xee\x55\x58\x26\xd9\x9e\x28\x0a\x48\xd2\x07\x4c\xcf\x29\x70\x50\xed\x8b\xea\x45\x8a\xbb\x78\x4e\x70\xb9\x80\x3e\xc0\x9f\xdd\xee\x3f\x5c\xaf\x13\xc4\xcc\x36\xf5\x8d\xd5\xd9\xf8\x53\x98\x3e\x3f\xdf\x6f\x5b\x1c\x76\x64\x49\xc4\x16\xa9\x66\xfb\xc6\x08\x77\x64\xb7\x6e\x68\x12\xfe\xb2\x55\xa9\x73\x6a\x2c\xe3\xd5\xd0\x7c\xb3\x5a\x80\x7b\xb8\xa8\x14\xc6\x34\x42\x1e\x19\x58\xa4\xce\xab\x9a\x65\x9e\xa1\x61\xf1\x7e\xf3\x0a\x57\x0b\x18\xb3\x73\x66\xcd\x04\x34\x2d\x87\x37\x71\xb3\x54\xf9\x30\xc0\xf6\x24\x3b\xe1\x5a\xbc\x39\x8d\xbe\x23\xfe\x7c\xa8\xfb\x30\xce\xd3\x4f\x03\x45\x6c\xf3\xb5\x30\xb2\x57\xdc\x3f\x46\x31\x40\xf5\xdc\x9a\x57\xe5\x43\xf0\xaa\xfc\x90\xaf\x00\x00\x00\xff\xff\x9c\x89\xe2\x28\x29\x03\x00\x00")
func dataIndexHtmlBytes() ([]byte, error) {
return bindataRead(
......@@ -83,7 +83,7 @@ func dataIndexHtml() (*asset, error) {
return nil, err
}
info := bindataFileInfo{name: "data/index.html", size: 775, mode: os.FileMode(420), modTime: time.Unix(1466378108, 0)}
info := bindataFileInfo{name: "data/index.html", size: 809, mode: os.FileMode(436), modTime: time.Unix(1468620773, 0)}
a := &asset{bytes: bytes, info: info}
return a, nil
}
......
<html>
<body>
<form action="/login">
<table>
<tr>
<td> Authenticate for:
<br>
(comma-separated list of client-ids)
</td>
<td> <input type="text" name="cross_client" > </td>
</tr>
</table>
<p>
Authenticate for:<input type="text" name="cross_client" placeholder="comma-separated list of client-ids">
</p>
<p>
Extra scopes:<input type="text" name="extra_scopes" placeholder="comma-separated list of scopes">
</p>
{{ if .OOB }}
<input type="submit" value="Login" formtarget="_blank">
{{ else }}
......
......@@ -218,18 +218,25 @@ func handleLoginFunc(c *oidc.Client) http.HandlerFunc {
panic("unable to proceed")
}
xClient := r.Form.Get("cross_client")
if xClient != "" {
var scopes []string
q := u.Query()
if scope := q.Get("scope"); scope != "" {
scopes = strings.Split(scope, " ")
}
if xClient := r.Form.Get("cross_client"); xClient != "" {
xClients := strings.Split(xClient, ",")
for i, x := range xClients {
xClients[i] = scope.ScopeGoogleCrossClient + x
for _, x := range xClients {
scopes = append(scopes, scope.ScopeGoogleCrossClient+x)
}
q := u.Query()
scope := q.Get("scope")
scopes := strings.Split(scope, " ")
scopes = append(scopes, xClients...)
scope = strings.Join(scopes, " ")
q.Set("scope", scope)
}
if extraScopes := r.Form.Get("extra_scopes"); extraScopes != "" {
scopes = append(scopes, strings.Split(extraScopes, ",")...)
}
if scopes != nil {
q.Set("scope", strings.Join(scopes, " "))
u.RawQuery = q.Encode()
}
......@@ -292,57 +299,69 @@ func handleResendFunc(c *oidc.Client, issuerURL, resendURL, cbURL url.URL) http.
func handleCallbackFunc(c *oidc.Client) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
refreshToken := r.URL.Query().Get("refresh_token")
code := r.URL.Query().Get("code")
if code == "" {
phttp.WriteError(w, http.StatusBadRequest, "code query param must be set")
return
}
tokens, err := exchangeAuthCode(c, code)
oac, err := c.OAuthClient()
if err != nil {
phttp.WriteError(w, http.StatusBadRequest,
fmt.Sprintf("unable to verify auth code with issuer: %v", err))
phttp.WriteError(w, http.StatusBadRequest, fmt.Sprintf("unable to create OAuth2 client: %v", err))
return
}
tok, err := jose.ParseJWT(tokens.IDToken)
if err != nil {
phttp.WriteError(w, http.StatusBadRequest,
fmt.Sprintf("unable to parse JWT: %v", err))
var token oauth2.TokenResponse
switch {
case code != "":
if token, err = oac.RequestToken(oauth2.GrantTypeAuthCode, code); err != nil {
phttp.WriteError(w, http.StatusBadRequest, fmt.Sprintf("unable to verify auth code with issuer: %v", err))
return
}
case refreshToken != "":
if token, err = oac.RequestToken(oauth2.GrantTypeRefreshToken, refreshToken); err != nil {
phttp.WriteError(w, http.StatusBadRequest, fmt.Sprintf("unable to refresh token: %v", err))
return
}
if token.RefreshToken == "" {
token.RefreshToken = refreshToken
}
default:
phttp.WriteError(w, http.StatusBadRequest, "code query param must be set")
return
}
claims, err := tok.Claims()
tok, err := jose.ParseJWT(token.IDToken)
if err != nil {
phttp.WriteError(w, http.StatusBadRequest,
fmt.Sprintf("unable to construct claims: %v", err))
phttp.WriteError(w, http.StatusBadRequest, fmt.Sprintf("unable to parse JWT: %v", err))
return
}
claims := new(bytes.Buffer)
if err := json.Indent(claims, tok.Payload, "", " "); err != nil {
phttp.WriteError(w, http.StatusBadRequest, fmt.Sprintf("unable to construct claims: %v", err))
return
}
s := fmt.Sprintf(`
<html>
<head>
<style>
/* make pre wrap */
pre {
white-space: pre-wrap; /* css-3 */
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
white-space: -pre-wrap; /* Opera 4-6 */
white-space: -o-pre-wrap; /* Opera 7 */
word-wrap: break-word; /* Internet Explorer 5.5+ */
}
</style>
</head>
<body>
<p> Token: %v</p>
<p> Claims: %v </p>
<a href="/resend?jwt=%s">Resend Verification Email</a>
<p> Refresh Token: %v </p>
<p> Token: <pre><code>%v</code></pre></p>
<p> Claims: <pre><code>%v</code></pre></p>
<p> Refresh Token: <pre><code>%v</code></pre></p>
<p><a href="%s?refresh_token=%s">Redeem refresh token</a><p>
<p><a href="/resend?jwt=%s">Resend Verification Email</a></p>
</body>
</html>`, tok.Encode(), claims, tok.Encode(), tokens.RefreshToken)
</html>`, tok.Encode(), claims.String(), token.RefreshToken, r.URL.Path, token.RefreshToken, tok.Encode())
w.Write([]byte(s))
}
}
func exchangeAuthCode(c *oidc.Client, code string) (oauth2.TokenResponse, error) {
oac, err := c.OAuthClient()
if err != nil {
return oauth2.TokenResponse{}, err
}
t, err := oac.RequestToken(oauth2.GrantTypeAuthCode, code)
if err != nil {
return oauth2.TokenResponse{}, err
}
return t, nil
}
......@@ -20,7 +20,10 @@ import (
var (
testRefreshClientID = "client1"
testRefreshClientID2 = "client2"
testRefreshClients = []client.LoadableClient{
testRefreshConnectorID = "IDPC-1"
testRefreshClients = []client.LoadableClient{
{
Client: client.Client{
Credentials: oidc.ClientCredentials{
......@@ -59,7 +62,7 @@ var (
},
RemoteIdentities: []user.RemoteIdentity{
{
ConnectorID: "IDPC-1",
ConnectorID: testRefreshConnectorID,
ID: "RID-1",
},
},
......@@ -103,12 +106,12 @@ func TestRefreshTokenRepoCreateVerify(t *testing.T) {
for i, tt := range tests {
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
tok, err := repo.Create(testRefreshUserID, testRefreshClientID, tt.createScopes)
tok, err := repo.Create(testRefreshUserID, testRefreshClientID, testRefreshConnectorID, tt.createScopes)
if err != nil {
t.Fatalf("case %d: failed to create refresh token: %v", i, err)
}
tokUserID, gotScopes, err := repo.Verify(tt.verifyClientID, tok)
tokUserID, gotConnectorID, gotScopes, err := repo.Verify(tt.verifyClientID, tok)
if tt.wantVerifyErr {
if err == nil {
t.Errorf("case %d: want non-nil error.", i)
......@@ -126,6 +129,10 @@ func TestRefreshTokenRepoCreateVerify(t *testing.T) {
t.Errorf("case %d: Verified token returned wrong user id, want=%s, got=%s", i,
testRefreshUserID, tokUserID)
}
if gotConnectorID != testRefreshConnectorID {
t.Errorf("case %d: wanted connector_id=%q got=%q", i, testRefreshConnectorID, gotConnectorID)
}
}
}
......@@ -138,7 +145,7 @@ func buildRefreshToken(tokenID int64, tokenPayload []byte) string {
func TestRefreshRepoVerifyInvalidTokens(t *testing.T) {
r := db.NewRefreshTokenRepo(connect(t))
token, err := r.Create("user-foo", "client-foo", oidc.DefaultScope)
token, err := r.Create("user-foo", "client-foo", testRefreshConnectorID, oidc.DefaultScope)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
......@@ -209,7 +216,7 @@ func TestRefreshRepoVerifyInvalidTokens(t *testing.T) {
}
for i, tt := range tests {
result, _, err := r.Verify(tt.creds.ID, tt.token)
result, _, _, err := r.Verify(tt.creds.ID, tt.token)
if err != tt.err {
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
}
......@@ -232,7 +239,7 @@ func TestRefreshTokenRepoClientsWithRefreshTokens(t *testing.T) {
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
for _, clientID := range tt.clientIDs {
_, err := repo.Create(testRefreshUserID, clientID, []string{"openid"})
_, err := repo.Create(testRefreshUserID, clientID, testRefreshConnectorID, []string{"openid"})
if err != nil {
t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err)
}
......@@ -281,7 +288,7 @@ func TestRefreshTokenRepoRevokeForClient(t *testing.T) {
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
for _, clientID := range tt.createIDs {
_, err := repo.Create(testRefreshUserID, clientID, []string{"openid"})
_, err := repo.Create(testRefreshUserID, clientID, testRefreshConnectorID, []string{"openid"})
if err != nil {
t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err)
}
......@@ -318,7 +325,7 @@ func TestRefreshTokenRepoRevokeForClient(t *testing.T) {
func TestRefreshRepoRevoke(t *testing.T) {
r := db.NewRefreshTokenRepo(connect(t))
token, err := r.Create("user-foo", "client-foo", oidc.DefaultScope)
token, err := r.Create("user-foo", "client-foo", testRefreshConnectorID, oidc.DefaultScope)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
......
......@@ -104,6 +104,13 @@ func TestSessionRepoCreateGet(t *testing.T) {
ExpiresAt: time.Unix(789, 0).UTC(),
Nonce: "oncenay",
},
session.Session{
ID: "anID",
ClientState: "blargh",
ExpiresAt: time.Unix(789, 0).UTC(),
Nonce: "oncenay",
Groups: []string{"group1", "group2"},
},
}
for i, tt := range tests {
......
......@@ -149,7 +149,7 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
refreshRepo := db.NewRefreshTokenRepo(dbMap)
for _, user := range userUsers {
if _, err := refreshRepo.Create(user.User.ID, testClientID,
append([]string{"offline_access"}, oidc.DefaultScope...)); err != nil {
"", append([]string{"offline_access"}, oidc.DefaultScope...)); err != nil {
panic("Failed to create refresh token: " + err.Error())
}
}
......
......@@ -44,12 +44,12 @@ type RefreshTokenRepo interface {
// The scopes will be stored with the refresh token, and used to verify
// against future OIDC refresh requests' scopes.
// On success the token will be returned.
Create(userID, clientID string, scope []string) (string, error)
Create(userID, clientID, connectorID string, scope []string) (string, error)
// Verify verifies that a token belongs to the client.
// It returns the user ID to which the token belongs, and the scopes stored
// with token.
Verify(clientID, token string) (string, scope.Scopes, error)
Verify(clientID, token string) (userID, connectorID string, scope scope.Scopes, err error)
// Revoke deletes the refresh token if the token belongs to the given userID.
Revoke(userID, token string) error
......
......@@ -6,6 +6,9 @@ const (
// Scope prefix which indicates initiation of a cross-client authentication flow.
// See https://developers.google.com/identity/protocols/CrossClientAuth
ScopeGoogleCrossClient = "audience:server:client_id:"
// ScopeGroups indicates that groups should be added to the ID Token.
ScopeGroups = "groups"
)
type Scopes []string
......
......@@ -421,6 +421,7 @@ func validateScopes(srv OIDCServer, clientID string, scopes []string) error {
foundOpenIDScope = true
case curScope == "profile":
case curScope == "email":
case curScope == scope.ScopeGroups:
case curScope == "offline_access":
// According to the spec, for offline_access scope, the client must
// use a response_type value that would result in an Authorization
......
......@@ -75,7 +75,8 @@ type Server struct {
OOBTemplate *template.Template
HealthChecks []health.Checkable
Connectors []connector.Connector
// TODO(ericchiang): Make this a map of ID to connector.
Connectors []connector.Connector
ClientRepo client.ClientRepo
ConnectorConfigRepo connector.ConnectorConfigRepo
......@@ -306,6 +307,15 @@ func (s *Server) NewSession(ipdcID, clientID, clientState string, redirectURL ur
return s.SessionManager.NewSessionKey(sessionID)
}
func (s *Server) connector(id string) (connector.Connector, bool) {
for _, c := range s.Connectors {
if c.ID() == id {
return c, true
}
}
return nil, false
}
func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
sessionID, err := s.SessionManager.ExchangeKey(key)
if err != nil {
......@@ -318,6 +328,29 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
}
log.Infof("Session %s remote identity attached: clientID=%s identity=%#v", sessionID, ses.ClientID, ident)
// Get the connector used to log the user in.
conn, ok := s.connector(ses.ConnectorID)
if !ok {
return "", fmt.Errorf("session contained invalid connector ID (%s)", ses.ConnectorID)
}
// If the client has requested access to groups, add them here.
if ses.Scope.HasScope(scope.ScopeGroups) {
grouper, ok := conn.(connector.GroupsConnector)
if !ok {
return "", fmt.Errorf("scope %q provided but connector does not support groups", scope.ScopeGroups)
}
groups, err := grouper.Groups(ident.ID)
if err != nil {
return "", fmt.Errorf("failed to retrieve user groups for %q %v", ident.ID, err)
}
// Update the session.
if ses, err = s.SessionManager.AttachGroups(sessionID, groups); err != nil {
return "", fmt.Errorf("failed save groups")
}
}
if ses.Register {
code, err := s.SessionManager.NewSessionKey(sessionID)
if err != nil {
......@@ -334,18 +367,6 @@ func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
remoteIdentity := user.RemoteIdentity{ConnectorID: ses.ConnectorID, ID: ses.Identity.ID}
// Get the connector used to log the user in.
var conn connector.Connector
for _, c := range s.Connectors {
if c.ID() == ses.ConnectorID {
conn = c
break
}
}
if conn == nil {
return "", fmt.Errorf("session contained invalid connector ID (%s)", ses.ConnectorID)
}
usr, err := s.UserRepo.GetByRemoteIdentity(nil, remoteIdentity)
if err == user.ErrorNotFound {
if ses.Identity.Email == "" {
......@@ -508,7 +529,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
if scope == "offline_access" {
log.Infof("Session %s requests offline access, will generate refresh token", sessionID)
refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID, ses.Scope)
refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID, ses.ConnectorID, ses.Scope)
switch err {
case nil:
break
......@@ -535,7 +556,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
}
userID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token)
userID, connectorID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token)
switch err {
case nil:
break
......@@ -555,7 +576,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
}
}
user, err := s.UserRepo.Get(nil, userID)
usr, err := s.UserRepo.Get(nil, userID)
if err != nil {
// The error can be user.ErrorNotFound, but we are not deleting
// user at this moment, so this shouldn't happen.
......@@ -563,6 +584,43 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
var groups []string
if rtScopes.HasScope(scope.ScopeGroups) {
conn, ok := s.connector(connectorID)
if !ok {
log.Errorf("refresh token contained invalid connector ID (%s)", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
grouper, ok := conn.(connector.GroupsConnector)
if !ok {
log.Errorf("refresh token requested groups for connector (%s) that doesn't support groups", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
remoteIdentities, err := s.UserRepo.GetRemoteIdentities(nil, userID)
if err != nil {
log.Errorf("failed to get remote identities: %v", err)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
remoteIdentity, ok := func() (user.RemoteIdentity, bool) {
for _, ri := range remoteIdentities {
if ri.ConnectorID == connectorID {
return ri, true
}
}
return user.RemoteIdentity{}, false
}()
if !ok {
log.Errorf("failed to get remote identity for connector %s", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
if groups, err = grouper.Groups(remoteIdentity.ID); err != nil {
log.Errorf("failed to get groups for refresh token: %v", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
}
signer, err := s.KeyManager.Signer()
if err != nil {
log.Errorf("Failed to refresh ID token: %v", err)
......@@ -572,8 +630,14 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
now := time.Now()
expireAt := now.Add(session.DefaultSessionValidityWindow)
claims := oidc.NewClaims(s.IssuerURL.String(), user.ID, creds.ID, now, expireAt)
user.AddToClaims(claims)
claims := oidc.NewClaims(s.IssuerURL.String(), usr.ID, creds.ID, now, expireAt)
usr.AddToClaims(claims)
if rtScopes.HasScope(scope.ScopeGroups) {
if groups == nil {
groups = []string{}
}
claims["groups"] = groups
}
s.addClaimsFromScope(claims, scope.Scopes(scopes), creds.ID)
......
......@@ -785,8 +785,7 @@ func TestServerRefreshToken(t *testing.T) {
t.Errorf("case %d: error creating other client: %v", i, err)
}
if _, err := f.srv.RefreshTokenRepo.Create(testUserID1, tt.clientID,
tt.createScopes); err != nil {
if _, err := f.srv.RefreshTokenRepo.Create(testUserID1, tt.clientID, "", tt.createScopes); err != nil {
t.Fatalf("Unexpected error: %v", err)
}
......
......@@ -144,6 +144,18 @@ func (m *SessionManager) AttachUser(sessionID string, userID string) (*session.S
return s, nil
}
func (m *SessionManager) AttachGroups(sessionID string, groups []string) (*session.Session, error) {
s, err := m.sessions.Get(sessionID)
if err != nil {
return nil, err
}
s.Groups = groups
if err = m.sessions.Update(*s); err != nil {
return nil, err
}
return s, nil
}
func (m *SessionManager) Kill(sessionID string) (*session.Session, error) {
s, err := m.sessions.Get(sessionID)
if err != nil {
......
......@@ -55,6 +55,9 @@ type Session struct {
// Scope is the 'scope' field in the authentication request. Example scopes
// are 'openid', 'email', 'offline', etc.
Scope scope.Scopes
// Groups the user belongs to.
Groups []string
}
// Claims returns a new set of Claims for the current session.
......@@ -65,5 +68,8 @@ func (s *Session) Claims(issuerURL string) jose.Claims {
if s.Nonce != "" {
claims["nonce"] = s.Nonce
}
if s.Scope.HasScope(scope.ScopeGroups) {
claims["groups"] = s.Groups
}
return claims
}
......@@ -192,7 +192,7 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
}
refreshRepo := db.NewRefreshTokenRepo(dbMap)
for _, token := range refreshTokens {
if _, err := refreshRepo.Create(token.userID, token.clientID, []string{"openid"}); err != nil {
if _, err := refreshRepo.Create(token.userID, token.clientID, "local", []string{"openid"}); err != nil {
panic("Failed to create refresh token: " + err.Error())
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment