Commit 95560404 authored by Eric Chiang's avatar Eric Chiang

*: remove in memory refresh repo

parent 7bac93aa
......@@ -83,6 +83,13 @@ func NewRefreshTokenRepo(dbm *gorp.DbMap) refresh.RefreshTokenRepo {
}
}
func NewRefreshTokenRepoWithGenerator(dbm *gorp.DbMap, gen refresh.RefreshTokenGenerator) refresh.RefreshTokenRepo {
return &refreshTokenRepo{
dbMap: dbm,
tokenGenerator: gen,
}
}
func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) {
if userID == "" {
return "", refresh.ErrorInvalidUserID
......
......@@ -145,10 +145,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
}
passwordInfoRepo := user.NewPasswordInfoRepo()
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
srv := &server.Server{
IssuerURL: issuerURL,
......
......@@ -3,16 +3,17 @@ package refreshtest
import (
"fmt"
"github.com/coreos/dex/db"
"github.com/coreos/dex/refresh"
)
// NewTestRefreshTokenRepo returns a test repo whose tokens monotonically increase.
// The tokens are in the form { refresh-1, refresh-2 ... refresh-n}.
func NewTestRefreshTokenRepo() (refresh.RefreshTokenRepo, error) {
func NewTestRefreshTokenRepo() refresh.RefreshTokenRepo {
var tokenIdx int
tokenGenerator := func() ([]byte, error) {
tokenIdx++
return []byte(fmt.Sprintf("refresh-%d", tokenIdx)), nil
}
return refresh.NewRefreshTokenRepoWithTokenGenerator(tokenGenerator), nil
return db.NewRefreshTokenRepoWithGenerator(db.NewMemDB(), tokenGenerator)
}
package refresh
import (
"bytes"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"strconv"
"strings"
)
const (
......@@ -53,121 +48,3 @@ type RefreshTokenRepo interface {
// Revoke deletes the refresh token if the token belongs to the given userID.
Revoke(userID, token string) error
}
type refreshToken struct {
payload []byte
userID string
clientID string
}
type memRefreshTokenRepo struct {
store map[int]refreshToken
tokenGenerator RefreshTokenGenerator
}
// buildToken combines the token ID and token payload to create a new token.
func buildToken(tokenID int, tokenPayload []byte) string {
return fmt.Sprintf("%d%s%s", tokenID, TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload))
}
// parseToken parses a token and returns the token ID and token payload.
func parseToken(token string) (int, []byte, error) {
parts := strings.SplitN(token, TokenDelimer, 2)
if len(parts) != 2 {
return -1, nil, ErrorInvalidToken
}
id, err := strconv.Atoi(parts[0])
if err != nil {
return -1, nil, ErrorInvalidToken
}
tokenPayload, err := base64.URLEncoding.DecodeString(parts[1])
if err != nil {
return -1, nil, ErrorInvalidToken
}
return id, tokenPayload, nil
}
// NewRefreshTokenRepo returns an in-memory RefreshTokenRepo useful for development.
func NewRefreshTokenRepo() RefreshTokenRepo {
return NewRefreshTokenRepoWithTokenGenerator(DefaultRefreshTokenGenerator)
}
func NewRefreshTokenRepoWithTokenGenerator(tokenGenerator RefreshTokenGenerator) RefreshTokenRepo {
repo := &memRefreshTokenRepo{}
repo.store = make(map[int]refreshToken)
repo.tokenGenerator = tokenGenerator
return repo
}
func (r *memRefreshTokenRepo) Create(userID, clientID string) (string, error) {
// Validate userID.
if userID == "" {
return "", ErrorInvalidUserID
}
// Validate clientID.
if clientID == "" {
return "", ErrorInvalidClientID
}
// Generate and store token.
tokenPayload, err := r.tokenGenerator.Generate()
if err != nil {
return "", err
}
tokenID := len(r.store) // Should only be used in single threaded tests.
// No limits on the number of tokens per user/client for this in-memory repo.
r.store[tokenID] = refreshToken{
payload: tokenPayload,
userID: userID,
clientID: clientID,
}
return buildToken(tokenID, tokenPayload), nil
}
func (r *memRefreshTokenRepo) Verify(clientID, token string) (string, error) {
tokenID, tokenPayload, err := parseToken(token)
if err != nil {
return "", err
}
record, ok := r.store[tokenID]
if !ok {
return "", ErrorInvalidToken
}
if !bytes.Equal(record.payload, tokenPayload) {
return "", ErrorInvalidToken
}
if record.clientID != clientID {
return "", ErrorInvalidClientID
}
return record.userID, nil
}
func (r *memRefreshTokenRepo) Revoke(userID, token string) error {
tokenID, tokenPayload, err := parseToken(token)
if err != nil {
return err
}
record, ok := r.store[tokenID]
if !ok {
return ErrorInvalidToken
}
if !bytes.Equal(record.payload, tokenPayload) {
return ErrorInvalidToken
}
if record.userID != userID {
return ErrorInvalidUserID
}
delete(r.store, tokenID)
return nil
}
......@@ -17,7 +17,6 @@ import (
"github.com/coreos/dex/connector"
"github.com/coreos/dex/db"
"github.com/coreos/dex/email"
"github.com/coreos/dex/refresh"
"github.com/coreos/dex/repo"
sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user"
......@@ -139,7 +138,7 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
pwiRepo := user.NewPasswordInfoRepo()
refTokRepo := refresh.NewRefreshTokenRepo()
refTokRepo := db.NewRefreshTokenRepo(db.NewMemDB())
txnFactory := repo.InMemTransactionFactory
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{})
......
......@@ -351,10 +351,7 @@ func TestServerCodeToken(t *testing.T) {
t.Fatalf("Unexpected error: %v", err)
}
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
......@@ -376,8 +373,10 @@ func TestServerCodeToken(t *testing.T) {
},
// Have 'offline_access' in scope, should get non-empty refresh token.
{
// NOTE(ericchiang): This test assumes that the database ID of the first
// refresh token will be "1".
scope: []string{"openid", "offline_access"},
refreshToken: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
refreshToken: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
},
}
......@@ -475,11 +474,13 @@ func TestServerTokenFail(t *testing.T) {
}{
// control test case to make sure fixtures check out
{
// NOTE(ericchiang): This test assumes that the database ID of the first
// refresh token will be "1".
signer: signerFixture,
argCC: ccFixture,
argKey: keyFixture,
scope: []string{"openid", "offline_access"},
refreshToken: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
refreshToken: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
},
// no 'offline_access' in 'scope', should get empty refresh token
......@@ -549,10 +550,7 @@ func TestServerTokenFail(t *testing.T) {
t.Fatalf("Unexpected error: %v", err)
}
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
srv := &Server{
IssuerURL: issuerURL,
......@@ -600,6 +598,8 @@ func TestServerRefreshToken(t *testing.T) {
signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}
// NOTE(ericchiang): These tests assume that the database ID of the first
// refresh token will be "1".
tests := []struct {
token string
clientID string // The client that associates with the token.
......@@ -609,7 +609,7 @@ func TestServerRefreshToken(t *testing.T) {
}{
// Everything is good.
{
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX",
credXXX,
signerFixture,
......@@ -625,7 +625,7 @@ func TestServerRefreshToken(t *testing.T) {
},
// Invalid refresh token(invalid payload content).
{
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))),
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))),
"XXX",
credXXX,
signerFixture,
......@@ -633,7 +633,7 @@ func TestServerRefreshToken(t *testing.T) {
},
// Invalid refresh token(invalid ID content).
{
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX",
credXXX,
signerFixture,
......@@ -641,7 +641,7 @@ func TestServerRefreshToken(t *testing.T) {
},
// Invalid client(client is not associated with the token).
{
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX",
credYYY,
signerFixture,
......@@ -649,7 +649,7 @@ func TestServerRefreshToken(t *testing.T) {
},
// Invalid client(no client ID).
{
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX",
oidc.ClientCredentials{ID: "", Secret: "aaa"},
signerFixture,
......@@ -657,7 +657,7 @@ func TestServerRefreshToken(t *testing.T) {
},
// Invalid client(no such client).
{
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX",
oidc.ClientCredentials{ID: "AAA", Secret: "aaa"},
signerFixture,
......@@ -665,7 +665,7 @@ func TestServerRefreshToken(t *testing.T) {
},
// Invalid client(no secrets).
{
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX",
oidc.ClientCredentials{ID: "XXX"},
signerFixture,
......@@ -673,7 +673,7 @@ func TestServerRefreshToken(t *testing.T) {
},
// Invalid client(invalid secret).
{
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX",
oidc.ClientCredentials{ID: "XXX", Secret: "bad-secret"},
signerFixture,
......@@ -681,7 +681,7 @@ func TestServerRefreshToken(t *testing.T) {
},
// Signing operation fails.
{
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX",
credXXX,
&StaticSigner{sig: nil, err: errors.New("fail")},
......@@ -704,10 +704,7 @@ func TestServerRefreshToken(t *testing.T) {
t.Fatalf("Unexpected error: %v", err)
}
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
srv := &Server{
IssuerURL: issuerURL,
......@@ -764,10 +761,7 @@ func TestServerRefreshToken(t *testing.T) {
t.Fatalf("Unexpected error: %v", err)
}
refreshTokenRepo, err := refreshtest.NewTestRefreshTokenRepo()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
srv := &Server{
IssuerURL: issuerURL,
......@@ -788,7 +782,7 @@ func TestServerRefreshToken(t *testing.T) {
}
srv.UserRepo = userRepo
_, err = srv.RefreshToken(credXXX, fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))))
_, err = srv.RefreshToken(credXXX, fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))))
if !reflect.DeepEqual(err, oauth2.NewError(oauth2.ErrorServerError)) {
t.Errorf("Expect: %v, got: %v", oauth2.NewError(oauth2.ErrorServerError), err)
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment