Commit 32a1994a authored by Bobby Rullo's avatar Bobby Rullo

refresh tokens: store and validate scopes.

A refresh request must fail if it asks for scopes that were not
originally granted when the refresh token was obtained.

This Commit:

* changes repo to store scopes with tokens
* changes repo interface signatures so that scopes can be stored and
  verified
* updates dependent code to pass along scopes
parent ea2f0a32
......@@ -39,7 +39,8 @@ CREATE TABLE refresh_token (
id integer PRIMARY KEY,
payload_hash blob,
user_id text,
client_id text
client_id text,
scopes text
);
CREATE TABLE remote_identity_mapping (
......
-- +migrate Up
ALTER TABLE refresh_token ADD COLUMN "scopes" text;
UPDATE refresh_token SET scopes = 'openid profile email offline_access';
......@@ -78,5 +78,11 @@ var PostgresMigrations migrate.MigrationSource = &migrate.MemoryMigrationSource{
"-- +migrate Up\nCREATE TABLE IF NOT EXISTS \"trusted_peers\" (\n \"client_id\" text not null,\n \"trusted_client_id\" text not null,\n primary key (\"client_id\", \"trusted_client_id\")) ;\n",
},
},
{
Id: "0013_add_scopes_to_refresh_tokens.sql",
Up: []string{
"-- +migrate Up\nALTER TABLE refresh_token ADD COLUMN \"scopes\" text;\n\nUPDATE refresh_token SET scopes = 'openid profile email offline_access';\n",
},
},
},
}
......@@ -15,6 +15,7 @@ import (
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh"
"github.com/coreos/dex/repo"
"github.com/coreos/dex/scope"
)
const (
......@@ -38,10 +39,9 @@ type refreshTokenRepo struct {
type refreshTokenModel struct {
ID int64 `db:"id"`
PayloadHash []byte `db:"payload_hash"`
// TODO(yifan): Use some sort of foreign key to manage database level
// data integrity.
UserID string `db:"user_id"`
ClientID string `db:"client_id"`
Scopes string `db:"scopes"`
}
// buildToken combines the token ID and token payload to create a new token.
......@@ -89,7 +89,7 @@ func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenG
}
}
func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) {
func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (string, error) {
if userID == "" {
return "", refresh.ErrorInvalidUserID
}
......@@ -112,6 +112,7 @@ func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) {
PayloadHash: payloadHash,
UserID: userID,
ClientID: clientID,
Scopes: strings.Join(scopes, " "),
}
if err := r.executor(nil).Insert(record); err != nil {
......@@ -121,27 +122,31 @@ func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) {
return buildToken(record.ID, tokenPayload), nil
}
func (r *refreshTokenRepo) Verify(clientID, token string) (string, error) {
func (r *refreshTokenRepo) Verify(clientID, token string) (string, scope.Scopes, error) {
tokenID, tokenPayload, err := parseToken(token)
if err != nil {
return "", err
return "", nil, err
}
record, err := r.get(nil, tokenID)
if err != nil {
return "", err
return "", nil, err
}
if record.ClientID != clientID {
return "", refresh.ErrorInvalidClientID
return "", nil, refresh.ErrorInvalidClientID
}
if err := checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
return "", err
return "", nil, err
}
return record.UserID, nil
var scopes []string
if len(record.Scopes) > 0 {
scopes = strings.Split(record.Scopes, " ")
}
return record.UserID, scopes, nil
}
func (r *refreshTokenRepo) Revoke(userID, token string) error {
......@@ -190,7 +195,6 @@ func (r *refreshTokenRepo) ClientsWithRefreshTokens(userID string) ([]client.Cli
q := `SELECT c.* FROM %s as c
INNER JOIN %s as r ON c.id = r.client_id WHERE r.user_id = $1;`
q = fmt.Sprintf(q, r.quote(clientTableName), r.quote(refreshTokenTableName))
var clients []clientModel
if _, err := r.executor(nil).Select(&clients, q, userID); err != nil {
return nil, err
......@@ -206,6 +210,7 @@ func (r *refreshTokenRepo) ClientsWithRefreshTokens(userID string) ([]client.Cli
// Do not share the secret.
c[i].Credentials.Secret = ""
}
return c, nil
}
......
package functional
import (
"encoding/base64"
"fmt"
"net/url"
"os"
......@@ -16,7 +15,6 @@ import (
"github.com/coreos/dex/client"
"github.com/coreos/dex/client/manager"
"github.com/coreos/dex/db"
"github.com/coreos/dex/refresh"
"github.com/coreos/dex/session"
)
......@@ -411,207 +409,3 @@ func TestDBClientAll(t *testing.T) {
t.Fatalf("Retrieved incorrect number of ClientIdentities: want=2 got=%d", count)
}
}
// buildRefreshToken combines the token ID and token payload to create a new token.
// used in the tests to created a refresh token.
func buildRefreshToken(tokenID int64, tokenPayload []byte) string {
return fmt.Sprintf("%d%s%s", tokenID, refresh.TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload))
}
func TestDBRefreshRepoCreate(t *testing.T) {
r := db.NewRefreshTokenRepo(connect(t))
tests := []struct {
userID string
clientID string
err error
}{
{
"",
"client-foo",
refresh.ErrorInvalidUserID,
},
{
"user-foo",
"",
refresh.ErrorInvalidClientID,
},
{
"user-foo",
"client-foo",
nil,
},
}
for i, tt := range tests {
token, err := r.Create(tt.userID, tt.clientID)
if err != nil {
if tt.err == nil {
t.Errorf("case %d: create failed: %v", i, err)
}
continue
}
if tt.err != nil {
t.Errorf("case %d: expected error, didn't get one", i)
continue
}
userID, err := r.Verify(tt.clientID, token)
if err != nil {
t.Errorf("case %d: failed to verify good token: %v", i, err)
continue
}
if userID != tt.userID {
t.Errorf("case %d: want userID=%s, got userID=%s", i, tt.userID, userID)
}
}
}
func TestDBRefreshRepoVerify(t *testing.T) {
r := db.NewRefreshTokenRepo(connect(t))
token, err := r.Create("user-foo", "client-foo")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
badTokenPayload, err := refresh.DefaultRefreshTokenGenerator()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
tokenWithBadID := "404" + token[1:]
tokenWithBadPayload := buildRefreshToken(1, badTokenPayload)
tests := []struct {
token string
creds oidc.ClientCredentials
err error
expected string
}{
{
"invalid-token-format",
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
"b/invalid-base64-encoded-format",
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
"1/invalid-base64-encoded-format",
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
token + "corrupted-token-payload",
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
// The token's ID content is invalid.
tokenWithBadID,
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
// The token's payload content is invalid.
tokenWithBadPayload,
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
token,
oidc.ClientCredentials{ID: "invalid-client", Secret: "secret-foo"},
refresh.ErrorInvalidClientID,
"",
},
{
token,
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
nil,
"user-foo",
},
}
for i, tt := range tests {
result, err := r.Verify(tt.creds.ID, tt.token)
if err != tt.err {
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
}
if result != tt.expected {
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.expected, result)
}
}
}
func TestDBRefreshRepoRevoke(t *testing.T) {
r := db.NewRefreshTokenRepo(connect(t))
token, err := r.Create("user-foo", "client-foo")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
badTokenPayload, err := refresh.DefaultRefreshTokenGenerator()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
tokenWithBadID := "404" + token[1:]
tokenWithBadPayload := buildRefreshToken(1, badTokenPayload)
tests := []struct {
token string
userID string
err error
}{
{
"invalid-token-format",
"user-foo",
refresh.ErrorInvalidToken,
},
{
"1/invalid-base64-encoded-format",
"user-foo",
refresh.ErrorInvalidToken,
},
{
token + "corrupted-token-payload",
"user-foo",
refresh.ErrorInvalidToken,
},
{
// The token's ID is invalid.
tokenWithBadID,
"user-foo",
refresh.ErrorInvalidToken,
},
{
// The token's payload is invalid.
tokenWithBadPayload,
"user-foo",
refresh.ErrorInvalidToken,
},
{
token,
"invalid-user",
refresh.ErrorInvalidUserID,
},
{
token,
"user-foo",
nil,
},
}
for i, tt := range tests {
if err := r.Revoke(tt.userID, tt.token); err != tt.err {
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
}
}
}
This diff is collapsed.
......@@ -12,7 +12,8 @@ import (
func connect(t *testing.T) *gorp.DbMap {
dsn := os.Getenv("DEX_TEST_DSN")
if dsn == "" {
t.Fatal("DEX_TEST_DSN environment variable not set")
return db.NewMemDB()
}
c, err := db.NewConnection(db.Config{DSN: dsn})
if err != nil {
......
......@@ -231,7 +231,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
// this will actually happen due to some interaction between the
// end-user and a remote identity provider
sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false, []string{"openid", "offline_access"})
sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false, []string{"openid", "offline_access", "email", "profile"})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
......
......@@ -147,8 +147,10 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
}
refreshRepo := db.NewRefreshTokenRepo(dbMap)
fmt.Println("DEFAULT: ", oidc.DefaultScope)
for _, user := range userUsers {
if _, err := refreshRepo.Create(user.User.ID, testClientID); err != nil {
if _, err := refreshRepo.Create(user.User.ID, testClientID,
append([]string{"offline_access"}, oidc.DefaultScope...)); err != nil {
panic("Failed to create refresh token: " + err.Error())
}
}
......
......@@ -5,6 +5,7 @@ import (
"errors"
"github.com/coreos/dex/client"
"github.com/coreos/dex/scope"
)
const (
......@@ -41,11 +42,12 @@ func DefaultRefreshTokenGenerator() ([]byte, error) {
type RefreshTokenRepo interface {
// Create generates and returns a new refresh token for the given client-user pair.
// On success the token will be return.
Create(userID, clientID string) (string, error)
Create(userID, clientID string, scope []string) (string, error)
// Verify verifies that a token belongs to the client, and returns the corresponding user ID.
// Note that this assumes the client validation is currently done in the application layer,
Verify(clientID, token string) (string, error)
// Verify verifies that a token belongs to the client.
// It returns the user ID to which the token belongs, and the scopes stored
// with token.
Verify(clientID, token string) (string, scope.Scopes, error)
// Revoke deletes the refresh token if the token belongs to the given userID.
Revoke(userID, token string) error
......
......@@ -32,3 +32,17 @@ func (s Scopes) CrossClientIDs() []string {
}
return clients
}
func (s Scopes) Contains(other Scopes) bool {
rScopes := map[string]struct{}{}
for _, scope := range s {
rScopes[scope] = struct{}{}
}
for _, scope := range other {
if _, ok := rScopes[scope]; !ok {
return false
}
}
return true
}
......@@ -518,11 +518,12 @@ func handleTokenFunc(srv OIDCServer) http.HandlerFunc {
}
case oauth2.GrantTypeRefreshToken:
token := r.PostForm.Get("refresh_token")
scopes := r.PostForm.Get("scope")
if token == "" {
writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state)
return
}
jwt, err = srv.RefreshToken(creds, token)
jwt, err = srv.RefreshToken(creds, strings.Split(scopes, " "), token)
if err != nil {
writeTokenError(w, err, state)
return
......
......@@ -23,6 +23,7 @@ import (
"github.com/coreos/dex/connector"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh"
"github.com/coreos/dex/scope"
"github.com/coreos/dex/session"
sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user"
......@@ -53,7 +54,7 @@ type OIDCServer interface {
// RefreshToken takes a previously generated refresh token and returns a new ID token
// if the token is valid.
RefreshToken(creds oidc.ClientCredentials, token string) (*jose.JWT, error)
RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, error)
KillSession(string) error
......@@ -487,7 +488,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
if scope == "offline_access" {
log.Infof("Session %s requests offline access, will generate refresh token", sessionID)
refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID)
refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID, ses.Scope)
switch err {
case nil:
break
......@@ -503,7 +504,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
return jwt, refreshToken, nil
}
func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose.JWT, error) {
func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, error) {
ok, err := s.ClientManager.Authenticate(creds)
if err != nil {
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
......@@ -514,7 +515,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose
return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
}
userID, err := s.RefreshTokenRepo.Verify(creds.ID, token)
userID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token)
switch err {
case nil:
break
......@@ -526,6 +527,14 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
if len(scopes) == 0 {
scopes = rtScopes
} else {
if !rtScopes.Contains(scopes) {
return nil, oauth2.NewError(oauth2.ErrorInvalidRequest)
}
}
user, err := s.UserRepo.Get(nil, userID)
if err != nil {
// The error can be user.ErrorNotFound, but we are not deleting
......
......@@ -488,6 +488,8 @@ func TestServerRefreshToken(t *testing.T) {
clientID string // The client that associates with the token.
creds oidc.ClientCredentials
signer jose.Signer
createScopes []string
refreshScopes []string
err error
}{
// Everything is good.
......@@ -496,14 +498,28 @@ func TestServerRefreshToken(t *testing.T) {
testClientID,
testClientCredentials,
signerFixture,
[]string{"openid", "profile"},
[]string{"openid", "profile"},
nil,
},
// Asking for a scope not originally granted to you.
{
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
testClientID,
testClientCredentials,
signerFixture,
[]string{"openid", "profile"},
[]string{"openid", "profile", "extra_scope"},
oauth2.NewError(oauth2.ErrorInvalidRequest),
},
// Invalid refresh token(malformatted).
{
"invalid-token",
testClientID,
testClientCredentials,
signerFixture,
[]string{"openid", "profile"},
[]string{"openid", "profile"},
oauth2.NewError(oauth2.ErrorInvalidRequest),
},
// Invalid refresh token(invalid payload content).
......@@ -512,6 +528,8 @@ func TestServerRefreshToken(t *testing.T) {
testClientID,
testClientCredentials,
signerFixture,
[]string{"openid", "profile"},
[]string{"openid", "profile"},
oauth2.NewError(oauth2.ErrorInvalidRequest),
},
// Invalid refresh token(invalid ID content).
......@@ -520,6 +538,8 @@ func TestServerRefreshToken(t *testing.T) {
testClientID,
testClientCredentials,
signerFixture,
[]string{"openid", "profile"},
[]string{"openid", "profile"},
oauth2.NewError(oauth2.ErrorInvalidRequest),
},
// Invalid client(client is not associated with the token).
......@@ -528,6 +548,8 @@ func TestServerRefreshToken(t *testing.T) {
testClientID,
clientB.Credentials,
signerFixture,
[]string{"openid", "profile"},
[]string{"openid", "profile"},
oauth2.NewError(oauth2.ErrorInvalidClient),
},
// Invalid client(no client ID).
......@@ -536,6 +558,8 @@ func TestServerRefreshToken(t *testing.T) {
testClientID,
oidc.ClientCredentials{ID: "", Secret: "aaa"},
signerFixture,
[]string{"openid", "profile"},
[]string{"openid", "profile"},
oauth2.NewError(oauth2.ErrorInvalidClient),
},
// Invalid client(no such client).
......@@ -544,6 +568,8 @@ func TestServerRefreshToken(t *testing.T) {
testClientID,
oidc.ClientCredentials{ID: "AAA", Secret: "aaa"},
signerFixture,
[]string{"openid", "profile"},
[]string{"openid", "profile"},
oauth2.NewError(oauth2.ErrorInvalidClient),
},
// Invalid client(no secrets).
......@@ -552,6 +578,8 @@ func TestServerRefreshToken(t *testing.T) {
testClientID,
oidc.ClientCredentials{ID: testClientID},
signerFixture,
[]string{"openid", "profile"},
[]string{"openid", "profile"},
oauth2.NewError(oauth2.ErrorInvalidClient),
},
// Invalid client(invalid secret).
......@@ -560,6 +588,8 @@ func TestServerRefreshToken(t *testing.T) {
testClientID,
oidc.ClientCredentials{ID: "bad-id", Secret: "bad-secret"},
signerFixture,
[]string{"openid", "profile"},
[]string{"openid", "profile"},
oauth2.NewError(oauth2.ErrorInvalidClient),
},
// Signing operation fails.
......@@ -568,6 +598,8 @@ func TestServerRefreshToken(t *testing.T) {
testClientID,
testClientCredentials,
&StaticSigner{sig: nil, err: errors.New("fail")},
[]string{"openid", "profile"},
[]string{"openid", "profile"},
oauth2.NewError(oauth2.ErrorServerError),
},
}
......@@ -587,11 +619,12 @@ func TestServerRefreshToken(t *testing.T) {
t.Errorf("case %d: error creating other client: %v", i, err)
}
if _, err := f.srv.RefreshTokenRepo.Create(testUserID1, tt.clientID); err != nil {
if _, err := f.srv.RefreshTokenRepo.Create(testUserID1, tt.clientID,
tt.createScopes); err != nil {
t.Fatalf("Unexpected error: %v", err)
}
jwt, err := f.srv.RefreshToken(tt.creds, tt.token)
jwt, err := f.srv.RefreshToken(tt.creds, tt.refreshScopes, tt.token)
if !reflect.DeepEqual(err, tt.err) {
t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err)
}
......
......@@ -192,7 +192,7 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
}
refreshRepo := db.NewRefreshTokenRepo(dbMap)
for _, token := range refreshTokens {
if _, err := refreshRepo.Create(token.userID, token.clientID); err != nil {
if _, err := refreshRepo.Create(token.userID, token.clientID, []string{"openid"}); err != nil {
panic("Failed to create refresh token: " + err.Error())
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment