Commit 5ed42be7 authored by rithu leena john's avatar rithu leena john Committed by GitHub

Merge pull request #702 from ericchiang/connector-interface-cleanup

connector: add RefreshConnector interface
parents 27fb7c52 6980920a
# Authentication through GitHub
## Overview
One of the login options for dex uses the GitHub OAuth2 flow to identify the end user through their GitHub account.
When a client redeems a refresh token through dex, dex will re-query GitHub to update user information in the ID Token. To do this, __dex stores a readonly GitHub access token in its backing datastore.__ Users that reject dex's access through GitHub will also revoke all dex clients which authenticated them through GitHub.
## Configuration
Register a new application with [GitHub][github-oauth2] ensuring the callback URL is `(dex issuer)/callback`. For example if dex is listening at the non-root path `https://auth.example.com/dex` the callback would be `https://auth.example.com/dex/callback`.
The following is an example of a configuration for `examples/config-dev.yaml`:
```yaml
connectors:
- type: github
id: github
name: GitHub
config:
# Credentials can be string literals or pulled from the environment.
clientID: $GITHUB_CLIENT_ID
clientSecret: $GITHUB_CLIENT_SECRET
redirectURI: http://127.0.0.1:5556/dex/callback
# Optional organization to pull teams from, communicate through the
# "groups" scope.
#
# NOTE: This is an EXPERIMENTAL config option and will likely change.
org: my-oranization
```
[github-oauth2]: https://github.com/settings/applications/new
...@@ -27,6 +27,7 @@ Standards-based token responses allows applications to interact with any OpenID ...@@ -27,6 +27,7 @@ Standards-based token responses allows applications to interact with any OpenID
* [gRPC API](Documentation/api.md) * [gRPC API](Documentation/api.md)
* Identity provider logins * Identity provider logins
* [LDAP](Documentation/ldap-connector.md) * [LDAP](Documentation/ldap-connector.md)
* [GitHub](Documentation/github-connector.md)
* Client libraries * Client libraries
* [Go][go-oidc] * [Go][go-oidc]
......
// Package connector defines interfaces for federated identity strategies. // Package connector defines interfaces for federated identity strategies.
package connector package connector
import "net/http" import (
"net/http"
"golang.org/x/net/context"
)
// Connector is a mechanism for federating login to a remote identity service. // Connector is a mechanism for federating login to a remote identity service.
// //
// Implementations are expected to implement either the PasswordConnector or // Implementations are expected to implement either the PasswordConnector or
// CallbackConnector interface. // CallbackConnector interface.
type Connector interface { type Connector interface{}
Close() error
// Scopes represents additional data requested by the clients about the end user.
type Scopes struct {
// The client has requested a refresh token from the server.
OfflineAccess bool
// The client has requested group information about the end user.
Groups bool
} }
// Identity represents the ID Token claims supported by the server. // Identity represents the ID Token claims supported by the server.
...@@ -18,6 +29,8 @@ type Identity struct { ...@@ -18,6 +29,8 @@ type Identity struct {
Email string Email string
EmailVerified bool EmailVerified bool
Groups []string
// ConnectorData holds data used by the connector for subsequent requests after initial // ConnectorData holds data used by the connector for subsequent requests after initial
// authentication, such as access tokens for upstream provides. // authentication, such as access tokens for upstream provides.
// //
...@@ -25,18 +38,38 @@ type Identity struct { ...@@ -25,18 +38,38 @@ type Identity struct {
ConnectorData []byte ConnectorData []byte
} }
// PasswordConnector is an optional interface for password based connectors. // PasswordConnector is an interface implemented by connectors which take a
// username and password.
type PasswordConnector interface { type PasswordConnector interface {
Login(username, password string) (identity Identity, validPassword bool, err error) Login(ctx context.Context, s Scopes, username, password string) (identity Identity, validPassword bool, err error)
} }
// CallbackConnector is an optional interface for callback based connectors. // CallbackConnector is an interface implemented by connectors which use an OAuth
// style redirect flow to determine user information.
type CallbackConnector interface { type CallbackConnector interface {
LoginURL(callbackURL, state string) (string, error) // The initial URL to redirect the user to.
HandleCallback(r *http.Request) (identity Identity, err error) //
// OAuth2 implementations should request different scopes from the upstream
// identity provider based on the scopes requested by the downstream client.
// For example, if the downstream client requests a refresh token from the
// server, the connector should also request a token from the provider.
//
// Many identity providers have arbitrary restrictions on refresh tokens. For
// example Google only allows a single refresh token per client/user/scopes
// combination, and wont return a refresh token even if offline access is
// requested if one has already been issues. There's no good general answer
// for these kind of restrictions, and may require this package to become more
// aware of the global set of user/connector interactions.
LoginURL(s Scopes, callbackURL, state string) (string, error)
// Handle the callback to the server and return an identity.
HandleCallback(s Scopes, r *http.Request) (identity Identity, err error)
} }
// GroupsConnector is an optional interface for connectors which can map a user to groups. // RefreshConnector is a connector that can update the client claims.
type GroupsConnector interface { type RefreshConnector interface {
Groups(identity Identity) ([]string, error) // Refresh is called when a client attempts to claim a refresh token. The
// connector should attempt to update the identity object to reflect any
// changes since the token was last refreshed.
Refresh(ctx context.Context, s Scopes, identity Identity) (Identity, error)
} }
...@@ -3,6 +3,7 @@ package github ...@@ -3,6 +3,7 @@ package github
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
...@@ -15,7 +16,11 @@ import ( ...@@ -15,7 +16,11 @@ import (
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
) )
const baseURL = "https://api.github.com" const (
baseURL = "https://api.github.com"
scopeEmail = "user:email"
scopeOrgs = "read:org"
)
// Config holds configuration options for github logins. // Config holds configuration options for github logins.
type Config struct { type Config struct {
...@@ -30,15 +35,8 @@ func (c *Config) Open() (connector.Connector, error) { ...@@ -30,15 +35,8 @@ func (c *Config) Open() (connector.Connector, error) {
return &githubConnector{ return &githubConnector{
redirectURI: c.RedirectURI, redirectURI: c.RedirectURI,
org: c.Org, org: c.Org,
oauth2Config: &oauth2.Config{ clientID: c.ClientID,
ClientID: c.ClientID, clientSecret: c.ClientSecret,
ClientSecret: c.ClientSecret,
Endpoint: github.Endpoint,
Scopes: []string{
"user:email", // View user's email
"read:org", // View user's org teams.
},
},
}, nil }, nil
} }
...@@ -49,26 +47,36 @@ type connectorData struct { ...@@ -49,26 +47,36 @@ type connectorData struct {
var ( var (
_ connector.CallbackConnector = (*githubConnector)(nil) _ connector.CallbackConnector = (*githubConnector)(nil)
_ connector.GroupsConnector = (*githubConnector)(nil) _ connector.RefreshConnector = (*githubConnector)(nil)
) )
type githubConnector struct { type githubConnector struct {
redirectURI string redirectURI string
org string org string
oauth2Config *oauth2.Config clientID string
ctx context.Context clientSecret string
cancel context.CancelFunc
} }
func (c *githubConnector) Close() error { func (c *githubConnector) oauth2Config(scopes connector.Scopes) *oauth2.Config {
return nil var githubScopes []string
if scopes.Groups {
githubScopes = []string{scopeEmail, scopeOrgs}
} else {
githubScopes = []string{scopeEmail}
}
return &oauth2.Config{
ClientID: c.clientID,
ClientSecret: c.clientSecret,
Endpoint: github.Endpoint,
Scopes: githubScopes,
}
} }
func (c *githubConnector) LoginURL(callbackURL, state string) (string, error) { func (c *githubConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) {
if c.redirectURI != callbackURL { if c.redirectURI != callbackURL {
return "", fmt.Errorf("expected callback URL did not match the URL in the config") return "", fmt.Errorf("expected callback URL did not match the URL in the config")
} }
return c.oauth2Config.AuthCodeURL(state), nil return c.oauth2Config(scopes).AuthCodeURL(state), nil
} }
type oauth2Error struct { type oauth2Error struct {
...@@ -83,70 +91,144 @@ func (e *oauth2Error) Error() string { ...@@ -83,70 +91,144 @@ func (e *oauth2Error) Error() string {
return e.error + ": " + e.errorDescription return e.error + ": " + e.errorDescription
} }
func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Identity, err error) { func (c *githubConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query() q := r.URL.Query()
if errType := q.Get("error"); errType != "" { if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")} return identity, &oauth2Error{errType, q.Get("error_description")}
} }
token, err := c.oauth2Config.Exchange(c.ctx, q.Get("code"))
oauth2Config := c.oauth2Config(s)
ctx := r.Context()
token, err := oauth2Config.Exchange(ctx, q.Get("code"))
if err != nil { if err != nil {
return identity, fmt.Errorf("github: failed to get token: %v", err) return identity, fmt.Errorf("github: failed to get token: %v", err)
} }
resp, err := c.oauth2Config.Client(c.ctx, token).Get(baseURL + "/user") client := oauth2Config.Client(ctx, token)
user, err := c.user(ctx, client)
if err != nil { if err != nil {
return identity, fmt.Errorf("github: get URL %v", err) return identity, fmt.Errorf("github: get user: %v", err)
} }
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { username := user.Name
body, err := ioutil.ReadAll(resp.Body) if username == "" {
if err != nil { username = user.Login
return identity, fmt.Errorf("github: read body: %v", err)
} }
return identity, fmt.Errorf("%s: %s", resp.Status, body) identity = connector.Identity{
UserID: strconv.Itoa(user.ID),
Username: username,
Email: user.Email,
EmailVerified: true,
} }
var user struct {
Name string `json:"name"` if s.Groups && c.org != "" {
Login string `json:"login"` groups, err := c.teams(ctx, client, c.org)
ID int `json:"id"` if err != nil {
Email string `json:"email"` return identity, fmt.Errorf("github: get teams: %v", err)
} }
if err := json.NewDecoder(resp.Body).Decode(&user); err != nil { identity.Groups = groups
return identity, fmt.Errorf("failed to decode response: %v", err)
} }
if s.OfflineAccess {
data := connectorData{AccessToken: token.AccessToken} data := connectorData{AccessToken: token.AccessToken}
connData, err := json.Marshal(data) connData, err := json.Marshal(data)
if err != nil { if err != nil {
return identity, fmt.Errorf("marshal connector data: %v", err) return identity, fmt.Errorf("marshal connector data: %v", err)
} }
identity.ConnectorData = connData
}
return identity, nil
}
func (c *githubConnector) Refresh(ctx context.Context, s connector.Scopes, ident connector.Identity) (connector.Identity, error) {
if len(ident.ConnectorData) == 0 {
return ident, errors.New("no upstream access token found")
}
var data connectorData
if err := json.Unmarshal(ident.ConnectorData, &data); err != nil {
return ident, fmt.Errorf("github: unmarshal access token: %v", err)
}
client := c.oauth2Config(s).Client(ctx, &oauth2.Token{AccessToken: data.AccessToken})
user, err := c.user(ctx, client)
if err != nil {
return ident, fmt.Errorf("github: get user: %v", err)
}
username := user.Name username := user.Name
if username == "" { if username == "" {
username = user.Login username = user.Login
} }
identity = connector.Identity{ ident.Username = username
UserID: strconv.Itoa(user.ID), ident.Email = user.Email
Username: username,
Email: user.Email, if s.Groups && c.org != "" {
EmailVerified: true, groups, err := c.teams(ctx, client, c.org)
ConnectorData: connData, if err != nil {
return ident, fmt.Errorf("github: get teams: %v", err)
} }
return identity, nil ident.Groups = groups
}
return ident, nil
} }
func (c *githubConnector) Groups(identity connector.Identity) ([]string, error) { type user struct {
var data connectorData Name string `json:"name"`
if err := json.Unmarshal(identity.ConnectorData, &data); err != nil { Login string `json:"login"`
return nil, fmt.Errorf("decode connector data: %v", err) ID int `json:"id"`
Email string `json:"email"`
}
// user queries the GitHub API for profile information using the provided client. The HTTP
// client is expected to be constructed by the golang.org/x/oauth2 package, which inserts
// a bearer token as part of the request.
func (c *githubConnector) user(ctx context.Context, client *http.Client) (user, error) {
var u user
req, err := http.NewRequest("GET", baseURL+"/user", nil)
if err != nil {
return u, fmt.Errorf("github: new req: %v", err)
} }
token := &oauth2.Token{AccessToken: data.AccessToken} req = req.WithContext(ctx)
resp, err := c.oauth2Config.Client(c.ctx, token).Get(baseURL + "/user/teams") resp, err := client.Do(req)
if err != nil {
return u, fmt.Errorf("github: get URL %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return u, fmt.Errorf("github: read body: %v", err)
}
return u, fmt.Errorf("%s: %s", resp.Status, body)
}
if err := json.NewDecoder(resp.Body).Decode(&u); err != nil {
return u, fmt.Errorf("failed to decode response: %v", err)
}
return u, nil
}
// teams queries the GitHub API for team membership within a specific organization.
//
// The HTTP passed client is expected to be constructed by the golang.org/x/oauth2 package,
// which inserts a bearer token as part of the request.
func (c *githubConnector) teams(ctx context.Context, client *http.Client, org string) ([]string, error) {
req, err := http.NewRequest("GET", baseURL+"/user/teams", nil)
if err != nil {
return nil, fmt.Errorf("github: new req: %v", err)
}
req = req.WithContext(ctx)
resp, err := client.Do(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("github: get teams: %v", err) return nil, fmt.Errorf("github: get teams: %v", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
body, err := ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {
...@@ -167,7 +249,7 @@ func (c *githubConnector) Groups(identity connector.Identity) ([]string, error) ...@@ -167,7 +249,7 @@ func (c *githubConnector) Groups(identity connector.Identity) ([]string, error)
} }
groups := []string{} groups := []string{}
for _, team := range teams { for _, team := range teams {
if team.Org.Login == c.org { if team.Org.Login == org {
groups = append(groups, team.Name) groups = append(groups, team.Name)
} }
} }
......
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"log" "log"
"net" "net"
"golang.org/x/net/context"
"gopkg.in/ldap.v2" "gopkg.in/ldap.v2"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
...@@ -57,6 +58,9 @@ type Config struct { ...@@ -57,6 +58,9 @@ type Config struct {
// Required if LDAP host does not use TLS. // Required if LDAP host does not use TLS.
InsecureNoSSL bool `json:"insecureNoSSL"` InsecureNoSSL bool `json:"insecureNoSSL"`
// Don't verify the CA.
InsecureSkipVerify bool `json:"insecureSkipVerify"`
// Path to a trusted root certificate file. // Path to a trusted root certificate file.
RootCA string `json:"rootCA"` RootCA string `json:"rootCA"`
...@@ -139,11 +143,16 @@ func (c *Config) Open() (connector.Connector, error) { ...@@ -139,11 +143,16 @@ func (c *Config) Open() (connector.Connector, error) {
return connector.Connector(conn), nil return connector.Connector(conn), nil
} }
type refreshData struct {
Username string `json:"username"`
Entry ldap.Entry `json:"entry"`
}
// OpenConnector is the same as Open but returns a type with all implemented connector interfaces. // OpenConnector is the same as Open but returns a type with all implemented connector interfaces.
func (c *Config) OpenConnector() (interface { func (c *Config) OpenConnector() (interface {
connector.Connector connector.Connector
connector.PasswordConnector connector.PasswordConnector
connector.GroupsConnector connector.RefreshConnector
}, error) { }, error) {
requiredFields := []struct { requiredFields := []struct {
...@@ -174,7 +183,7 @@ func (c *Config) OpenConnector() (interface { ...@@ -174,7 +183,7 @@ func (c *Config) OpenConnector() (interface {
} }
} }
tlsConfig := &tls.Config{ServerName: host} tlsConfig := &tls.Config{ServerName: host, InsecureSkipVerify: c.InsecureSkipVerify}
if c.RootCA != "" || len(c.RootCAData) != 0 { if c.RootCA != "" || len(c.RootCAData) != 0 {
data := c.RootCAData data := c.RootCAData
if len(data) == 0 { if len(data) == 0 {
...@@ -209,12 +218,16 @@ type ldapConnector struct { ...@@ -209,12 +218,16 @@ type ldapConnector struct {
tlsConfig *tls.Config tlsConfig *tls.Config
} }
var _ connector.PasswordConnector = (*ldapConnector)(nil) var (
_ connector.PasswordConnector = (*ldapConnector)(nil)
_ connector.RefreshConnector = (*ldapConnector)(nil)
)
// do initializes a connection to the LDAP directory and passes it to the // do initializes a connection to the LDAP directory and passes it to the
// provided function. It then performs appropriate teardown or reuse before // provided function. It then performs appropriate teardown or reuse before
// returning. // returning.
func (c *ldapConnector) do(f func(c *ldap.Conn) error) error { func (c *ldapConnector) do(ctx context.Context, f func(c *ldap.Conn) error) error {
// TODO(ericchiang): support context here
var ( var (
conn *ldap.Conn conn *ldap.Conn
err error err error
...@@ -253,13 +266,32 @@ func getAttr(e ldap.Entry, name string) string { ...@@ -253,13 +266,32 @@ func getAttr(e ldap.Entry, name string) string {
return "" return ""
} }
func (c *ldapConnector) Login(username, password string) (ident connector.Identity, validPass bool, err error) { func (c *ldapConnector) identityFromEntry(user ldap.Entry) (ident connector.Identity, err error) {
var ( // If we're missing any attributes, such as email or ID, we want to report
// We want to return a different error if the user's password is incorrect vs // an error rather than continuing.
// if there was an error. missing := []string{}
incorrectPass = false
user ldap.Entry // Fill the identity struct using the attributes from the user entry.
) if ident.UserID = getAttr(user, c.UserSearch.IDAttr); ident.UserID == "" {
missing = append(missing, c.UserSearch.IDAttr)
}
if ident.Email = getAttr(user, c.UserSearch.EmailAttr); ident.Email == "" {
missing = append(missing, c.UserSearch.EmailAttr)
}
if c.UserSearch.NameAttr != "" {
if ident.Username = getAttr(user, c.UserSearch.NameAttr); ident.Username == "" {
missing = append(missing, c.UserSearch.NameAttr)
}
}
if len(missing) != 0 {
err := fmt.Errorf("ldap: entry %q missing following required attribute(s): %q", user.DN, missing)
return connector.Identity{}, err
}
return ident, nil
}
func (c *ldapConnector) userEntry(conn *ldap.Conn, username string) (user ldap.Entry, found bool, err error) {
filter := fmt.Sprintf("(%s=%s)", c.UserSearch.Username, ldap.EscapeFilter(username)) filter := fmt.Sprintf("(%s=%s)", c.UserSearch.Username, ldap.EscapeFilter(username))
if c.UserSearch.Filter != "" { if c.UserSearch.Filter != "" {
...@@ -283,24 +315,40 @@ func (c *ldapConnector) Login(username, password string) (ident connector.Identi ...@@ -283,24 +315,40 @@ func (c *ldapConnector) Login(username, password string) (ident connector.Identi
if c.UserSearch.NameAttr != "" { if c.UserSearch.NameAttr != "" {
req.Attributes = append(req.Attributes, c.UserSearch.NameAttr) req.Attributes = append(req.Attributes, c.UserSearch.NameAttr)
} }
err = c.do(func(conn *ldap.Conn) error {
resp, err := conn.Search(req) resp, err := conn.Search(req)
if err != nil { if err != nil {
return fmt.Errorf("ldap: search with filter %q failed: %v", req.Filter, err) return ldap.Entry{}, false, fmt.Errorf("ldap: search with filter %q failed: %v", req.Filter, err)
} }
switch n := len(resp.Entries); n { switch n := len(resp.Entries); n {
case 0: case 0:
log.Printf("ldap: no results returned for filter: %q", filter) log.Printf("ldap: no results returned for filter: %q", filter)
incorrectPass = true return ldap.Entry{}, false, nil
return nil
case 1: case 1:
return *resp.Entries[0], true, nil
default: default:
return fmt.Errorf("ldap: filter returned multiple (%d) results: %q", n, filter) return ldap.Entry{}, false, fmt.Errorf("ldap: filter returned multiple (%d) results: %q", n, filter)
} }
}
user = *resp.Entries[0] func (c *ldapConnector) Login(ctx context.Context, s connector.Scopes, username, password string) (ident connector.Identity, validPass bool, err error) {
var (
// We want to return a different error if the user's password is incorrect vs
// if there was an error.
incorrectPass = false
user ldap.Entry
)
err = c.do(ctx, func(conn *ldap.Conn) error {
entry, found, err := c.userEntry(conn, username)
if err != nil {
return err
}
if !found {
incorrectPass = true
return nil
}
user = entry
// Try to authenticate as the distinguished name. // Try to authenticate as the distinguished name.
if err := conn.Bind(user.DN, password); err != nil { if err := conn.Bind(user.DN, password); err != nil {
...@@ -323,44 +371,75 @@ func (c *ldapConnector) Login(username, password string) (ident connector.Identi ...@@ -323,44 +371,75 @@ func (c *ldapConnector) Login(username, password string) (ident connector.Identi
return connector.Identity{}, false, nil return connector.Identity{}, false, nil
} }
if ident, err = c.identityFromEntry(user); err != nil {
return connector.Identity{}, false, err
}
if s.Groups {
groups, err := c.groups(ctx, user)
if err != nil {
return connector.Identity{}, false, fmt.Errorf("ldap: failed to query groups: %v", err)
}
ident.Groups = groups
}
if s.OfflineAccess {
refresh := refreshData{
Username: username,
Entry: user,
}
// Encode entry for follow up requests such as the groups query and // Encode entry for follow up requests such as the groups query and
// refresh attempts. // refresh attempts.
if ident.ConnectorData, err = json.Marshal(user); err != nil { if ident.ConnectorData, err = json.Marshal(refresh); err != nil {
return connector.Identity{}, false, fmt.Errorf("ldap: marshal entry: %v", err) return connector.Identity{}, false, fmt.Errorf("ldap: marshal entry: %v", err)
} }
}
// If we're missing any attributes, such as email or ID, we want to report return ident, true, nil
// an error rather than continuing. }
missing := []string{}
// Fill the identity struct using the attributes from the user entry. func (c *ldapConnector) Refresh(ctx context.Context, s connector.Scopes, ident connector.Identity) (connector.Identity, error) {
if ident.UserID = getAttr(user, c.UserSearch.IDAttr); ident.UserID == "" { var data refreshData
missing = append(missing, c.UserSearch.IDAttr) if err := json.Unmarshal(ident.ConnectorData, &data); err != nil {
return ident, fmt.Errorf("ldap: failed to unamrshal internal data: %v", err)
} }
if ident.Email = getAttr(user, c.UserSearch.EmailAttr); ident.Email == "" {
missing = append(missing, c.UserSearch.EmailAttr) var user ldap.Entry
err := c.do(ctx, func(conn *ldap.Conn) error {
entry, found, err := c.userEntry(conn, data.Username)
if err != nil {
return err
} }
if c.UserSearch.NameAttr != "" { if !found {
if ident.Username = getAttr(user, c.UserSearch.NameAttr); ident.Username == "" { return fmt.Errorf("ldap: user not found %q", data.Username)
missing = append(missing, c.UserSearch.NameAttr)
} }
user = entry
return nil
})
if err != nil {
return ident, err
} }
if user.DN != data.Entry.DN {
if len(missing) != 0 { return ident, fmt.Errorf("ldap: refresh for username %q expected DN %q got %q", data.Username, data.Entry.DN, user.DN)
err := fmt.Errorf("ldap: entry %q missing following required attribute(s): %q", user.DN, missing)
return connector.Identity{}, false, err
} }
return ident, true, nil newIdent, err := c.identityFromEntry(user)
} if err != nil {
return ident, err
}
newIdent.ConnectorData = ident.ConnectorData
func (c *ldapConnector) Groups(ident connector.Identity) ([]string, error) { if s.Groups {
// Decode the user entry from the identity. groups, err := c.groups(ctx, user)
var user ldap.Entry if err != nil {
if err := json.Unmarshal(ident.ConnectorData, &user); err != nil { return connector.Identity{}, fmt.Errorf("ldap: failed to query groups: %v", err)
return nil, fmt.Errorf("ldap: failed to unmarshal connector data: %v", err)
} }
newIdent.Groups = groups
}
return newIdent, nil
}
func (c *ldapConnector) groups(ctx context.Context, user ldap.Entry) ([]string, error) {
filter := fmt.Sprintf("(%s=%s)", c.GroupSearch.GroupAttr, ldap.EscapeFilter(getAttr(user, c.GroupSearch.UserAttr))) filter := fmt.Sprintf("(%s=%s)", c.GroupSearch.GroupAttr, ldap.EscapeFilter(getAttr(user, c.GroupSearch.UserAttr)))
if c.GroupSearch.Filter != "" { if c.GroupSearch.Filter != "" {
filter = fmt.Sprintf("(&%s%s)", c.GroupSearch.Filter, filter) filter = fmt.Sprintf("(&%s%s)", c.GroupSearch.Filter, filter)
...@@ -374,7 +453,7 @@ func (c *ldapConnector) Groups(ident connector.Identity) ([]string, error) { ...@@ -374,7 +453,7 @@ func (c *ldapConnector) Groups(ident connector.Identity) ([]string, error) {
} }
var groups []*ldap.Entry var groups []*ldap.Entry
if err := c.do(func(conn *ldap.Conn) error { if err := c.do(ctx, func(conn *ldap.Conn) error {
resp, err := conn.Search(req) resp, err := conn.Search(req)
if err != nil { if err != nil {
return fmt.Errorf("ldap: search failed: %v", err) return fmt.Errorf("ldap: search failed: %v", err)
...@@ -406,7 +485,3 @@ func (c *ldapConnector) Groups(ident connector.Identity) ([]string, error) { ...@@ -406,7 +485,3 @@ func (c *ldapConnector) Groups(ident connector.Identity) ([]string, error) {
} }
return groupNames, nil return groupNames, nil
} }
func (c *ldapConnector) Close() error {
return nil
}
...@@ -2,33 +2,45 @@ ...@@ -2,33 +2,45 @@
package mock package mock
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
"golang.org/x/net/context"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
) )
// NewCallbackConnector returns a mock connector which requires no user interaction. It always returns // NewCallbackConnector returns a mock connector which requires no user interaction. It always returns
// the same (fake) identity. // the same (fake) identity.
func NewCallbackConnector() connector.Connector { func NewCallbackConnector() connector.Connector {
return callbackConnector{} return &Callback{
Identity: connector.Identity{
UserID: "0-385-28089-0",
Username: "Kilgore Trout",
Email: "kilgore@kilgore.trout",
EmailVerified: true,
Groups: []string{"authors"},
ConnectorData: connectorData,
},
}
} }
var ( var (
_ connector.CallbackConnector = callbackConnector{} _ connector.CallbackConnector = &Callback{}
_ connector.GroupsConnector = callbackConnector{}
_ connector.PasswordConnector = passwordConnector{} _ connector.PasswordConnector = passwordConnector{}
) )
type callbackConnector struct{} // Callback is a connector that requires no user interaction and always returns the same identity.
type Callback struct {
func (m callbackConnector) Close() error { return nil } // The returned identity.
Identity connector.Identity
}
func (m callbackConnector) LoginURL(callbackURL, state string) (string, error) { // LoginURL returns the URL to redirect the user to login with.
func (m *Callback) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) {
u, err := url.Parse(callbackURL) u, err := url.Parse(callbackURL)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err) return "", fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err)
...@@ -41,21 +53,14 @@ func (m callbackConnector) LoginURL(callbackURL, state string) (string, error) { ...@@ -41,21 +53,14 @@ func (m callbackConnector) LoginURL(callbackURL, state string) (string, error) {
var connectorData = []byte("foobar") var connectorData = []byte("foobar")
func (m callbackConnector) HandleCallback(r *http.Request) (connector.Identity, error) { // HandleCallback parses the request and returns the user's identity
return connector.Identity{ func (m *Callback) HandleCallback(s connector.Scopes, r *http.Request) (connector.Identity, error) {
UserID: "0-385-28089-0", return m.Identity, nil
Username: "Kilgore Trout",
Email: "kilgore@kilgore.trout",
EmailVerified: true,
ConnectorData: connectorData,
}, nil
} }
func (m callbackConnector) Groups(identity connector.Identity) ([]string, error) { // Refresh updates the identity during a refresh token request.
if !bytes.Equal(identity.ConnectorData, connectorData) { func (m *Callback) Refresh(ctx context.Context, s connector.Scopes, identity connector.Identity) (connector.Identity, error) {
return nil, errors.New("connector data mismatch") return m.Identity, nil
}
return []string{"authors"}, nil
} }
// CallbackConfig holds the configuration parameters for a connector which requires no interaction. // CallbackConfig holds the configuration parameters for a connector which requires no interaction.
...@@ -91,7 +96,7 @@ type passwordConnector struct { ...@@ -91,7 +96,7 @@ type passwordConnector struct {
func (p passwordConnector) Close() error { return nil } func (p passwordConnector) Close() error { return nil }
func (p passwordConnector) Login(username, password string) (identity connector.Identity, validPassword bool, err error) { func (p passwordConnector) Login(ctx context.Context, s connector.Scopes, username, password string) (identity connector.Identity, validPassword bool, err error) {
if username == p.username && password == p.password { if username == p.username && password == p.password {
return connector.Identity{ return connector.Identity{
UserID: "0-385-28089-0", UserID: "0-385-28089-0",
......
...@@ -75,7 +75,7 @@ func (c *oidcConnector) Close() error { ...@@ -75,7 +75,7 @@ func (c *oidcConnector) Close() error {
return nil return nil
} }
func (c *oidcConnector) LoginURL(callbackURL, state string) (string, error) { func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) {
if c.redirectURI != callbackURL { if c.redirectURI != callbackURL {
return "", fmt.Errorf("expected callback URL did not match the URL in the config") return "", fmt.Errorf("expected callback URL did not match the URL in the config")
} }
...@@ -94,7 +94,7 @@ func (e *oauth2Error) Error() string { ...@@ -94,7 +94,7 @@ func (e *oauth2Error) Error() string {
return e.error + ": " + e.errorDescription return e.error + ": " + e.errorDescription
} }
func (c *oidcConnector) HandleCallback(r *http.Request) (identity connector.Identity, err error) { func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query() q := r.URL.Query()
if errType := q.Get("error"); errType != "" { if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")} return identity, &oauth2Error{errType, q.Get("error_description")}
......
...@@ -179,7 +179,13 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { ...@@ -179,7 +179,13 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
authReqID := r.FormValue("req") authReqID := r.FormValue("req")
// TODO(ericchiang): cache user identity. authReq, err := s.storage.GetAuthRequest(authReqID)
if err != nil {
log.Printf("Failed to get auth request: %v", err)
s.renderError(w, http.StatusInternalServerError, errServerError, "")
return
}
scopes := parseScopes(authReq.Scopes)
switch r.Method { switch r.Method {
case "GET": case "GET":
...@@ -199,7 +205,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { ...@@ -199,7 +205,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
// Use the auth request ID as the "state" token. // Use the auth request ID as the "state" token.
// //
// TODO(ericchiang): Is this appropriate or should we also be using a nonce? // TODO(ericchiang): Is this appropriate or should we also be using a nonce?
callbackURL, err := conn.LoginURL(s.absURL("/callback"), authReqID) callbackURL, err := conn.LoginURL(scopes, s.absURL("/callback"), authReqID)
if err != nil { if err != nil {
log.Printf("Connector %q returned error when creating callback: %v", connID, err) log.Printf("Connector %q returned error when creating callback: %v", connID, err)
s.renderError(w, http.StatusInternalServerError, errServerError, "") s.renderError(w, http.StatusInternalServerError, errServerError, "")
...@@ -221,7 +227,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { ...@@ -221,7 +227,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
username := r.FormValue("login") username := r.FormValue("login")
password := r.FormValue("password") password := r.FormValue("password")
identity, ok, err := passwordConnector.Login(username, password) identity, ok, err := passwordConnector.Login(r.Context(), scopes, username, password)
if err != nil { if err != nil {
log.Printf("Failed to login user: %v", err) log.Printf("Failed to login user: %v", err)
s.renderError(w, http.StatusInternalServerError, errServerError, "") s.renderError(w, http.StatusInternalServerError, errServerError, "")
...@@ -231,12 +237,6 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { ...@@ -231,12 +237,6 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
s.templates.password(w, authReqID, r.URL.String(), username, true) s.templates.password(w, authReqID, r.URL.String(), username, true)
return return
} }
authReq, err := s.storage.GetAuthRequest(authReqID)
if err != nil {
log.Printf("Failed to get auth request: %v", err)
s.renderError(w, http.StatusInternalServerError, errServerError, "")
return
}
redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector) redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector)
if err != nil { if err != nil {
log.Printf("Failed to finalize login: %v", err) log.Printf("Failed to finalize login: %v", err)
...@@ -286,7 +286,7 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) ...@@ -286,7 +286,7 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request)
return return
} }
identity, err := callbackConnector.HandleCallback(r) identity, err := callbackConnector.HandleCallback(parseScopes(authReq.Scopes), r)
if err != nil { if err != nil {
log.Printf("Failed to authenticate: %v", err) log.Printf("Failed to authenticate: %v", err)
s.renderError(w, http.StatusInternalServerError, errServerError, "") s.renderError(w, http.StatusInternalServerError, errServerError, "")
...@@ -304,34 +304,12 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) ...@@ -304,34 +304,12 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request)
} }
func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.AuthRequest, conn connector.Connector) (string, error) { func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.AuthRequest, conn connector.Connector) (string, error) {
if authReq.ConnectorID == "" {
}
claims := storage.Claims{ claims := storage.Claims{
UserID: identity.UserID, UserID: identity.UserID,
Username: identity.Username, Username: identity.Username,
Email: identity.Email, Email: identity.Email,
EmailVerified: identity.EmailVerified, EmailVerified: identity.EmailVerified,
} Groups: identity.Groups,
groupsConn, ok := conn.(connector.GroupsConnector)
if ok {
reqGroups := func() bool {
for _, scope := range authReq.Scopes {
if scope == scopeGroups {
return true
}
}
return false
}()
if reqGroups {
groups, err := groupsConn.Groups(identity)
if err != nil {
return "", fmt.Errorf("getting groups: %v", err)
}
claims.Groups = groups
}
} }
updater := func(a storage.AuthRequest) (storage.AuthRequest, error) { updater := func(a storage.AuthRequest) (storage.AuthRequest, error) {
...@@ -415,6 +393,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe ...@@ -415,6 +393,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
Claims: authReq.Claims, Claims: authReq.Claims,
Expiry: s.now().Add(time.Minute * 30), Expiry: s.now().Add(time.Minute * 30),
RedirectURI: authReq.RedirectURI, RedirectURI: authReq.RedirectURI,
ConnectorData: authReq.ConnectorData,
} }
if err := s.storage.CreateAuthCode(code); err != nil { if err := s.storage.CreateAuthCode(code); err != nil {
log.Printf("Failed to create auth code: %v", err) log.Printf("Failed to create auth code: %v", err)
...@@ -543,6 +522,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s ...@@ -543,6 +522,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
Scopes: authCode.Scopes, Scopes: authCode.Scopes,
Claims: authCode.Claims, Claims: authCode.Claims,
Nonce: authCode.Nonce, Nonce: authCode.Nonce,
ConnectorData: authCode.ConnectorData,
} }
if err := s.storage.CreateRefresh(refresh); err != nil { if err := s.storage.CreateRefresh(refresh); err != nil {
log.Printf("failed to create refresh token: %v", err) log.Printf("failed to create refresh token: %v", err)
...@@ -574,6 +554,10 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie ...@@ -574,6 +554,10 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
return return
} }
// Per the OAuth2 spec, if the client has omitted the scopes, default to the original
// authorized scopes.
//
// https://tools.ietf.org/html/rfc6749#section-6
scopes := refresh.Scopes scopes := refresh.Scopes
if scope != "" { if scope != "" {
requestedScopes := strings.Fields(scope) requestedScopes := strings.Fields(scope)
...@@ -601,7 +585,43 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie ...@@ -601,7 +585,43 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
scopes = requestedScopes scopes = requestedScopes
} }
// TODO(ericchiang): re-auth with backends conn, ok := s.connectors[refresh.ConnectorID]
if !ok {
log.Printf("connector ID not found: %q", refresh.ConnectorID)
tokenErr(w, errServerError, "", http.StatusInternalServerError)
return
}
// Can the connector refresh the identity? If so, attempt to refresh the data
// in the connector.
//
// TODO(ericchiang): We may want a strict mode where connectors that don't implement
// this interface can't perform refreshing.
if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok {
ident := connector.Identity{
UserID: refresh.Claims.UserID,
Username: refresh.Claims.Username,
Email: refresh.Claims.Email,
EmailVerified: refresh.Claims.EmailVerified,
Groups: refresh.Claims.Groups,
ConnectorData: refresh.ConnectorData,
}
ident, err := refreshConn.Refresh(r.Context(), parseScopes(scopes), ident)
if err != nil {
log.Printf("failed to refresh identity: %v", err)
tokenErr(w, errServerError, "", http.StatusInternalServerError)
return
}
// Update the claims of the refresh token.
//
// UserID intentionally ignored for now.
refresh.Claims.Username = ident.Username
refresh.Claims.Email = ident.Email
refresh.Claims.EmailVerified = ident.EmailVerified
refresh.Claims.Groups = ident.Groups
refresh.ConnectorData = ident.ConnectorData
}
idToken, expiry, err := s.newIDToken(client.ID, refresh.Claims, scopes, refresh.Nonce) idToken, expiry, err := s.newIDToken(client.ID, refresh.Claims, scopes, refresh.Nonce)
if err != nil { if err != nil {
...@@ -610,6 +630,8 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie ...@@ -610,6 +630,8 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
return return
} }
// Refresh tokens are claimed exactly once. Delete the current token and
// create a new one.
if err := s.storage.DeleteRefresh(code); err != nil { if err := s.storage.DeleteRefresh(code); err != nil {
log.Printf("failed to delete auth code: %v", err) log.Printf("failed to delete auth code: %v", err)
tokenErr(w, errServerError, "", http.StatusInternalServerError) tokenErr(w, errServerError, "", http.StatusInternalServerError)
......
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/storage" "github.com/coreos/dex/storage"
) )
...@@ -93,6 +94,19 @@ const ( ...@@ -93,6 +94,19 @@ const (
responseTypeIDToken = "id_token" // ID Token in url fragment responseTypeIDToken = "id_token" // ID Token in url fragment
) )
func parseScopes(scopes []string) connector.Scopes {
var s connector.Scopes
for _, scope := range scopes {
switch scope {
case scopeOfflineAccess:
s.OfflineAccess = true
case scopeGroups:
s.Groups = true
}
}
return s
}
type audience []string type audience []string
func (a audience) MarshalJSON() ([]byte, error) { func (a audience) MarshalJSON() ([]byte, error) {
......
...@@ -211,9 +211,7 @@ type passwordDB struct { ...@@ -211,9 +211,7 @@ type passwordDB struct {
s storage.Storage s storage.Storage
} }
func (db passwordDB) Close() error { return nil } func (db passwordDB) Login(ctx context.Context, s connector.Scopes, email, password string) (connector.Identity, bool, error) {
func (db passwordDB) Login(email, password string) (connector.Identity, bool, error) {
p, err := db.s.GetPassword(email) p, err := db.s.GetPassword(email)
if err != nil { if err != nil {
if err != storage.ErrNotFound { if err != storage.ErrNotFound {
...@@ -233,6 +231,31 @@ func (db passwordDB) Login(email, password string) (connector.Identity, bool, er ...@@ -233,6 +231,31 @@ func (db passwordDB) Login(email, password string) (connector.Identity, bool, er
}, true, nil }, true, nil
} }
func (db passwordDB) Refresh(ctx context.Context, s connector.Scopes, identity connector.Identity) (connector.Identity, error) {
// If the user has been deleted, the refresh token will be rejected.
p, err := db.s.GetPassword(identity.Email)
if err != nil {
if err == storage.ErrNotFound {
return connector.Identity{}, errors.New("user not found")
}
return connector.Identity{}, fmt.Errorf("get password: %v", err)
}
// User removed but a new user with the same email exists.
if p.UserID != identity.UserID {
return connector.Identity{}, errors.New("user not found")
}
// If a user has updated their username, that will be reflected in the
// refreshed token.
//
// No other fields are expected to be refreshable as email is effectively used
// as an ID and this implementation doesn't deal with groups.
identity.Username = p.Username
return identity, nil
}
// newKeyCacher returns a storage which caches keys so long as the next // newKeyCacher returns a storage which caches keys so long as the next
func newKeyCacher(s storage.Storage, now func() time.Time) storage.Storage { func newKeyCacher(s storage.Storage, now func() time.Time) storage.Storage {
if now == nil { if now == nil {
......
...@@ -132,10 +132,13 @@ func TestDiscovery(t *testing.T) { ...@@ -132,10 +132,13 @@ func TestDiscovery(t *testing.T) {
} }
} }
// TestOAuth2CodeFlow runs integration tests against a test server. The tests stand up a server
// which requires no interaction to login, logs in through a test client, then passes the client
// and returned token to the test.
func TestOAuth2CodeFlow(t *testing.T) { func TestOAuth2CodeFlow(t *testing.T) {
clientID := "testclient" clientID := "testclient"
clientSecret := "testclientsecret" clientSecret := "testclientsecret"
requestedScopes := []string{oidc.ScopeOpenID, "email", "offline_access"} requestedScopes := []string{oidc.ScopeOpenID, "email", "profile", "groups", "offline_access"}
t0 := time.Now() t0 := time.Now()
...@@ -149,8 +152,14 @@ func TestOAuth2CodeFlow(t *testing.T) { ...@@ -149,8 +152,14 @@ func TestOAuth2CodeFlow(t *testing.T) {
// so tests can compute the expected "expires_in" field. // so tests can compute the expected "expires_in" field.
idTokensValidFor := time.Second * 30 idTokensValidFor := time.Second * 30
// Connector used by the tests.
var conn *mock.Callback
tests := []struct { tests := []struct {
name string name string
// If specified these set of scopes will be used during the test case.
scopes []string
// handleToken provides the OAuth2 token response for the integration test.
handleToken func(context.Context, *oidc.Provider, *oauth2.Config, *oauth2.Token) error handleToken func(context.Context, *oidc.Provider, *oauth2.Config, *oauth2.Token) error
}{ }{
{ {
...@@ -266,6 +275,7 @@ func TestOAuth2CodeFlow(t *testing.T) { ...@@ -266,6 +275,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
}, },
{ {
name: "refresh with unauthorized scopes", name: "refresh with unauthorized scopes",
scopes: []string{"openid", "email"},
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
v := url.Values{} v := url.Values{}
v.Add("client_id", clientID) v.Add("client_id", clientID)
...@@ -273,7 +283,7 @@ func TestOAuth2CodeFlow(t *testing.T) { ...@@ -273,7 +283,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
v.Add("grant_type", "refresh_token") v.Add("grant_type", "refresh_token")
v.Add("refresh_token", token.RefreshToken) v.Add("refresh_token", token.RefreshToken)
// Request a scope that wasn't requestd initially. // Request a scope that wasn't requestd initially.
v.Add("scope", strings.Join(append(requestedScopes, "profile"), " ")) v.Add("scope", "oidc email profile")
resp, err := http.PostForm(p.TokenURL, v) resp, err := http.PostForm(p.TokenURL, v)
if err != nil { if err != nil {
return err return err
...@@ -289,6 +299,57 @@ func TestOAuth2CodeFlow(t *testing.T) { ...@@ -289,6 +299,57 @@ func TestOAuth2CodeFlow(t *testing.T) {
return nil return nil
}, },
}, },
{
// This test ensures that the connector.RefreshConnector interface is being
// used when clients request a refresh token.
name: "refresh with identity changes",
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
// have to use time.Now because the OAuth2 package uses it.
token.Expiry = time.Now().Add(time.Second * -10)
if token.Valid() {
return errors.New("token shouldn't be valid")
}
ident := connector.Identity{
UserID: "fooid",
Username: "foo",
Email: "foo@bar.com",
EmailVerified: true,
Groups: []string{"foo", "bar"},
}
conn.Identity = ident
type claims struct {
Username string `json:"name"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Groups []string `json:"groups"`
}
want := claims{ident.Username, ident.Email, ident.EmailVerified, ident.Groups}
newToken, err := config.TokenSource(ctx, token).Token()
if err != nil {
return fmt.Errorf("failed to refresh token: %v", err)
}
rawIDToken, ok := newToken.Extra("id_token").(string)
if !ok {
return fmt.Errorf("no id_token in refreshed token")
}
idToken, err := p.NewVerifier(ctx).Verify(rawIDToken)
if err != nil {
return fmt.Errorf("failed to verify id token: %v", err)
}
var got claims
if err := idToken.Claims(&got); err != nil {
return fmt.Errorf("failed to unmarshal claims: %v", err)
}
if diff := pretty.Compare(want, got); diff != "" {
return fmt.Errorf("got identity != want identity: %s", diff)
}
return nil
},
},
} }
for _, tc := range tests { for _, tc := range tests {
...@@ -300,6 +361,15 @@ func TestOAuth2CodeFlow(t *testing.T) { ...@@ -300,6 +361,15 @@ func TestOAuth2CodeFlow(t *testing.T) {
c.Issuer = c.Issuer + "/non-root-path" c.Issuer = c.Issuer + "/non-root-path"
c.Now = now c.Now = now
c.IDTokensValidFor = idTokensValidFor c.IDTokensValidFor = idTokensValidFor
// Create a new mock callback connector for each test case.
conn = mock.NewCallbackConnector().(*mock.Callback)
c.Connectors = []Connector{
{
ID: "mock",
DisplayName: "mock",
Connector: conn,
},
}
}) })
defer httpServer.Close() defer httpServer.Close()
...@@ -375,6 +445,9 @@ func TestOAuth2CodeFlow(t *testing.T) { ...@@ -375,6 +445,9 @@ func TestOAuth2CodeFlow(t *testing.T) {
Scopes: requestedScopes, Scopes: requestedScopes,
RedirectURL: redirectURL, RedirectURL: redirectURL,
} }
if len(tc.scopes) != 0 {
oauth2Config.Scopes = tc.scopes
}
resp, err := http.Get(oauth2Server.URL + "/login") resp, err := http.Get(oauth2Server.URL + "/login")
if err != nil { if err != nil {
...@@ -662,7 +735,6 @@ func TestCrossClientScopes(t *testing.T) { ...@@ -662,7 +735,6 @@ func TestCrossClientScopes(t *testing.T) {
func TestPasswordDB(t *testing.T) { func TestPasswordDB(t *testing.T) {
s := memory.New() s := memory.New()
conn := newPasswordDB(s) conn := newPasswordDB(s)
defer conn.Close()
pw := "hi" pw := "hi"
...@@ -712,7 +784,7 @@ func TestPasswordDB(t *testing.T) { ...@@ -712,7 +784,7 @@ func TestPasswordDB(t *testing.T) {
} }
for _, tc := range tests { for _, tc := range tests {
ident, valid, err := conn.Login(tc.username, tc.password) ident, valid, err := conn.Login(context.Background(), connector.Scopes{}, tc.username, tc.password)
if err != nil { if err != nil {
if !tc.wantErr { if !tc.wantErr {
t.Errorf("%s: %v", tc.name, err) t.Errorf("%s: %v", tc.name, err)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment