Commit 95560404 authored by Eric Chiang's avatar Eric Chiang

*: remove in memory refresh repo

parent 7bac93aa
...@@ -83,6 +83,13 @@ func NewRefreshTokenRepo(dbm *gorp.DbMap) refresh.RefreshTokenRepo { ...@@ -83,6 +83,13 @@ func NewRefreshTokenRepo(dbm *gorp.DbMap) refresh.RefreshTokenRepo {
} }
} }
func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenGenerator) refresh.RefreshTokenRepo {
return &refreshTokenRepo{
dbMap: dbm,
tokenGenerator: gen,
}
}
func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) { func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) {
if userID == "" { if userID == "" {
return "", refresh.ErrorInvalidUserID return "", refresh.ErrorInvalidUserID
......
...@@ -145,10 +145,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) { ...@@ -145,10 +145,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
} }
passwordInfoRepo := user.NewPasswordInfoRepo() passwordInfoRepo := user.NewPasswordInfoRepo()
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo() refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
srv := &server.Server{ srv := &server.Server{
IssuerURL: issuerURL, IssuerURL: issuerURL,
......
...@@ -3,16 +3,17 @@ package refreshtest ...@@ -3,16 +3,17 @@ package refreshtest
import ( import (
"fmt" "fmt"
"github.com/coreos/dex/db"
"github.com/coreos/dex/refresh" "github.com/coreos/dex/refresh"
) )
// NewTestRefreshTokenRepo returns a test repo whose tokens monotonically increase. // NewTestRefreshTokenRepo returns a test repo whose tokens monotonically increase.
// The tokens are in the form { refresh-1, refresh-2 ... refresh-n}. // The tokens are in the form { refresh-1, refresh-2 ... refresh-n}.
func NewTestRefreshTokenRepo() (refresh.RefreshTokenRepo, error) { func NewTestRefreshTokenRepo() refresh.RefreshTokenRepo {
var tokenIdx int var tokenIdx int
tokenGenerator := func() ([]byte, error) { tokenGenerator := func() ([]byte, error) {
tokenIdx++ tokenIdx++
return []byte(fmt.Sprintf("refresh-%d", tokenIdx)), nil return []byte(fmt.Sprintf("refresh-%d", tokenIdx)), nil
} }
return refresh.NewRefreshTokenRepoWithTokenGenerator(tokenGenerator), nil return db.NewRefreshTokenRepoWithGenerator(db.NewMemDB(), tokenGenerator)
} }
package refresh package refresh
import ( import (
"bytes"
"crypto/rand" "crypto/rand"
"encoding/base64"
"errors" "errors"
"fmt"
"strconv"
"strings"
) )
const ( const (
...@@ -53,121 +48,3 @@ type RefreshTokenRepo interface { ...@@ -53,121 +48,3 @@ type RefreshTokenRepo interface {
// Revoke deletes the refresh token if the token belongs to the given userID. // Revoke deletes the refresh token if the token belongs to the given userID.
Revoke(userID, token string) error Revoke(userID, token string) error
} }
type refreshToken struct {
payload []byte
userID string
clientID string
}
type memRefreshTokenRepo struct {
store map[int]refreshToken
tokenGenerator RefreshTokenGenerator
}
// buildToken combines the token ID and token payload to create a new token.
func buildToken(tokenID int, tokenPayload []byte) string {
return fmt.Sprintf("%d%s%s", tokenID, TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload))
}
// parseToken parses a token and returns the token ID and token payload.
func parseToken(token string) (int, []byte, error) {
parts := strings.SplitN(token, TokenDelimer, 2)
if len(parts) != 2 {
return -1, nil, ErrorInvalidToken
}
id, err := strconv.Atoi(parts[0])
if err != nil {
return -1, nil, ErrorInvalidToken
}
tokenPayload, err := base64.URLEncoding.DecodeString(parts[1])
if err != nil {
return -1, nil, ErrorInvalidToken
}
return id, tokenPayload, nil
}
// NewRefreshTokenRepo returns an in-memory RefreshTokenRepo useful for development.
func NewRefreshTokenRepo() RefreshTokenRepo {
return NewRefreshTokenRepoWithTokenGenerator(DefaultRefreshTokenGenerator)
}
func NewRefreshTokenRepoWithTokenGenerator(tokenGenerator RefreshTokenGenerator) RefreshTokenRepo {
repo := &memRefreshTokenRepo{}
repo.store = make(map[int]refreshToken)
repo.tokenGenerator = tokenGenerator
return repo
}
func (r *memRefreshTokenRepo) Create(userID, clientID string) (string, error) {
// Validate userID.
if userID == "" {
return "", ErrorInvalidUserID
}
// Validate clientID.
if clientID == "" {
return "", ErrorInvalidClientID
}
// Generate and store token.
tokenPayload, err := r.tokenGenerator.Generate()
if err != nil {
return "", err
}
tokenID := len(r.store) // Should only be used in single threaded tests.
// No limits on the number of tokens per user/client for this in-memory repo.
r.store[tokenID] = refreshToken{
payload: tokenPayload,
userID: userID,
clientID: clientID,
}
return buildToken(tokenID, tokenPayload), nil
}
func (r *memRefreshTokenRepo) Verify(clientID, token string) (string, error) {
tokenID, tokenPayload, err := parseToken(token)
if err != nil {
return "", err
}
record, ok := r.store[tokenID]
if !ok {
return "", ErrorInvalidToken
}
if !bytes.Equal(record.payload, tokenPayload) {
return "", ErrorInvalidToken
}
if record.clientID != clientID {
return "", ErrorInvalidClientID
}
return record.userID, nil
}
func (r *memRefreshTokenRepo) Revoke(userID, token string) error {
tokenID, tokenPayload, err := parseToken(token)
if err != nil {
return err
}
record, ok := r.store[tokenID]
if !ok {
return ErrorInvalidToken
}
if !bytes.Equal(record.payload, tokenPayload) {
return ErrorInvalidToken
}
if record.userID != userID {
return ErrorInvalidUserID
}
delete(r.store, tokenID)
return nil
}
...@@ -17,7 +17,6 @@ import ( ...@@ -17,7 +17,6 @@ import (
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/coreos/dex/email" "github.com/coreos/dex/email"
"github.com/coreos/dex/refresh"
"github.com/coreos/dex/repo" "github.com/coreos/dex/repo"
sessionmanager "github.com/coreos/dex/session/manager" sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
...@@ -139,7 +138,7 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error { ...@@ -139,7 +138,7 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
pwiRepo := user.NewPasswordInfoRepo() pwiRepo := user.NewPasswordInfoRepo()
refTokRepo := refresh.NewRefreshTokenRepo() refTokRepo := db.NewRefreshTokenRepo(db.NewMemDB())
txnFactory := repo.InMemTransactionFactory txnFactory := repo.InMemTransactionFactory
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{}) userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{})
......
...@@ -351,10 +351,7 @@ func TestServerCodeToken(t *testing.T) { ...@@ -351,10 +351,7 @@ func TestServerCodeToken(t *testing.T) {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo() refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
srv := &Server{ srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
...@@ -376,8 +373,10 @@ func TestServerCodeToken(t *testing.T) { ...@@ -376,8 +373,10 @@ func TestServerCodeToken(t *testing.T) {
}, },
// Have 'offline_access' in scope, should get non-empty refresh token. // Have 'offline_access' in scope, should get non-empty refresh token.
{ {
// NOTE(ericchiang): This test assumes that the database ID of the first
// refresh token will be "1".
scope: []string{"openid", "offline_access"}, scope: []string{"openid", "offline_access"},
refreshToken: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), refreshToken: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
}, },
} }
...@@ -475,11 +474,13 @@ func TestServerTokenFail(t *testing.T) { ...@@ -475,11 +474,13 @@ func TestServerTokenFail(t *testing.T) {
}{ }{
// control test case to make sure fixtures check out // control test case to make sure fixtures check out
{ {
// NOTE(ericchiang): This test assumes that the database ID of the first
// refresh token will be "1".
signer: signerFixture, signer: signerFixture,
argCC: ccFixture, argCC: ccFixture,
argKey: keyFixture, argKey: keyFixture,
scope: []string{"openid", "offline_access"}, scope: []string{"openid", "offline_access"},
refreshToken: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), refreshToken: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
}, },
// no 'offline_access' in 'scope', should get empty refresh token // no 'offline_access' in 'scope', should get empty refresh token
...@@ -549,10 +550,7 @@ func TestServerTokenFail(t *testing.T) { ...@@ -549,10 +550,7 @@ func TestServerTokenFail(t *testing.T) {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo() refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
srv := &Server{ srv := &Server{
IssuerURL: issuerURL, IssuerURL: issuerURL,
...@@ -600,6 +598,8 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -600,6 +598,8 @@ func TestServerRefreshToken(t *testing.T) {
signerFixture := &StaticSigner{sig: []byte("beer"), err: nil} signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}
// NOTE(ericchiang): These tests assume that the database ID of the first
// refresh token will be "1".
tests := []struct { tests := []struct {
token string token string
clientID string // The client that associates with the token. clientID string // The client that associates with the token.
...@@ -609,7 +609,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -609,7 +609,7 @@ func TestServerRefreshToken(t *testing.T) {
}{ }{
// Everything is good. // Everything is good.
{ {
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", "XXX",
credXXX, credXXX,
signerFixture, signerFixture,
...@@ -625,7 +625,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -625,7 +625,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid refresh token(invalid payload content). // Invalid refresh token(invalid payload content).
{ {
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))),
"XXX", "XXX",
credXXX, credXXX,
signerFixture, signerFixture,
...@@ -633,7 +633,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -633,7 +633,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid refresh token(invalid ID content). // Invalid refresh token(invalid ID content).
{ {
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", "XXX",
credXXX, credXXX,
signerFixture, signerFixture,
...@@ -641,7 +641,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -641,7 +641,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid client(client is not associated with the token). // Invalid client(client is not associated with the token).
{ {
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", "XXX",
credYYY, credYYY,
signerFixture, signerFixture,
...@@ -649,7 +649,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -649,7 +649,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid client(no client ID). // Invalid client(no client ID).
{ {
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", "XXX",
oidc.ClientCredentials{ID: "", Secret: "aaa"}, oidc.ClientCredentials{ID: "", Secret: "aaa"},
signerFixture, signerFixture,
...@@ -657,7 +657,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -657,7 +657,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid client(no such client). // Invalid client(no such client).
{ {
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", "XXX",
oidc.ClientCredentials{ID: "AAA", Secret: "aaa"}, oidc.ClientCredentials{ID: "AAA", Secret: "aaa"},
signerFixture, signerFixture,
...@@ -665,7 +665,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -665,7 +665,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid client(no secrets). // Invalid client(no secrets).
{ {
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", "XXX",
oidc.ClientCredentials{ID: "XXX"}, oidc.ClientCredentials{ID: "XXX"},
signerFixture, signerFixture,
...@@ -673,7 +673,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -673,7 +673,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid client(invalid secret). // Invalid client(invalid secret).
{ {
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", "XXX",
oidc.ClientCredentials{ID: "XXX", Secret: "bad-secret"}, oidc.ClientCredentials{ID: "XXX", Secret: "bad-secret"},
signerFixture, signerFixture,
...@@ -681,7 +681,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -681,7 +681,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Signing operation fails. // Signing operation fails.
{ {
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", "XXX",
credXXX, credXXX,
&StaticSigner{sig: nil, err: errors.New("fail")}, &StaticSigner{sig: nil, err: errors.New("fail")},
...@@ -704,10 +704,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -704,10 +704,7 @@ func TestServerRefreshToken(t *testing.T) {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo() refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
srv := &Server{ srv := &Server{
IssuerURL: issuerURL, IssuerURL: issuerURL,
...@@ -764,10 +761,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -764,10 +761,7 @@ func TestServerRefreshToken(t *testing.T) {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo() refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
srv := &Server{ srv := &Server{
IssuerURL: issuerURL, IssuerURL: issuerURL,
...@@ -788,7 +782,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -788,7 +782,7 @@ func TestServerRefreshToken(t *testing.T) {
} }
srv.UserRepo = userRepo srv.UserRepo = userRepo
_, err = srv.RefreshToken(credXXX, fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1")))) _, err = srv.RefreshToken(credXXX, fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))))
if !reflect.DeepEqual(err, oauth2.NewError(oauth2.ErrorServerError)) { if !reflect.DeepEqual(err, oauth2.NewError(oauth2.ErrorServerError)) {
t.Errorf("Expect: %v, got: %v", oauth2.NewError(oauth2.ErrorServerError), err) t.Errorf("Expect: %v, got: %v", oauth2.NewError(oauth2.ErrorServerError), err)
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment