Commit 75473b4c authored by Bobby Rullo's avatar Bobby Rullo

refresh tokens: grant claims based on scopes

Before,  this logic was only in the OIDCServer.CodeToken() method; now it has been
pulled out so that other paths, like OIDCServer.RefreshToken() can use
it.

The net affect, is that now refresh tokens can be used to get
cross-client authenticated ID Tokens.
parent 32a1994a
...@@ -146,6 +146,7 @@ func (r *refreshTokenRepo) Verify(clientID, token string) (string, scope.Scopes, ...@@ -146,6 +146,7 @@ func (r *refreshTokenRepo) Verify(clientID, token string) (string, scope.Scopes,
if len(record.Scopes) > 0 { if len(record.Scopes) > 0 {
scopes = strings.Split(record.Scopes, " ") scopes = strings.Split(record.Scopes, " ")
} }
return record.UserID, scopes, nil return record.UserID, scopes, nil
} }
......
...@@ -147,7 +147,6 @@ func makeUserAPITestFixtures() *userAPITestFixtures { ...@@ -147,7 +147,6 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
} }
refreshRepo := db.NewRefreshTokenRepo(dbMap) refreshRepo := db.NewRefreshTokenRepo(dbMap)
fmt.Println("DEFAULT: ", oidc.DefaultScope)
for _, user := range userUsers { for _, user := range userUsers {
if _, err := refreshRepo.Create(user.User.ID, testClientID, if _, err := refreshRepo.Create(user.User.ID, testClientID,
append([]string{"offline_access"}, oidc.DefaultScope...)); err != nil { append([]string{"offline_access"}, oidc.DefaultScope...)); err != nil {
......
...@@ -41,7 +41,9 @@ func DefaultRefreshTokenGenerator() ([]byte, error) { ...@@ -41,7 +41,9 @@ func DefaultRefreshTokenGenerator() ([]byte, error) {
type RefreshTokenRepo interface { type RefreshTokenRepo interface {
// Create generates and returns a new refresh token for the given client-user pair. // Create generates and returns a new refresh token for the given client-user pair.
// On success the token will be return. // The scopes will be stored with the refresh token, and used to verify
// against future OIDC refresh requests' scopes.
// On success the token will be returned.
Create(userID, clientID string, scope []string) (string, error) Create(userID, clientID string, scope []string) (string, error)
// Verify verifies that a token belongs to the client. // Verify verifies that a token belongs to the client.
......
...@@ -14,29 +14,24 @@ import ( ...@@ -14,29 +14,24 @@ import (
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/scope" "github.com/coreos/dex/scope"
) )
func makeCrossClientTestFixtures() (*testFixtures, error) { func makeCrossClientTestFixtures() (*testFixtures, error) {
f, err := makeTestFixtures() xClients := []client.LoadableClient{}
if err != nil {
return nil, fmt.Errorf("couldn't make test fixtures: %v", err)
}
for _, cliData := range []struct { for _, cliData := range []struct {
id string id string
authorized []string trustedPeers []string
}{ }{
{ {
id: "client_a", id: "client_a",
}, { }, {
id: "client_b", id: "client_b",
authorized: []string{"client_a"}, trustedPeers: []string{"client_a"},
}, { }, {
id: "client_c", id: "client_c",
authorized: []string{"client_a", "client_b"}, trustedPeers: []string{"client_a", "client_b"},
}, },
} { } {
u := url.URL{ u := url.URL{
...@@ -44,20 +39,27 @@ func makeCrossClientTestFixtures() (*testFixtures, error) { ...@@ -44,20 +39,27 @@ func makeCrossClientTestFixtures() (*testFixtures, error) {
Path: cliData.id, Path: cliData.id,
Host: cliData.id, Host: cliData.id,
} }
cliCreds, err := f.clientManager.New(client.Client{ xClients = append(xClients, client.LoadableClient{
Client: client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: cliData.id, ID: cliData.id,
Secret: base64.URLEncoding.EncodeToString(
[]byte(cliData.id + "_secret")),
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{u}, RedirectURIs: []url.URL{u},
}, },
}, &clientmanager.ClientOptions{ },
TrustedPeers: cliData.authorized, TrustedPeers: cliData.trustedPeers,
}) })
if err != nil {
return nil, fmt.Errorf("Unexpected error creating clients: %v", err)
} }
f.clientCreds[cliData.id] = *cliCreds
xClients = append(xClients, testClients...)
f, err := makeTestFixturesWithOptions(testFixtureOptions{
clients: xClients,
})
if err != nil {
return nil, fmt.Errorf("couldn't make test fixtures: %v", err)
} }
return f, nil return f, nil
} }
......
...@@ -445,35 +445,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo ...@@ -445,35 +445,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
claims := ses.Claims(s.IssuerURL.String()) claims := ses.Claims(s.IssuerURL.String())
user.AddToClaims(claims) user.AddToClaims(claims)
crossClientIDs := ses.Scope.CrossClientIDs() s.addClaimsFromScope(claims, ses.Scope, ses.ClientID)
if len(crossClientIDs) > 0 {
var aud []string
for _, id := range crossClientIDs {
if ses.ClientID == id {
aud = append(aud, id)
continue
}
allowed, err := s.CrossClientAuthAllowed(ses.ClientID, id)
if err != nil {
log.Errorf("Failed to check cross client auth. reqClientID %v; authClient:ID %v; err: %v", ses.ClientID, id, err)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}
if !allowed {
err := oauth2.NewError(oauth2.ErrorInvalidRequest)
err.Description = fmt.Sprintf(
"%q is not authorized to perform cross-client requests for %q",
ses.ClientID, id)
return nil, "", err
}
aud = append(aud, id)
}
if len(aud) == 1 {
claims.Add("aud", aud[0])
} else {
claims.Add("aud", aud)
}
claims.Add("azp", ses.ClientID)
}
jwt, err := jose.NewSignedJWT(claims, signer) jwt, err := jose.NewSignedJWT(claims, signer)
if err != nil { if err != nil {
...@@ -555,6 +527,8 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, ...@@ -555,6 +527,8 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes,
claims := oidc.NewClaims(s.IssuerURL.String(), user.ID, creds.ID, now, expireAt) claims := oidc.NewClaims(s.IssuerURL.String(), user.ID, creds.ID, now, expireAt)
user.AddToClaims(claims) user.AddToClaims(claims)
s.addClaimsFromScope(claims, scope.Scopes(scopes), creds.ID)
jwt, err := jose.NewSignedJWT(claims, signer) jwt, err := jose.NewSignedJWT(claims, signer)
if err != nil { if err != nil {
log.Errorf("Failed to generate ID token: %v", err) log.Errorf("Failed to generate ID token: %v", err)
...@@ -596,6 +570,41 @@ func (s *Server) JWTVerifierFactory() JWTVerifierFactory { ...@@ -596,6 +570,41 @@ func (s *Server) JWTVerifierFactory() JWTVerifierFactory {
} }
} }
// addClaimsFromScope adds claims that are based on the scopes that the client requested.
// Currently, these include cross-client claims (aud, azp).
func (s *Server) addClaimsFromScope(claims jose.Claims, scopes scope.Scopes, clientID string) error {
crossClientIDs := scopes.CrossClientIDs()
if len(crossClientIDs) > 0 {
var aud []string
for _, id := range crossClientIDs {
if clientID == id {
aud = append(aud, id)
continue
}
allowed, err := s.CrossClientAuthAllowed(clientID, id)
if err != nil {
log.Errorf("Failed to check cross client auth. reqClientID %v; authClient:ID %v; err: %v", clientID, id, err)
return oauth2.NewError(oauth2.ErrorServerError)
}
if !allowed {
err := oauth2.NewError(oauth2.ErrorInvalidRequest)
err.Description = fmt.Sprintf(
"%q is not authorized to perform cross-client requests for %q",
clientID, id)
return err
}
aud = append(aud, id)
}
if len(aud) == 1 {
claims.Add("aud", aud[0])
} else {
claims.Add("aud", aud)
}
claims.Add("azp", clientID)
}
return nil
}
type sortableIDPCs []connector.Connector type sortableIDPCs []connector.Connector
func (s sortableIDPCs) Len() int { func (s sortableIDPCs) Len() int {
......
This diff is collapsed.
...@@ -39,6 +39,18 @@ var ( ...@@ -39,6 +39,18 @@ var (
ID: testClientID, ID: testClientID,
Secret: clientTestSecret, Secret: clientTestSecret,
} }
testClients = []client.LoadableClient{
{
Client: client.Client{
Credentials: testClientCredentials,
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
testRedirectURL,
},
},
},
},
}
testConnectorID1 = "IDPC-1" testConnectorID1 = "IDPC-1"
...@@ -169,18 +181,7 @@ func makeTestFixturesWithOptions(options testFixtureOptions) (*testFixtures, err ...@@ -169,18 +181,7 @@ func makeTestFixturesWithOptions(options testFixtureOptions) (*testFixtures, err
var clients []client.LoadableClient var clients []client.LoadableClient
if options.clients == nil { if options.clients == nil {
clients = []client.LoadableClient{ clients = testClients
{
Client: client.Client{
Credentials: testClientCredentials,
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
testRedirectURL,
},
},
},
},
}
} else { } else {
clients = options.clients clients = options.clients
} }
...@@ -247,6 +248,10 @@ func makeTestFixturesWithOptions(options testFixtureOptions) (*testFixtures, err ...@@ -247,6 +248,10 @@ func makeTestFixturesWithOptions(options testFixtureOptions) (*testFixtures, err
srv.absURL(httpPathAcceptInvitation), srv.absURL(httpPathAcceptInvitation),
) )
clientCreds := map[string]oidc.ClientCredentials{}
for _, c := range clients {
clientCreds[c.Client.Credentials.ID] = c.Client.Credentials
}
return &testFixtures{ return &testFixtures{
srv: srv, srv: srv,
redirectURL: testRedirectURL, redirectURL: testRedirectURL,
...@@ -255,9 +260,7 @@ func makeTestFixturesWithOptions(options testFixtureOptions) (*testFixtures, err ...@@ -255,9 +260,7 @@ func makeTestFixturesWithOptions(options testFixtureOptions) (*testFixtures, err
emailer: emailer, emailer: emailer,
clientRepo: clientRepo, clientRepo: clientRepo,
clientManager: clientManager, clientManager: clientManager,
clientCreds: map[string]oidc.ClientCredentials{ clientCreds: clientCreds,
testClientID: testClientCreds,
},
}, nil }, nil
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment