Commit b7e19b6e authored by bobbyrullo's avatar bobbyrullo Committed by GitHub

Merge pull request #465 from bobbyrullo/cross_client_refresh_tokens

Cross client refresh tokens
parents ea2f0a32 75473b4c
...@@ -39,7 +39,8 @@ CREATE TABLE refresh_token ( ...@@ -39,7 +39,8 @@ CREATE TABLE refresh_token (
id integer PRIMARY KEY, id integer PRIMARY KEY,
payload_hash blob, payload_hash blob,
user_id text, user_id text,
client_id text client_id text,
scopes text
); );
CREATE TABLE remote_identity_mapping ( CREATE TABLE remote_identity_mapping (
......
-- +migrate Up
ALTER TABLE refresh_token ADD COLUMN "scopes" text;
UPDATE refresh_token SET scopes = 'openid profile email offline_access';
...@@ -78,5 +78,11 @@ var PostgresMigrations migrate.MigrationSource = &migrate.MemoryMigrationSource{ ...@@ -78,5 +78,11 @@ var PostgresMigrations migrate.MigrationSource = &migrate.MemoryMigrationSource{
"-- +migrate Up\nCREATE TABLE IF NOT EXISTS \"trusted_peers\" (\n \"client_id\" text not null,\n \"trusted_client_id\" text not null,\n primary key (\"client_id\", \"trusted_client_id\")) ;\n", "-- +migrate Up\nCREATE TABLE IF NOT EXISTS \"trusted_peers\" (\n \"client_id\" text not null,\n \"trusted_client_id\" text not null,\n primary key (\"client_id\", \"trusted_client_id\")) ;\n",
}, },
}, },
{
Id: "0013_add_scopes_to_refresh_tokens.sql",
Up: []string{
"-- +migrate Up\nALTER TABLE refresh_token ADD COLUMN \"scopes\" text;\n\nUPDATE refresh_token SET scopes = 'openid profile email offline_access';\n",
},
},
}, },
} }
...@@ -15,6 +15,7 @@ import ( ...@@ -15,6 +15,7 @@ import (
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh" "github.com/coreos/dex/refresh"
"github.com/coreos/dex/repo" "github.com/coreos/dex/repo"
"github.com/coreos/dex/scope"
) )
const ( const (
...@@ -38,10 +39,9 @@ type refreshTokenRepo struct { ...@@ -38,10 +39,9 @@ type refreshTokenRepo struct {
type refreshTokenModel struct { type refreshTokenModel struct {
ID int64 `db:"id"` ID int64 `db:"id"`
PayloadHash []byte `db:"payload_hash"` PayloadHash []byte `db:"payload_hash"`
// TODO(yifan): Use some sort of foreign key to manage database level
// data integrity.
UserID string `db:"user_id"` UserID string `db:"user_id"`
ClientID string `db:"client_id"` ClientID string `db:"client_id"`
Scopes string `db:"scopes"`
} }
// buildToken combines the token ID and token payload to create a new token. // buildToken combines the token ID and token payload to create a new token.
...@@ -89,7 +89,7 @@ func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenG ...@@ -89,7 +89,7 @@ func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenG
} }
} }
func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) { func (r *refreshTokenRepo) Create(userID, clientID string, scopes []string) (string, error) {
if userID == "" { if userID == "" {
return "", refresh.ErrorInvalidUserID return "", refresh.ErrorInvalidUserID
} }
...@@ -112,6 +112,7 @@ func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) { ...@@ -112,6 +112,7 @@ func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) {
PayloadHash: payloadHash, PayloadHash: payloadHash,
UserID: userID, UserID: userID,
ClientID: clientID, ClientID: clientID,
Scopes: strings.Join(scopes, " "),
} }
if err := r.executor(nil).Insert(record); err != nil { if err := r.executor(nil).Insert(record); err != nil {
...@@ -121,27 +122,32 @@ func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) { ...@@ -121,27 +122,32 @@ func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) {
return buildToken(record.ID, tokenPayload), nil return buildToken(record.ID, tokenPayload), nil
} }
func (r *refreshTokenRepo) Verify(clientID, token string) (string, error) { func (r *refreshTokenRepo) Verify(clientID, token string) (string, scope.Scopes, error) {
tokenID, tokenPayload, err := parseToken(token) tokenID, tokenPayload, err := parseToken(token)
if err != nil { if err != nil {
return "", err return "", nil, err
} }
record, err := r.get(nil, tokenID) record, err := r.get(nil, tokenID)
if err != nil { if err != nil {
return "", err return "", nil, err
} }
if record.ClientID != clientID { if record.ClientID != clientID {
return "", refresh.ErrorInvalidClientID return "", nil, refresh.ErrorInvalidClientID
} }
if err := checkTokenPayload(record.PayloadHash, tokenPayload); err != nil { if err := checkTokenPayload(record.PayloadHash, tokenPayload); err != nil {
return "", err return "", nil, err
} }
return record.UserID, nil var scopes []string
if len(record.Scopes) > 0 {
scopes = strings.Split(record.Scopes, " ")
}
return record.UserID, scopes, nil
} }
func (r *refreshTokenRepo) Revoke(userID, token string) error { func (r *refreshTokenRepo) Revoke(userID, token string) error {
...@@ -190,7 +196,6 @@ func (r *refreshTokenRepo) ClientsWithRefreshTokens(userID string) ([]client.Cli ...@@ -190,7 +196,6 @@ func (r *refreshTokenRepo) ClientsWithRefreshTokens(userID string) ([]client.Cli
q := `SELECT c.* FROM %s as c q := `SELECT c.* FROM %s as c
INNER JOIN %s as r ON c.id = r.client_id WHERE r.user_id = $1;` INNER JOIN %s as r ON c.id = r.client_id WHERE r.user_id = $1;`
q = fmt.Sprintf(q, r.quote(clientTableName), r.quote(refreshTokenTableName)) q = fmt.Sprintf(q, r.quote(clientTableName), r.quote(refreshTokenTableName))
var clients []clientModel var clients []clientModel
if _, err := r.executor(nil).Select(&clients, q, userID); err != nil { if _, err := r.executor(nil).Select(&clients, q, userID); err != nil {
return nil, err return nil, err
...@@ -206,6 +211,7 @@ func (r *refreshTokenRepo) ClientsWithRefreshTokens(userID string) ([]client.Cli ...@@ -206,6 +211,7 @@ func (r *refreshTokenRepo) ClientsWithRefreshTokens(userID string) ([]client.Cli
// Do not share the secret. // Do not share the secret.
c[i].Credentials.Secret = "" c[i].Credentials.Secret = ""
} }
return c, nil return c, nil
} }
......
package functional package functional
import ( import (
"encoding/base64"
"fmt" "fmt"
"net/url" "net/url"
"os" "os"
...@@ -16,7 +15,6 @@ import ( ...@@ -16,7 +15,6 @@ import (
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/client/manager" "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/coreos/dex/refresh"
"github.com/coreos/dex/session" "github.com/coreos/dex/session"
) )
...@@ -411,207 +409,3 @@ func TestDBClientAll(t *testing.T) { ...@@ -411,207 +409,3 @@ func TestDBClientAll(t *testing.T) {
t.Fatalf("Retrieved incorrect number of ClientIdentities: want=2 got=%d", count) t.Fatalf("Retrieved incorrect number of ClientIdentities: want=2 got=%d", count)
} }
} }
// buildRefreshToken combines the token ID and token payload to create a new token.
// used in the tests to created a refresh token.
func buildRefreshToken(tokenID int64, tokenPayload []byte) string {
return fmt.Sprintf("%d%s%s", tokenID, refresh.TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload))
}
func TestDBRefreshRepoCreate(t *testing.T) {
r := db.NewRefreshTokenRepo(connect(t))
tests := []struct {
userID string
clientID string
err error
}{
{
"",
"client-foo",
refresh.ErrorInvalidUserID,
},
{
"user-foo",
"",
refresh.ErrorInvalidClientID,
},
{
"user-foo",
"client-foo",
nil,
},
}
for i, tt := range tests {
token, err := r.Create(tt.userID, tt.clientID)
if err != nil {
if tt.err == nil {
t.Errorf("case %d: create failed: %v", i, err)
}
continue
}
if tt.err != nil {
t.Errorf("case %d: expected error, didn't get one", i)
continue
}
userID, err := r.Verify(tt.clientID, token)
if err != nil {
t.Errorf("case %d: failed to verify good token: %v", i, err)
continue
}
if userID != tt.userID {
t.Errorf("case %d: want userID=%s, got userID=%s", i, tt.userID, userID)
}
}
}
func TestDBRefreshRepoVerify(t *testing.T) {
r := db.NewRefreshTokenRepo(connect(t))
token, err := r.Create("user-foo", "client-foo")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
badTokenPayload, err := refresh.DefaultRefreshTokenGenerator()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
tokenWithBadID := "404" + token[1:]
tokenWithBadPayload := buildRefreshToken(1, badTokenPayload)
tests := []struct {
token string
creds oidc.ClientCredentials
err error
expected string
}{
{
"invalid-token-format",
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
"b/invalid-base64-encoded-format",
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
"1/invalid-base64-encoded-format",
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
token + "corrupted-token-payload",
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
// The token's ID content is invalid.
tokenWithBadID,
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
// The token's payload content is invalid.
tokenWithBadPayload,
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
token,
oidc.ClientCredentials{ID: "invalid-client", Secret: "secret-foo"},
refresh.ErrorInvalidClientID,
"",
},
{
token,
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
nil,
"user-foo",
},
}
for i, tt := range tests {
result, err := r.Verify(tt.creds.ID, tt.token)
if err != tt.err {
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
}
if result != tt.expected {
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.expected, result)
}
}
}
func TestDBRefreshRepoRevoke(t *testing.T) {
r := db.NewRefreshTokenRepo(connect(t))
token, err := r.Create("user-foo", "client-foo")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
badTokenPayload, err := refresh.DefaultRefreshTokenGenerator()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
tokenWithBadID := "404" + token[1:]
tokenWithBadPayload := buildRefreshToken(1, badTokenPayload)
tests := []struct {
token string
userID string
err error
}{
{
"invalid-token-format",
"user-foo",
refresh.ErrorInvalidToken,
},
{
"1/invalid-base64-encoded-format",
"user-foo",
refresh.ErrorInvalidToken,
},
{
token + "corrupted-token-payload",
"user-foo",
refresh.ErrorInvalidToken,
},
{
// The token's ID is invalid.
tokenWithBadID,
"user-foo",
refresh.ErrorInvalidToken,
},
{
// The token's payload is invalid.
tokenWithBadPayload,
"user-foo",
refresh.ErrorInvalidToken,
},
{
token,
"invalid-user",
refresh.ErrorInvalidUserID,
},
{
token,
"user-foo",
nil,
},
}
for i, tt := range tests {
if err := r.Revoke(tt.userID, tt.token); err != tt.err {
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
}
}
}
...@@ -2,13 +2,13 @@ package repo ...@@ -2,13 +2,13 @@ package repo
import ( import (
"encoding/base64" "encoding/base64"
"fmt"
"net/url" "net/url"
"os" "sort"
"testing" "testing"
"time" "time"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
"github.com/go-gorp/gorp"
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
...@@ -17,27 +17,14 @@ import ( ...@@ -17,27 +17,14 @@ import (
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
) )
func newRefreshRepo(t *testing.T, users []user.UserWithRemoteIdentities, clients []client.Client) refresh.RefreshTokenRepo { var (
var dbMap *gorp.DbMap testRefreshClientID = "client1"
if dsn := os.Getenv("DEX_TEST_DSN"); dsn == "" { testRefreshClientID2 = "client2"
dbMap = db.NewMemDB() testRefreshClients = []client.LoadableClient{
} else {
dbMap = connect(t)
}
if _, err := db.NewUserRepoFromUsers(dbMap, users); err != nil {
t.Fatalf("Unable to add users: %v", err)
}
return db.NewRefreshTokenRepo(dbMap)
}
func TestRefreshTokenRepo(t *testing.T) {
clientID := "client1"
userID := "user1"
clients := []client.Client{
{ {
Client: client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: clientID, ID: testRefreshClientID,
Secret: base64.URLEncoding.EncodeToString([]byte("secret-2")), Secret: base64.URLEncoding.EncodeToString([]byte("secret-2")),
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
...@@ -46,11 +33,27 @@ func TestRefreshTokenRepo(t *testing.T) { ...@@ -46,11 +33,27 @@ func TestRefreshTokenRepo(t *testing.T) {
}, },
}, },
}, },
},
{
Client: client.Client{
Credentials: oidc.ClientCredentials{
ID: testRefreshClientID2,
Secret: base64.URLEncoding.EncodeToString([]byte("secret-2")),
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
url.URL{Scheme: "https", Host: "client2.example.com", Path: "/callback"},
},
},
},
},
} }
users := []user.UserWithRemoteIdentities{
testRefreshUserID = "user1"
testRefreshUsers = []user.UserWithRemoteIdentities{
{ {
User: user.User{ User: user.User{
ID: userID, ID: testRefreshUserID,
Email: "Email-1@example.com", Email: "Email-1@example.com",
CreatedAt: time.Now().Truncate(time.Second), CreatedAt: time.Now().Truncate(time.Second),
}, },
...@@ -62,31 +65,318 @@ func TestRefreshTokenRepo(t *testing.T) { ...@@ -62,31 +65,318 @@ func TestRefreshTokenRepo(t *testing.T) {
}, },
}, },
} }
)
func newRefreshRepo(t *testing.T, users []user.UserWithRemoteIdentities, clients []client.LoadableClient) refresh.RefreshTokenRepo {
dbMap := connect(t)
if _, err := db.NewUserRepoFromUsers(dbMap, users); err != nil {
t.Fatalf("Unable to add users: %v", err)
}
if _, err := db.NewClientRepoFromClients(dbMap, clients); err != nil {
t.Fatalf("Unable to add clients: %v", err)
}
return db.NewRefreshTokenRepo(dbMap)
}
func TestRefreshTokenRepoCreateVerify(t *testing.T) {
tests := []struct {
createScopes []string
verifyClientID string
wantVerifyErr bool
}{
{
createScopes: []string{"openid", "profile"},
verifyClientID: testRefreshClientID,
},
{
createScopes: []string{},
verifyClientID: testRefreshClientID,
},
{
createScopes: []string{"openid", "profile"},
verifyClientID: "not-a-client",
wantVerifyErr: true,
},
}
for i, tt := range tests {
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
tok, err := repo.Create(testRefreshUserID, testRefreshClientID, tt.createScopes)
if err != nil {
t.Fatalf("case %d: failed to create refresh token: %v", i, err)
}
tokUserID, gotScopes, err := repo.Verify(tt.verifyClientID, tok)
if tt.wantVerifyErr {
if err == nil {
t.Errorf("case %d: want non-nil error.", i)
}
continue
}
if diff := pretty.Compare(tt.createScopes, gotScopes); diff != "" {
t.Errorf("case %d: Compare(want, got): %v", i, diff)
}
if err != nil {
t.Errorf("case %d: Could not verify token: %v", i, err)
} else if tokUserID != testRefreshUserID {
t.Errorf("case %d: Verified token returned wrong user id, want=%s, got=%s", i,
testRefreshUserID, tokUserID)
}
}
}
// buildRefreshToken combines the token ID and token payload to create a new token.
// used in the tests to created a refresh token.
func buildRefreshToken(tokenID int64, tokenPayload []byte) string {
return fmt.Sprintf("%d%s%s", tokenID, refresh.TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload))
}
func TestRefreshRepoVerifyInvalidTokens(t *testing.T) {
r := db.NewRefreshTokenRepo(connect(t))
token, err := r.Create("user-foo", "client-foo", oidc.DefaultScope)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
badTokenPayload, err := refresh.DefaultRefreshTokenGenerator()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
tokenWithBadID := "404" + token[1:]
tokenWithBadPayload := buildRefreshToken(1, badTokenPayload)
tests := []struct {
token string
creds oidc.ClientCredentials
err error
expected string
}{
{
"invalid-token-format",
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
"b/invalid-base64-encoded-format",
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
"1/invalid-base64-encoded-format",
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
token + "corrupted-token-payload",
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
// The token's ID content is invalid.
tokenWithBadID,
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
// The token's payload content is invalid.
tokenWithBadPayload,
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
token,
oidc.ClientCredentials{ID: "invalid-client", Secret: "secret-foo"},
refresh.ErrorInvalidClientID,
"",
},
{
token,
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
nil,
"user-foo",
},
}
for i, tt := range tests {
result, _, err := r.Verify(tt.creds.ID, tt.token)
if err != tt.err {
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
}
if result != tt.expected {
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.expected, result)
}
}
}
func TestRefreshTokenRepoClientsWithRefreshTokens(t *testing.T) {
tests := []struct {
clientIDs []string
}{
{clientIDs: []string{"client1", "client2"}},
{clientIDs: []string{"client1"}},
{clientIDs: []string{}},
}
for i, tt := range tests {
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
for _, clientID := range tt.clientIDs {
_, err := repo.Create(testRefreshUserID, clientID, []string{"openid"})
if err != nil {
t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err)
}
}
repo := newRefreshRepo(t, users, clients) clients, err := repo.ClientsWithRefreshTokens(testRefreshUserID)
tok, err := repo.Create(userID, clientID)
if err != nil { if err != nil {
t.Fatalf("failed to create refresh token: %v", err) t.Fatalf("case %d: unexpected error fetching clients %q", i, err)
}
var clientIDs []string
for _, client := range clients {
clientIDs = append(clientIDs, client.Credentials.ID)
} }
if tokUserID, err := repo.Verify(clientID, tok); err != nil { sort.Strings(clientIDs)
t.Errorf("Could not verify token: %v", err)
} else if tokUserID != userID { if diff := pretty.Compare(clientIDs, tt.clientIDs); diff != "" {
t.Errorf("Verified token returned wrong user id, want=%s, got=%s", userID, tokUserID) t.Errorf("case %d: Compare(want, got): %v", i, diff)
} }
}
}
if userClients, err := repo.ClientsWithRefreshTokens(userID); err != nil { func TestRefreshTokenRepoRevokeForClient(t *testing.T) {
t.Errorf("Failed to get the list of clients the user was logged into: %v", err) tests := []struct {
} else { createIDs []string
if diff := pretty.Compare(userClients, clients); diff == "" { revokeID string
t.Errorf("Clients user logged into: want did not equal got %s", diff) }{
{
createIDs: []string{"client1", "client2"},
revokeID: "client1",
},
{
createIDs: []string{"client2"},
revokeID: "client1",
},
{
createIDs: []string{"client1"},
revokeID: "client1",
},
{
createIDs: []string{},
revokeID: "oops",
},
} }
for i, tt := range tests {
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients)
for _, clientID := range tt.createIDs {
_, err := repo.Create(testRefreshUserID, clientID, []string{"openid"})
if err != nil {
t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err)
} }
if err := repo.RevokeTokensForClient(userID, clientID); err != nil { if err := repo.RevokeTokensForClient(testRefreshUserID, tt.revokeID); err != nil {
t.Errorf("Failed to revoke refresh token: %v", err) t.Fatalf("case %d: couldn't revoke refresh token(s): %v", i, err)
}
} }
if _, err := repo.Verify(clientID, tok); err == nil { var wantIDs []string
t.Errorf("Token which should have been revoked was verified") for _, id := range tt.createIDs {
if id != tt.revokeID {
wantIDs = append(wantIDs, id)
}
}
clients, err := repo.ClientsWithRefreshTokens(testRefreshUserID)
if err != nil {
t.Fatalf("case %d: unexpected error fetching clients %q", i, err)
}
var gotIDs []string
for _, client := range clients {
gotIDs = append(gotIDs, client.Credentials.ID)
}
sort.Strings(gotIDs)
if diff := pretty.Compare(wantIDs, gotIDs); diff != "" {
t.Errorf("case %d: Compare(wantIDs, gotIDs): %v", i, diff)
}
}
}
func TestRefreshRepoRevoke(t *testing.T) {
r := db.NewRefreshTokenRepo(connect(t))
token, err := r.Create("user-foo", "client-foo", oidc.DefaultScope)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
badTokenPayload, err := refresh.DefaultRefreshTokenGenerator()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
tokenWithBadID := "404" + token[1:]
tokenWithBadPayload := buildRefreshToken(1, badTokenPayload)
tests := []struct {
token string
userID string
err error
}{
{
"invalid-token-format",
"user-foo",
refresh.ErrorInvalidToken,
},
{
"1/invalid-base64-encoded-format",
"user-foo",
refresh.ErrorInvalidToken,
},
{
token + "corrupted-token-payload",
"user-foo",
refresh.ErrorInvalidToken,
},
{
// The token's ID is invalid.
tokenWithBadID,
"user-foo",
refresh.ErrorInvalidToken,
},
{
// The token's payload is invalid.
tokenWithBadPayload,
"user-foo",
refresh.ErrorInvalidToken,
},
{
token,
"invalid-user",
refresh.ErrorInvalidUserID,
},
{
token,
"user-foo",
nil,
},
}
for i, tt := range tests {
if err := r.Revoke(tt.userID, tt.token); err != tt.err {
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err)
}
} }
} }
...@@ -12,7 +12,8 @@ import ( ...@@ -12,7 +12,8 @@ import (
func connect(t *testing.T) *gorp.DbMap { func connect(t *testing.T) *gorp.DbMap {
dsn := os.Getenv("DEX_TEST_DSN") dsn := os.Getenv("DEX_TEST_DSN")
if dsn == "" { if dsn == "" {
t.Fatal("DEX_TEST_DSN environment variable not set") return db.NewMemDB()
} }
c, err := db.NewConnection(db.Config{DSN: dsn}) c, err := db.NewConnection(db.Config{DSN: dsn})
if err != nil { if err != nil {
......
...@@ -231,7 +231,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) { ...@@ -231,7 +231,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
// this will actually happen due to some interaction between the // this will actually happen due to some interaction between the
// end-user and a remote identity provider // end-user and a remote identity provider
sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false, []string{"openid", "offline_access"}) sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false, []string{"openid", "offline_access", "email", "profile"})
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
......
...@@ -148,7 +148,8 @@ func makeUserAPITestFixtures() *userAPITestFixtures { ...@@ -148,7 +148,8 @@ func makeUserAPITestFixtures() *userAPITestFixtures {
refreshRepo := db.NewRefreshTokenRepo(dbMap) refreshRepo := db.NewRefreshTokenRepo(dbMap)
for _, user := range userUsers { for _, user := range userUsers {
if _, err := refreshRepo.Create(user.User.ID, testClientID); err != nil { if _, err := refreshRepo.Create(user.User.ID, testClientID,
append([]string{"offline_access"}, oidc.DefaultScope...)); err != nil {
panic("Failed to create refresh token: " + err.Error()) panic("Failed to create refresh token: " + err.Error())
} }
} }
......
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"errors" "errors"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/scope"
) )
const ( const (
...@@ -40,12 +41,15 @@ func DefaultRefreshTokenGenerator() ([]byte, error) { ...@@ -40,12 +41,15 @@ func DefaultRefreshTokenGenerator() ([]byte, error) {
type RefreshTokenRepo interface { type RefreshTokenRepo interface {
// Create generates and returns a new refresh token for the given client-user pair. // Create generates and returns a new refresh token for the given client-user pair.
// On success the token will be return. // The scopes will be stored with the refresh token, and used to verify
Create(userID, clientID string) (string, error) // against future OIDC refresh requests' scopes.
// On success the token will be returned.
// Verify verifies that a token belongs to the client, and returns the corresponding user ID. Create(userID, clientID string, scope []string) (string, error)
// Note that this assumes the client validation is currently done in the application layer,
Verify(clientID, token string) (string, error) // Verify verifies that a token belongs to the client.
// It returns the user ID to which the token belongs, and the scopes stored
// with token.
Verify(clientID, token string) (string, scope.Scopes, error)
// Revoke deletes the refresh token if the token belongs to the given userID. // Revoke deletes the refresh token if the token belongs to the given userID.
Revoke(userID, token string) error Revoke(userID, token string) error
......
...@@ -32,3 +32,17 @@ func (s Scopes) CrossClientIDs() []string { ...@@ -32,3 +32,17 @@ func (s Scopes) CrossClientIDs() []string {
} }
return clients return clients
} }
func (s Scopes) Contains(other Scopes) bool {
rScopes := map[string]struct{}{}
for _, scope := range s {
rScopes[scope] = struct{}{}
}
for _, scope := range other {
if _, ok := rScopes[scope]; !ok {
return false
}
}
return true
}
...@@ -14,29 +14,24 @@ import ( ...@@ -14,29 +14,24 @@ import (
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/scope" "github.com/coreos/dex/scope"
) )
func makeCrossClientTestFixtures() (*testFixtures, error) { func makeCrossClientTestFixtures() (*testFixtures, error) {
f, err := makeTestFixtures() xClients := []client.LoadableClient{}
if err != nil {
return nil, fmt.Errorf("couldn't make test fixtures: %v", err)
}
for _, cliData := range []struct { for _, cliData := range []struct {
id string id string
authorized []string trustedPeers []string
}{ }{
{ {
id: "client_a", id: "client_a",
}, { }, {
id: "client_b", id: "client_b",
authorized: []string{"client_a"}, trustedPeers: []string{"client_a"},
}, { }, {
id: "client_c", id: "client_c",
authorized: []string{"client_a", "client_b"}, trustedPeers: []string{"client_a", "client_b"},
}, },
} { } {
u := url.URL{ u := url.URL{
...@@ -44,20 +39,27 @@ func makeCrossClientTestFixtures() (*testFixtures, error) { ...@@ -44,20 +39,27 @@ func makeCrossClientTestFixtures() (*testFixtures, error) {
Path: cliData.id, Path: cliData.id,
Host: cliData.id, Host: cliData.id,
} }
cliCreds, err := f.clientManager.New(client.Client{ xClients = append(xClients, client.LoadableClient{
Client: client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: cliData.id, ID: cliData.id,
Secret: base64.URLEncoding.EncodeToString(
[]byte(cliData.id + "_secret")),
}, },
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{u}, RedirectURIs: []url.URL{u},
}, },
}, &clientmanager.ClientOptions{ },
TrustedPeers: cliData.authorized, TrustedPeers: cliData.trustedPeers,
}) })
if err != nil {
return nil, fmt.Errorf("Unexpected error creating clients: %v", err)
} }
f.clientCreds[cliData.id] = *cliCreds
xClients = append(xClients, testClients...)
f, err := makeTestFixturesWithOptions(testFixtureOptions{
clients: xClients,
})
if err != nil {
return nil, fmt.Errorf("couldn't make test fixtures: %v", err)
} }
return f, nil return f, nil
} }
......
...@@ -518,11 +518,12 @@ func handleTokenFunc(srv OIDCServer) http.HandlerFunc { ...@@ -518,11 +518,12 @@ func handleTokenFunc(srv OIDCServer) http.HandlerFunc {
} }
case oauth2.GrantTypeRefreshToken: case oauth2.GrantTypeRefreshToken:
token := r.PostForm.Get("refresh_token") token := r.PostForm.Get("refresh_token")
scopes := r.PostForm.Get("scope")
if token == "" { if token == "" {
writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state) writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state)
return return
} }
jwt, err = srv.RefreshToken(creds, token) jwt, err = srv.RefreshToken(creds, strings.Split(scopes, " "), token)
if err != nil { if err != nil {
writeTokenError(w, err, state) writeTokenError(w, err, state)
return return
......
...@@ -23,6 +23,7 @@ import ( ...@@ -23,6 +23,7 @@ import (
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh" "github.com/coreos/dex/refresh"
"github.com/coreos/dex/scope"
"github.com/coreos/dex/session" "github.com/coreos/dex/session"
sessionmanager "github.com/coreos/dex/session/manager" sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
...@@ -53,7 +54,7 @@ type OIDCServer interface { ...@@ -53,7 +54,7 @@ type OIDCServer interface {
// RefreshToken takes a previously generated refresh token and returns a new ID token // RefreshToken takes a previously generated refresh token and returns a new ID token
// if the token is valid. // if the token is valid.
RefreshToken(creds oidc.ClientCredentials, token string) (*jose.JWT, error) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, error)
KillSession(string) error KillSession(string) error
...@@ -444,35 +445,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo ...@@ -444,35 +445,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
claims := ses.Claims(s.IssuerURL.String()) claims := ses.Claims(s.IssuerURL.String())
user.AddToClaims(claims) user.AddToClaims(claims)
crossClientIDs := ses.Scope.CrossClientIDs() s.addClaimsFromScope(claims, ses.Scope, ses.ClientID)
if len(crossClientIDs) > 0 {
var aud []string
for _, id := range crossClientIDs {
if ses.ClientID == id {
aud = append(aud, id)
continue
}
allowed, err := s.CrossClientAuthAllowed(ses.ClientID, id)
if err != nil {
log.Errorf("Failed to check cross client auth. reqClientID %v; authClient:ID %v; err: %v", ses.ClientID, id, err)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}
if !allowed {
err := oauth2.NewError(oauth2.ErrorInvalidRequest)
err.Description = fmt.Sprintf(
"%q is not authorized to perform cross-client requests for %q",
ses.ClientID, id)
return nil, "", err
}
aud = append(aud, id)
}
if len(aud) == 1 {
claims.Add("aud", aud[0])
} else {
claims.Add("aud", aud)
}
claims.Add("azp", ses.ClientID)
}
jwt, err := jose.NewSignedJWT(claims, signer) jwt, err := jose.NewSignedJWT(claims, signer)
if err != nil { if err != nil {
...@@ -487,7 +460,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo ...@@ -487,7 +460,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
if scope == "offline_access" { if scope == "offline_access" {
log.Infof("Session %s requests offline access, will generate refresh token", sessionID) log.Infof("Session %s requests offline access, will generate refresh token", sessionID)
refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID) refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID, ses.Scope)
switch err { switch err {
case nil: case nil:
break break
...@@ -503,7 +476,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo ...@@ -503,7 +476,7 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
return jwt, refreshToken, nil return jwt, refreshToken, nil
} }
func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose.JWT, error) { func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, error) {
ok, err := s.ClientManager.Authenticate(creds) ok, err := s.ClientManager.Authenticate(creds)
if err != nil { if err != nil {
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err) log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
...@@ -514,7 +487,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose ...@@ -514,7 +487,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose
return nil, oauth2.NewError(oauth2.ErrorInvalidClient) return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
} }
userID, err := s.RefreshTokenRepo.Verify(creds.ID, token) userID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token)
switch err { switch err {
case nil: case nil:
break break
...@@ -526,6 +499,14 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose ...@@ -526,6 +499,14 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose
return nil, oauth2.NewError(oauth2.ErrorServerError) return nil, oauth2.NewError(oauth2.ErrorServerError)
} }
if len(scopes) == 0 {
scopes = rtScopes
} else {
if !rtScopes.Contains(scopes) {
return nil, oauth2.NewError(oauth2.ErrorInvalidRequest)
}
}
user, err := s.UserRepo.Get(nil, userID) user, err := s.UserRepo.Get(nil, userID)
if err != nil { if err != nil {
// The error can be user.ErrorNotFound, but we are not deleting // The error can be user.ErrorNotFound, but we are not deleting
...@@ -546,6 +527,8 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose ...@@ -546,6 +527,8 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose
claims := oidc.NewClaims(s.IssuerURL.String(), user.ID, creds.ID, now, expireAt) claims := oidc.NewClaims(s.IssuerURL.String(), user.ID, creds.ID, now, expireAt)
user.AddToClaims(claims) user.AddToClaims(claims)
s.addClaimsFromScope(claims, scope.Scopes(scopes), creds.ID)
jwt, err := jose.NewSignedJWT(claims, signer) jwt, err := jose.NewSignedJWT(claims, signer)
if err != nil { if err != nil {
log.Errorf("Failed to generate ID token: %v", err) log.Errorf("Failed to generate ID token: %v", err)
...@@ -587,6 +570,41 @@ func (s *Server) JWTVerifierFactory() JWTVerifierFactory { ...@@ -587,6 +570,41 @@ func (s *Server) JWTVerifierFactory() JWTVerifierFactory {
} }
} }
// addClaimsFromScope adds claims that are based on the scopes that the client requested.
// Currently, these include cross-client claims (aud, azp).
func (s *Server) addClaimsFromScope(claims jose.Claims, scopes scope.Scopes, clientID string) error {
crossClientIDs := scopes.CrossClientIDs()
if len(crossClientIDs) > 0 {
var aud []string
for _, id := range crossClientIDs {
if clientID == id {
aud = append(aud, id)
continue
}
allowed, err := s.CrossClientAuthAllowed(clientID, id)
if err != nil {
log.Errorf("Failed to check cross client auth. reqClientID %v; authClient:ID %v; err: %v", clientID, id, err)
return oauth2.NewError(oauth2.ErrorServerError)
}
if !allowed {
err := oauth2.NewError(oauth2.ErrorInvalidRequest)
err.Description = fmt.Sprintf(
"%q is not authorized to perform cross-client requests for %q",
clientID, id)
return err
}
aud = append(aud, id)
}
if len(aud) == 1 {
claims.Add("aud", aud[0])
} else {
claims.Add("aud", aud)
}
claims.Add("azp", clientID)
}
return nil
}
type sortableIDPCs []connector.Connector type sortableIDPCs []connector.Connector
func (s sortableIDPCs) Len() int { func (s sortableIDPCs) Len() int {
......
...@@ -18,6 +18,7 @@ import ( ...@@ -18,6 +18,7 @@ import (
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/coreos/dex/refresh/refreshtest" "github.com/coreos/dex/refresh/refreshtest"
"github.com/coreos/dex/scope"
"github.com/coreos/dex/session/manager" "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
) )
...@@ -488,87 +489,193 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -488,87 +489,193 @@ func TestServerRefreshToken(t *testing.T) {
clientID string // The client that associates with the token. clientID string // The client that associates with the token.
creds oidc.ClientCredentials creds oidc.ClientCredentials
signer jose.Signer signer jose.Signer
createScopes []string
refreshScopes []string
expectedAud []string
err error err error
}{ }{
// Everything is good. // Everything is good.
{ {
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
testClientID, clientID: testClientID,
testClientCredentials, creds: testClientCredentials,
signerFixture, signer: signerFixture,
nil, createScopes: []string{"openid", "profile"},
refreshScopes: []string{"openid", "profile"},
},
// Asking for a scope not originally granted to you.
{
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
clientID: testClientID,
creds: testClientCredentials,
signer: signerFixture,
createScopes: []string{"openid", "profile"},
refreshScopes: []string{"openid", "profile", "extra_scope"},
err: oauth2.NewError(oauth2.ErrorInvalidRequest),
}, },
// Invalid refresh token(malformatted). // Invalid refresh token(malformatted).
{ {
"invalid-token", token: "invalid-token",
testClientID, clientID: testClientID,
testClientCredentials, creds: testClientCredentials,
signerFixture, signer: signerFixture,
oauth2.NewError(oauth2.ErrorInvalidRequest), createScopes: []string{"openid", "profile"},
refreshScopes: []string{"openid", "profile"},
err: oauth2.NewError(oauth2.ErrorInvalidRequest),
}, },
// Invalid refresh token(invalid payload content). // Invalid refresh token(invalid payload content).
{ {
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))), token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))),
testClientID, clientID: testClientID,
testClientCredentials, creds: testClientCredentials,
signerFixture, signer: signerFixture,
oauth2.NewError(oauth2.ErrorInvalidRequest), createScopes: []string{"openid", "profile"},
refreshScopes: []string{"openid", "profile"},
err: oauth2.NewError(oauth2.ErrorInvalidRequest),
}, },
// Invalid refresh token(invalid ID content). // Invalid refresh token(invalid ID content).
{ {
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), token: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
testClientID, clientID: testClientID,
testClientCredentials, creds: testClientCredentials,
signerFixture, signer: signerFixture,
oauth2.NewError(oauth2.ErrorInvalidRequest), createScopes: []string{"openid", "profile"},
refreshScopes: []string{"openid", "profile"},
err: oauth2.NewError(oauth2.ErrorInvalidRequest),
}, },
// Invalid client(client is not associated with the token). // Invalid client(client is not associated with the token).
{ {
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
testClientID, clientID: testClientID,
clientB.Credentials, creds: clientB.Credentials,
signerFixture, signer: signerFixture,
oauth2.NewError(oauth2.ErrorInvalidClient), createScopes: []string{"openid", "profile"},
refreshScopes: []string{"openid", "profile"},
err: oauth2.NewError(oauth2.ErrorInvalidClient),
}, },
// Invalid client(no client ID). // Invalid client(no client ID).
{ {
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
testClientID, clientID: testClientID,
oidc.ClientCredentials{ID: "", Secret: "aaa"}, creds: oidc.ClientCredentials{ID: "", Secret: "aaa"},
signerFixture, signer: signerFixture,
oauth2.NewError(oauth2.ErrorInvalidClient), createScopes: []string{"openid", "profile"},
refreshScopes: []string{"openid", "profile"},
err: oauth2.NewError(oauth2.ErrorInvalidClient),
}, },
// Invalid client(no such client). // Invalid client(no such client).
{ {
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
testClientID, clientID: testClientID,
oidc.ClientCredentials{ID: "AAA", Secret: "aaa"}, creds: oidc.ClientCredentials{ID: "AAA", Secret: "aaa"},
signerFixture, signer: signerFixture,
oauth2.NewError(oauth2.ErrorInvalidClient), createScopes: []string{"openid", "profile"},
refreshScopes: []string{"openid", "profile"},
err: oauth2.NewError(oauth2.ErrorInvalidClient),
}, },
// Invalid client(no secrets). // Invalid client(no secrets).
{ {
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
testClientID, clientID: testClientID,
oidc.ClientCredentials{ID: testClientID}, creds: oidc.ClientCredentials{ID: testClientID},
signerFixture, signer: signerFixture,
oauth2.NewError(oauth2.ErrorInvalidClient), createScopes: []string{"openid", "profile"},
refreshScopes: []string{"openid", "profile"},
err: oauth2.NewError(oauth2.ErrorInvalidClient),
}, },
// Invalid client(invalid secret). // Invalid client(invalid secret).
{ {
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
testClientID, clientID: testClientID,
oidc.ClientCredentials{ID: "bad-id", Secret: "bad-secret"}, creds: oidc.ClientCredentials{ID: "bad-id", Secret: "bad-secret"},
signerFixture, signer: signerFixture,
oauth2.NewError(oauth2.ErrorInvalidClient), createScopes: []string{"openid", "profile"},
refreshScopes: []string{"openid", "profile"},
err: oauth2.NewError(oauth2.ErrorInvalidClient),
}, },
// Signing operation fails. // Signing operation fails.
{ {
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
testClientID, clientID: testClientID,
testClientCredentials, creds: testClientCredentials,
&StaticSigner{sig: nil, err: errors.New("fail")}, signer: &StaticSigner{sig: nil, err: errors.New("fail")},
oauth2.NewError(oauth2.ErrorServerError), createScopes: []string{"openid", "profile"},
refreshScopes: []string{"openid", "profile"},
err: oauth2.NewError(oauth2.ErrorServerError),
},
// Valid Cross-Client
{
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
clientID: "client_a",
creds: oidc.ClientCredentials{
ID: "client_a",
Secret: base64.URLEncoding.EncodeToString(
[]byte("client_a_secret")),
},
signer: signerFixture,
createScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"},
refreshScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"},
expectedAud: []string{"client_b"},
},
// Valid Cross-Client - but this time we leave out the scopes in the
// refresh request, which should result in the original stored scopes
// being used.
{
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
clientID: "client_a",
creds: oidc.ClientCredentials{
ID: "client_a",
Secret: base64.URLEncoding.EncodeToString(
[]byte("client_a_secret")),
},
signer: signerFixture,
createScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"},
refreshScopes: []string{},
expectedAud: []string{"client_b"},
},
// Valid Cross-Client - asking for fewer scopes than originally used
// when creating the refresh token, which is ok.
{
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
clientID: "client_a",
creds: oidc.ClientCredentials{
ID: "client_a",
Secret: base64.URLEncoding.EncodeToString(
[]byte("client_a_secret")),
},
signer: signerFixture,
createScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b", scope.ScopeGoogleCrossClient + "client_c"},
refreshScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"},
expectedAud: []string{"client_b"},
},
// Valid Cross-Client - asking for multiple clients in the audience.
{
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
clientID: "client_a",
creds: oidc.ClientCredentials{
ID: "client_a",
Secret: base64.URLEncoding.EncodeToString(
[]byte("client_a_secret")),
},
signer: signerFixture,
createScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b", scope.ScopeGoogleCrossClient + "client_c"},
refreshScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b", scope.ScopeGoogleCrossClient + "client_c"},
expectedAud: []string{"client_b", "client_c"},
},
// Invalid Cross-Client - didn't orignally request cross-client when
// refresh token was created.
{
token: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
clientID: "client_a",
creds: oidc.ClientCredentials{
ID: "client_a",
Secret: base64.URLEncoding.EncodeToString(
[]byte("client_a_secret")),
},
signer: signerFixture,
createScopes: []string{"openid", "profile"},
refreshScopes: []string{"openid", "profile", scope.ScopeGoogleCrossClient + "client_b"},
err: oauth2.NewError(oauth2.ErrorInvalidRequest),
}, },
} }
...@@ -576,7 +683,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -576,7 +683,7 @@ func TestServerRefreshToken(t *testing.T) {
km := &StaticKeyManager{ km := &StaticKeyManager{
signer: tt.signer, signer: tt.signer,
} }
f, err := makeTestFixtures() f, err := makeCrossClientTestFixtures()
if err != nil { if err != nil {
t.Fatalf("error making test fixtures: %v", err) t.Fatalf("error making test fixtures: %v", err)
} }
...@@ -587,11 +694,12 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -587,11 +694,12 @@ func TestServerRefreshToken(t *testing.T) {
t.Errorf("case %d: error creating other client: %v", i, err) t.Errorf("case %d: error creating other client: %v", i, err)
} }
if _, err := f.srv.RefreshTokenRepo.Create(testUserID1, tt.clientID); err != nil { if _, err := f.srv.RefreshTokenRepo.Create(testUserID1, tt.clientID,
tt.createScopes); err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
jwt, err := f.srv.RefreshToken(tt.creds, tt.token) jwt, err := f.srv.RefreshToken(tt.creds, tt.refreshScopes, tt.token)
if !reflect.DeepEqual(err, tt.err) { if !reflect.DeepEqual(err, tt.err) {
t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err) t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err)
} }
...@@ -604,8 +712,27 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -604,8 +712,27 @@ func TestServerRefreshToken(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("Case %d: unexpected error: %v", i, err) t.Errorf("Case %d: unexpected error: %v", i, err)
} }
if claims["iss"] != testIssuerURL.String() || claims["sub"] != testUserID1 || claims["aud"] != testClientID {
t.Errorf("Case %d: invalid claims: %v", i, claims) var expectedAud interface{}
if tt.expectedAud == nil {
expectedAud = testClientID
} else if len(tt.expectedAud) == 1 {
expectedAud = tt.expectedAud[0]
} else {
expectedAud = tt.expectedAud
}
if claims["iss"] != testIssuerURL.String() {
t.Errorf("Case %d: want=%v, got=%v", i,
testIssuerURL.String(), claims["iss"])
}
if claims["sub"] != testUserID1 {
t.Errorf("Case %d: want=%v, got=%v", i,
testUserID1, claims["sub"])
}
if diff := pretty.Compare(claims["aud"], expectedAud); diff != "" {
t.Errorf("Case %d: want=%v, got=%v", i,
expectedAud, claims["aud"])
} }
} }
} }
......
...@@ -39,6 +39,18 @@ var ( ...@@ -39,6 +39,18 @@ var (
ID: testClientID, ID: testClientID,
Secret: clientTestSecret, Secret: clientTestSecret,
} }
testClients = []client.LoadableClient{
{
Client: client.Client{
Credentials: testClientCredentials,
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
testRedirectURL,
},
},
},
},
}
testConnectorID1 = "IDPC-1" testConnectorID1 = "IDPC-1"
...@@ -169,18 +181,7 @@ func makeTestFixturesWithOptions(options testFixtureOptions) (*testFixtures, err ...@@ -169,18 +181,7 @@ func makeTestFixturesWithOptions(options testFixtureOptions) (*testFixtures, err
var clients []client.LoadableClient var clients []client.LoadableClient
if options.clients == nil { if options.clients == nil {
clients = []client.LoadableClient{ clients = testClients
{
Client: client.Client{
Credentials: testClientCredentials,
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
testRedirectURL,
},
},
},
},
}
} else { } else {
clients = options.clients clients = options.clients
} }
...@@ -247,6 +248,10 @@ func makeTestFixturesWithOptions(options testFixtureOptions) (*testFixtures, err ...@@ -247,6 +248,10 @@ func makeTestFixturesWithOptions(options testFixtureOptions) (*testFixtures, err
srv.absURL(httpPathAcceptInvitation), srv.absURL(httpPathAcceptInvitation),
) )
clientCreds := map[string]oidc.ClientCredentials{}
for _, c := range clients {
clientCreds[c.Client.Credentials.ID] = c.Client.Credentials
}
return &testFixtures{ return &testFixtures{
srv: srv, srv: srv,
redirectURL: testRedirectURL, redirectURL: testRedirectURL,
...@@ -255,9 +260,7 @@ func makeTestFixturesWithOptions(options testFixtureOptions) (*testFixtures, err ...@@ -255,9 +260,7 @@ func makeTestFixturesWithOptions(options testFixtureOptions) (*testFixtures, err
emailer: emailer, emailer: emailer,
clientRepo: clientRepo, clientRepo: clientRepo,
clientManager: clientManager, clientManager: clientManager,
clientCreds: map[string]oidc.ClientCredentials{ clientCreds: clientCreds,
testClientID: testClientCreds,
},
}, nil }, nil
} }
......
...@@ -192,7 +192,7 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) { ...@@ -192,7 +192,7 @@ func makeTestFixtures() (*UsersAPI, *testEmailer) {
} }
refreshRepo := db.NewRefreshTokenRepo(dbMap) refreshRepo := db.NewRefreshTokenRepo(dbMap)
for _, token := range refreshTokens { for _, token := range refreshTokens {
if _, err := refreshRepo.Create(token.userID, token.clientID); err != nil { if _, err := refreshRepo.Create(token.userID, token.clientID, []string{"openid"}); err != nil {
panic("Failed to create refresh token: " + err.Error()) panic("Failed to create refresh token: " + err.Error())
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment