Commit e077803e authored by Yifan Gu's avatar Yifan Gu

Merge pull request #105 from yifan-gu/tests

refresh: bcrypt raw bytes rather than base64 encoded string.
parents ff71593c 44c6cb44
-- +migrate Up -- +migrate Up
CREATE TABLE refresh_token ( CREATE TABLE refresh_token (
id bigint NOT NULL, id bigint NOT NULL,
payload_hash text, payload_hash bytea,
user_id text, user_id text,
client_id text client_id text
); );
......
This diff is collapsed.
package db package db
import ( import (
"encoding/base64"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
...@@ -31,7 +32,7 @@ type refreshTokenRepo struct { ...@@ -31,7 +32,7 @@ type refreshTokenRepo struct {
type refreshTokenModel struct { type refreshTokenModel struct {
ID int64 `db:"id"` ID int64 `db:"id"`
PayloadHash string `db:"payload_hash"` PayloadHash []byte `db:"payload_hash"`
// TODO(yifan): Use some sort of foreign key to manage database level // TODO(yifan): Use some sort of foreign key to manage database level
// data integrity. // data integrity.
UserID string `db:"user_id"` UserID string `db:"user_id"`
...@@ -39,25 +40,29 @@ type refreshTokenModel struct { ...@@ -39,25 +40,29 @@ type refreshTokenModel struct {
} }
// buildToken combines the token ID and token payload to create a new token. // buildToken combines the token ID and token payload to create a new token.
func buildToken(tokenID int64, tokenPayload string) string { func buildToken(tokenID int64, tokenPayload []byte) string {
return fmt.Sprintf("%d%s%s", tokenID, refresh.TokenDelimer, tokenPayload) return fmt.Sprintf("%d%s%s", tokenID, refresh.TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload))
} }
// parseToken parses a token and returns the token ID and token payload. // parseToken parses a token and returns the token ID and token payload.
func parseToken(token string) (int64, string, error) { func parseToken(token string) (int64, []byte, error) {
parts := strings.SplitN(token, refresh.TokenDelimer, 2) parts := strings.SplitN(token, refresh.TokenDelimer, 2)
if len(parts) != 2 { if len(parts) != 2 {
return -1, "", refresh.ErrorInvalidToken return -1, nil, refresh.ErrorInvalidToken
} }
id, err := strconv.ParseInt(parts[0], 0, 64) id, err := strconv.ParseInt(parts[0], 0, 64)
if err != nil { if err != nil {
return -1, "", refresh.ErrorInvalidToken return -1, nil, refresh.ErrorInvalidToken
} }
return id, parts[1], nil tokenPayload, err := base64.URLEncoding.DecodeString(parts[1])
if err != nil {
return -1, nil, refresh.ErrorInvalidToken
}
return id, tokenPayload, nil
} }
func checkTokenPayload(payloadHash, payload string) error { func checkTokenPayload(payloadHash, payload []byte) error {
if err := bcrypt.CompareHashAndPassword([]byte(payloadHash), []byte(payload)); err != nil { if err := bcrypt.CompareHashAndPassword(payloadHash, payload); err != nil {
switch err { switch err {
case bcrypt.ErrMismatchedHashAndPassword: case bcrypt.ErrMismatchedHashAndPassword:
return refresh.ErrorInvalidToken return refresh.ErrorInvalidToken
...@@ -89,13 +94,13 @@ func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) { ...@@ -89,13 +94,13 @@ func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) {
return "", err return "", err
} }
payloadHash, err := bcrypt.GenerateFromPassword([]byte(tokenPayload), bcrypt.DefaultCost) payloadHash, err := bcrypt.GenerateFromPassword(tokenPayload, bcrypt.DefaultCost)
if err != nil { if err != nil {
return "", err return "", err
} }
record := &refreshTokenModel{ record := &refreshTokenModel{
PayloadHash: string(payloadHash), PayloadHash: payloadHash,
UserID: userID, UserID: userID,
ClientID: clientID, ClientID: clientID,
} }
...@@ -109,6 +114,7 @@ func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) { ...@@ -109,6 +114,7 @@ func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) {
func (r *refreshTokenRepo) Verify(clientID, token string) (string, error) { func (r *refreshTokenRepo) Verify(clientID, token string) (string, error) {
tokenID, tokenPayload, err := parseToken(token) tokenID, tokenPayload, err := parseToken(token)
if err != nil { if err != nil {
return "", err return "", err
} }
......
package functional package functional
import ( import (
"encoding/base64"
"fmt" "fmt"
"net/url" "net/url"
"os" "os"
...@@ -342,6 +343,12 @@ func TestDBClientIdentityAll(t *testing.T) { ...@@ -342,6 +343,12 @@ func TestDBClientIdentityAll(t *testing.T) {
} }
} }
// buildRefreshToken combines the token ID and token payload to create a new token.
// used in the tests to created a refresh token.
func buildRefreshToken(tokenID int64, tokenPayload []byte) string {
return fmt.Sprintf("%d%s%s", tokenID, refresh.TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload))
}
func TestDBRefreshRepoCreate(t *testing.T) { func TestDBRefreshRepoCreate(t *testing.T) {
r := db.NewRefreshTokenRepo(connect(t)) r := db.NewRefreshTokenRepo(connect(t))
...@@ -383,6 +390,13 @@ func TestDBRefreshRepoVerify(t *testing.T) { ...@@ -383,6 +390,13 @@ func TestDBRefreshRepoVerify(t *testing.T) {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
badTokenPayload, err := refresh.DefaultRefreshTokenGenerator()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
tokenWithBadID := "404" + token[1:]
tokenWithBadPayload := buildRefreshToken(1, badTokenPayload)
tests := []struct { tests := []struct {
token string token string
creds oidc.ClientCredentials creds oidc.ClientCredentials
...@@ -390,7 +404,39 @@ func TestDBRefreshRepoVerify(t *testing.T) { ...@@ -390,7 +404,39 @@ func TestDBRefreshRepoVerify(t *testing.T) {
expected string expected string
}{ }{
{ {
"invalid-token-foo", "invalid-token-format",
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
"b/invalid-base64-encoded-format",
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
"1/invalid-base64-encoded-format",
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
token + "corrupted-token-payload",
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
// The token's ID content is invalid.
tokenWithBadID,
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
// The token's payload content is invalid.
tokenWithBadPayload,
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken, refresh.ErrorInvalidToken,
"", "",
...@@ -428,13 +474,42 @@ func TestDBRefreshRepoRevoke(t *testing.T) { ...@@ -428,13 +474,42 @@ func TestDBRefreshRepoRevoke(t *testing.T) {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
badTokenPayload, err := refresh.DefaultRefreshTokenGenerator()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
tokenWithBadID := "404" + token[1:]
tokenWithBadPayload := buildRefreshToken(1, badTokenPayload)
tests := []struct { tests := []struct {
token string token string
userID string userID string
err error err error
}{ }{
{ {
"invalid-token-foo", "invalid-token-format",
"user-foo",
refresh.ErrorInvalidToken,
},
{
"1/invalid-base64-encoded-format",
"user-foo",
refresh.ErrorInvalidToken,
},
{
token + "corrupted-token-payload",
"user-foo",
refresh.ErrorInvalidToken,
},
{
// The token's ID is invalid.
tokenWithBadID,
"user-foo",
refresh.ErrorInvalidToken,
},
{
// The token's payload is invalid.
tokenWithBadPayload,
"user-foo", "user-foo",
refresh.ErrorInvalidToken, refresh.ErrorInvalidToken,
}, },
......
...@@ -10,9 +10,9 @@ import ( ...@@ -10,9 +10,9 @@ import (
// The tokens are in the form { refresh-1, refresh-2 ... refresh-n}. // The tokens are in the form { refresh-1, refresh-2 ... refresh-n}.
func NewTestRefreshTokenRepo() (refresh.RefreshTokenRepo, error) { func NewTestRefreshTokenRepo() (refresh.RefreshTokenRepo, error) {
var tokenIdx int var tokenIdx int
tokenGenerator := func() (string, error) { tokenGenerator := func() ([]byte, error) {
tokenIdx++ tokenIdx++
return fmt.Sprintf("refresh-%d", tokenIdx), nil return []byte(fmt.Sprintf("refresh-%d", tokenIdx)), nil
} }
return refresh.NewRefreshTokenRepoWithTokenGenerator(tokenGenerator), nil return refresh.NewRefreshTokenRepoWithTokenGenerator(tokenGenerator), nil
} }
package refresh package refresh
import ( import (
"bytes"
"crypto/rand" "crypto/rand"
"encoding/base64"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
...@@ -17,25 +17,26 @@ const ( ...@@ -17,25 +17,26 @@ const (
var ( var (
ErrorInvalidUserID = errors.New("invalid user ID") ErrorInvalidUserID = errors.New("invalid user ID")
ErrorInvalidClientID = errors.New("invalid client ID") ErrorInvalidClientID = errors.New("invalid client ID")
ErrorInvalidToken = errors.New("invalid token") ErrorInvalidToken = errors.New("invalid token")
) )
type RefreshTokenGenerator func() (string, error) type RefreshTokenGenerator func() ([]byte, error)
func (g RefreshTokenGenerator) Generate() (string, error) { func (g RefreshTokenGenerator) Generate() ([]byte, error) {
return g() return g()
} }
func DefaultRefreshTokenGenerator() (string, error) { func DefaultRefreshTokenGenerator() ([]byte, error) {
// TODO(yifan) Remove this duplicated token generate function. // TODO(yifan) Remove this duplicated token generate function.
b := make([]byte, DefaultRefreshTokenPayloadLength) b := make([]byte, DefaultRefreshTokenPayloadLength)
n, err := rand.Read(b) n, err := rand.Read(b)
if err != nil { if err != nil {
return "", err return nil, err
} else if n != DefaultRefreshTokenPayloadLength { } else if n != DefaultRefreshTokenPayloadLength {
return "", errors.New("unable to read enough random bytes") return nil, errors.New("unable to read enough random bytes")
} }
return base64.URLEncoding.EncodeToString(b), nil return b, nil
} }
type RefreshTokenRepo interface { type RefreshTokenRepo interface {
...@@ -52,7 +53,7 @@ type RefreshTokenRepo interface { ...@@ -52,7 +53,7 @@ type RefreshTokenRepo interface {
} }
type refreshToken struct { type refreshToken struct {
payload string payload []byte
userID string userID string
clientID string clientID string
} }
...@@ -63,21 +64,21 @@ type memRefreshTokenRepo struct { ...@@ -63,21 +64,21 @@ type memRefreshTokenRepo struct {
} }
// buildToken combines the token ID and token payload to create a new token. // buildToken combines the token ID and token payload to create a new token.
func buildToken(tokenID int, tokenPayload string) string { func buildToken(tokenID int, tokenPayload []byte) string {
return fmt.Sprintf("%d%s%s", tokenID, TokenDelimer, tokenPayload) return fmt.Sprintf("%d%s%s", tokenID, TokenDelimer, tokenPayload)
} }
// parseToken parses a token and returns the token ID and token payload. // parseToken parses a token and returns the token ID and token payload.
func parseToken(token string) (int, string, error) { func parseToken(token string) (int, []byte, error) {
parts := strings.SplitN(token, TokenDelimer, 2) parts := strings.SplitN(token, TokenDelimer, 2)
if len(parts) != 2 { if len(parts) != 2 {
return -1, "", ErrorInvalidToken return -1, nil, ErrorInvalidToken
} }
id, err := strconv.Atoi(parts[0]) id, err := strconv.Atoi(parts[0])
if err != nil { if err != nil {
return -1, "", ErrorInvalidToken return -1, nil, ErrorInvalidToken
} }
return id, parts[1], nil return id, []byte(parts[1]), nil
} }
// NewRefreshTokenRepo returns an in-memory RefreshTokenRepo useful for development. // NewRefreshTokenRepo returns an in-memory RefreshTokenRepo useful for development.
...@@ -131,7 +132,7 @@ func (r *memRefreshTokenRepo) Verify(clientID, token string) (string, error) { ...@@ -131,7 +132,7 @@ func (r *memRefreshTokenRepo) Verify(clientID, token string) (string, error) {
return "", ErrorInvalidToken return "", ErrorInvalidToken
} }
if record.payload != tokenPayload { if !bytes.Equal(record.payload, tokenPayload) {
return "", ErrorInvalidToken return "", ErrorInvalidToken
} }
...@@ -153,7 +154,7 @@ func (r *memRefreshTokenRepo) Revoke(userID, token string) error { ...@@ -153,7 +154,7 @@ func (r *memRefreshTokenRepo) Revoke(userID, token string) error {
return ErrorInvalidToken return ErrorInvalidToken
} }
if record.payload != tokenPayload { if !bytes.Equal(record.payload, tokenPayload) {
return ErrorInvalidToken return ErrorInvalidToken
} }
......
...@@ -397,7 +397,7 @@ func TestServerTokenFail(t *testing.T) { ...@@ -397,7 +397,7 @@ func TestServerTokenFail(t *testing.T) {
signer jose.Signer signer jose.Signer
argCC oidc.ClientCredentials argCC oidc.ClientCredentials
argKey string argKey string
err string err error
scope []string scope []string
refreshToken string refreshToken string
}{ }{
...@@ -423,7 +423,7 @@ func TestServerTokenFail(t *testing.T) { ...@@ -423,7 +423,7 @@ func TestServerTokenFail(t *testing.T) {
signer: signerFixture, signer: signerFixture,
argCC: ccFixture, argCC: ccFixture,
argKey: "foo", argKey: "foo",
err: oauth2.ErrorInvalidGrant, err: oauth2.NewError(oauth2.ErrorInvalidGrant),
scope: []string{"openid", "offline_access"}, scope: []string{"openid", "offline_access"},
}, },
...@@ -432,7 +432,7 @@ func TestServerTokenFail(t *testing.T) { ...@@ -432,7 +432,7 @@ func TestServerTokenFail(t *testing.T) {
signer: signerFixture, signer: signerFixture,
argCC: oidc.ClientCredentials{ID: "YYY"}, argCC: oidc.ClientCredentials{ID: "YYY"},
argKey: keyFixture, argKey: keyFixture,
err: oauth2.ErrorInvalidClient, err: oauth2.NewError(oauth2.ErrorInvalidClient),
scope: []string{"openid", "offline_access"}, scope: []string{"openid", "offline_access"},
}, },
...@@ -441,7 +441,7 @@ func TestServerTokenFail(t *testing.T) { ...@@ -441,7 +441,7 @@ func TestServerTokenFail(t *testing.T) {
signer: &StaticSigner{sig: nil, err: errors.New("fail")}, signer: &StaticSigner{sig: nil, err: errors.New("fail")},
argCC: ccFixture, argCC: ccFixture,
argKey: keyFixture, argKey: keyFixture,
err: oauth2.ErrorServerError, err: oauth2.NewError(oauth2.ErrorServerError),
scope: []string{"openid", "offline_access"}, scope: []string{"openid", "offline_access"},
}, },
} }
...@@ -502,18 +502,14 @@ func TestServerTokenFail(t *testing.T) { ...@@ -502,18 +502,14 @@ func TestServerTokenFail(t *testing.T) {
t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token) t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token)
panic("") panic("")
} }
if tt.err == "" { if !reflect.DeepEqual(err, tt.err) {
if err != nil { t.Errorf("case %d: expect %v, got %v", i, tt.err, err)
t.Errorf("case %d: got non-nil error: %v", i, err)
} else if jwt == nil {
t.Errorf("case %d: got nil JWT", i)
} }
} else { if err == nil && jwt == nil {
if err.Error() != tt.err { t.Errorf("case %d: got nil JWT", i)
t.Errorf("case %d: want err %q, got %q", i, tt.err, err.Error())
} else if jwt != nil {
t.Errorf("case %d: got non-nil JWT", i)
} }
if err != nil && jwt != nil {
t.Errorf("case %d: got non-nil JWT %v", i, jwt)
} }
} }
} }
...@@ -537,7 +533,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -537,7 +533,7 @@ func TestServerRefreshToken(t *testing.T) {
clientID string // The client that associates with the token. clientID string // The client that associates with the token.
creds oidc.ClientCredentials creds oidc.ClientCredentials
signer jose.Signer signer jose.Signer
err string err error
}{ }{
// Everything is good. // Everything is good.
{ {
...@@ -545,7 +541,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -545,7 +541,7 @@ func TestServerRefreshToken(t *testing.T) {
"XXX", "XXX",
credXXX, credXXX,
signerFixture, signerFixture,
"", nil,
}, },
// Invalid refresh token(malformatted). // Invalid refresh token(malformatted).
{ {
...@@ -553,15 +549,23 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -553,15 +549,23 @@ func TestServerRefreshToken(t *testing.T) {
"XXX", "XXX",
credXXX, credXXX,
signerFixture, signerFixture,
oauth2.ErrorInvalidRequest, oauth2.NewError(oauth2.ErrorInvalidRequest),
}, },
// Invalid refresh token. // Invalid refresh token(invalid payload content).
{ {
"0/refresh-1", "0/refresh-2",
"XXX", "XXX",
credXXX, credXXX,
signerFixture, signerFixture,
oauth2.ErrorInvalidRequest, oauth2.NewError(oauth2.ErrorInvalidRequest),
},
// Invalid refresh token(invalid ID content).
{
"1/refresh-2",
"XXX",
credXXX,
signerFixture,
oauth2.NewError(oauth2.ErrorInvalidRequest),
}, },
// Invalid client(client is not associated with the token). // Invalid client(client is not associated with the token).
{ {
...@@ -569,7 +573,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -569,7 +573,7 @@ func TestServerRefreshToken(t *testing.T) {
"XXX", "XXX",
credYYY, credYYY,
signerFixture, signerFixture,
oauth2.ErrorInvalidClient, oauth2.NewError(oauth2.ErrorInvalidClient),
}, },
// Invalid client(no client ID). // Invalid client(no client ID).
{ {
...@@ -577,7 +581,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -577,7 +581,7 @@ func TestServerRefreshToken(t *testing.T) {
"XXX", "XXX",
oidc.ClientCredentials{ID: "", Secret: "aaa"}, oidc.ClientCredentials{ID: "", Secret: "aaa"},
signerFixture, signerFixture,
oauth2.ErrorInvalidClient, oauth2.NewError(oauth2.ErrorInvalidClient),
}, },
// Invalid client(no such client). // Invalid client(no such client).
{ {
...@@ -585,7 +589,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -585,7 +589,7 @@ func TestServerRefreshToken(t *testing.T) {
"XXX", "XXX",
oidc.ClientCredentials{ID: "AAA", Secret: "aaa"}, oidc.ClientCredentials{ID: "AAA", Secret: "aaa"},
signerFixture, signerFixture,
oauth2.ErrorInvalidClient, oauth2.NewError(oauth2.ErrorInvalidClient),
}, },
// Invalid client(no secrets). // Invalid client(no secrets).
{ {
...@@ -593,7 +597,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -593,7 +597,7 @@ func TestServerRefreshToken(t *testing.T) {
"XXX", "XXX",
oidc.ClientCredentials{ID: "XXX"}, oidc.ClientCredentials{ID: "XXX"},
signerFixture, signerFixture,
oauth2.ErrorInvalidClient, oauth2.NewError(oauth2.ErrorInvalidClient),
}, },
// Invalid client(invalid secret). // Invalid client(invalid secret).
{ {
...@@ -601,7 +605,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -601,7 +605,7 @@ func TestServerRefreshToken(t *testing.T) {
"XXX", "XXX",
oidc.ClientCredentials{ID: "XXX", Secret: "bad-secret"}, oidc.ClientCredentials{ID: "XXX", Secret: "bad-secret"},
signerFixture, signerFixture,
oauth2.ErrorInvalidClient, oauth2.NewError(oauth2.ErrorInvalidClient),
}, },
// Signing operation fails. // Signing operation fails.
{ {
...@@ -609,7 +613,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -609,7 +613,7 @@ func TestServerRefreshToken(t *testing.T) {
"XXX", "XXX",
credXXX, credXXX,
&StaticSigner{sig: nil, err: errors.New("fail")}, &StaticSigner{sig: nil, err: errors.New("fail")},
oauth2.ErrorServerError, oauth2.NewError(oauth2.ErrorServerError),
}, },
} }
...@@ -646,11 +650,9 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -646,11 +650,9 @@ func TestServerRefreshToken(t *testing.T) {
} }
jwt, err := srv.RefreshToken(tt.creds, tt.token) jwt, err := srv.RefreshToken(tt.creds, tt.token)
if err != nil { if !reflect.DeepEqual(err, tt.err) {
if err.Error() != tt.err {
t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err) t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err)
} }
}
if jwt != nil { if jwt != nil {
if string(jwt.Signature) != "beer" { if string(jwt.Signature) != "beer" {
...@@ -715,7 +717,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -715,7 +717,7 @@ func TestServerRefreshToken(t *testing.T) {
srv.UserRepo = userRepo srv.UserRepo = userRepo
_, err = srv.RefreshToken(credXXX, "0/refresh-1") _, err = srv.RefreshToken(credXXX, "0/refresh-1")
if err == nil || err.Error() != oauth2.ErrorServerError { if !reflect.DeepEqual(err, oauth2.NewError(oauth2.ErrorServerError)) {
t.Errorf("Expect: %v, got: %v", oauth2.ErrorServerError, err) t.Errorf("Expect: %v, got: %v", oauth2.NewError(oauth2.ErrorServerError), err)
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment