Commit 44c6cb44 authored by Yifan Gu's avatar Yifan Gu

refresh: bcrypt raw bytes rather than base64 encoded string.

This enables us to control the length of the bytes that will be bcrypted,
by default it's 64.

Also changed the token's stored form from string('text') to []byte('bytea')
and added some test cases for different types of invalid tokens.
parent 081bfdd1
-- +migrate Up -- +migrate Up
CREATE TABLE refresh_token ( CREATE TABLE refresh_token (
id bigint NOT NULL, id bigint NOT NULL,
payload_hash text, payload_hash bytea,
user_id text, user_id text,
client_id text client_id text
); );
......
This diff is collapsed.
package db package db
import ( import (
"encoding/base64"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
...@@ -31,7 +32,7 @@ type refreshTokenRepo struct { ...@@ -31,7 +32,7 @@ type refreshTokenRepo struct {
type refreshTokenModel struct { type refreshTokenModel struct {
ID int64 `db:"id"` ID int64 `db:"id"`
PayloadHash string `db:"payload_hash"` PayloadHash []byte `db:"payload_hash"`
// TODO(yifan): Use some sort of foreign key to manage database level // TODO(yifan): Use some sort of foreign key to manage database level
// data integrity. // data integrity.
UserID string `db:"user_id"` UserID string `db:"user_id"`
...@@ -39,25 +40,29 @@ type refreshTokenModel struct { ...@@ -39,25 +40,29 @@ type refreshTokenModel struct {
} }
// buildToken combines the token ID and token payload to create a new token. // buildToken combines the token ID and token payload to create a new token.
func buildToken(tokenID int64, tokenPayload string) string { func buildToken(tokenID int64, tokenPayload []byte) string {
return fmt.Sprintf("%d%s%s", tokenID, refresh.TokenDelimer, tokenPayload) return fmt.Sprintf("%d%s%s", tokenID, refresh.TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload))
} }
// parseToken parses a token and returns the token ID and token payload. // parseToken parses a token and returns the token ID and token payload.
func parseToken(token string) (int64, string, error) { func parseToken(token string) (int64, []byte, error) {
parts := strings.SplitN(token, refresh.TokenDelimer, 2) parts := strings.SplitN(token, refresh.TokenDelimer, 2)
if len(parts) != 2 { if len(parts) != 2 {
return -1, "", refresh.ErrorInvalidToken return -1, nil, refresh.ErrorInvalidToken
} }
id, err := strconv.ParseInt(parts[0], 0, 64) id, err := strconv.ParseInt(parts[0], 0, 64)
if err != nil { if err != nil {
return -1, "", refresh.ErrorInvalidToken return -1, nil, refresh.ErrorInvalidToken
} }
return id, parts[1], nil tokenPayload, err := base64.URLEncoding.DecodeString(parts[1])
if err != nil {
return -1, nil, refresh.ErrorInvalidToken
}
return id, tokenPayload, nil
} }
func checkTokenPayload(payloadHash, payload string) error { func checkTokenPayload(payloadHash, payload []byte) error {
if err := bcrypt.CompareHashAndPassword([]byte(payloadHash), []byte(payload)); err != nil { if err := bcrypt.CompareHashAndPassword(payloadHash, payload); err != nil {
switch err { switch err {
case bcrypt.ErrMismatchedHashAndPassword: case bcrypt.ErrMismatchedHashAndPassword:
return refresh.ErrorInvalidToken return refresh.ErrorInvalidToken
...@@ -89,13 +94,13 @@ func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) { ...@@ -89,13 +94,13 @@ func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) {
return "", err return "", err
} }
payloadHash, err := bcrypt.GenerateFromPassword([]byte(tokenPayload), bcrypt.DefaultCost) payloadHash, err := bcrypt.GenerateFromPassword(tokenPayload, bcrypt.DefaultCost)
if err != nil { if err != nil {
return "", err return "", err
} }
record := &refreshTokenModel{ record := &refreshTokenModel{
PayloadHash: string(payloadHash), PayloadHash: payloadHash,
UserID: userID, UserID: userID,
ClientID: clientID, ClientID: clientID,
} }
...@@ -109,6 +114,7 @@ func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) { ...@@ -109,6 +114,7 @@ func (r *refreshTokenRepo) Create(userID, clientID string) (string, error) {
func (r *refreshTokenRepo) Verify(clientID, token string) (string, error) { func (r *refreshTokenRepo) Verify(clientID, token string) (string, error) {
tokenID, tokenPayload, err := parseToken(token) tokenID, tokenPayload, err := parseToken(token)
if err != nil { if err != nil {
return "", err return "", err
} }
......
package functional package functional
import ( import (
"encoding/base64"
"fmt" "fmt"
"net/url" "net/url"
"os" "os"
...@@ -342,6 +343,12 @@ func TestDBClientIdentityAll(t *testing.T) { ...@@ -342,6 +343,12 @@ func TestDBClientIdentityAll(t *testing.T) {
} }
} }
// buildRefreshToken combines the token ID and token payload to create a new token.
// used in the tests to created a refresh token.
func buildRefreshToken(tokenID int64, tokenPayload []byte) string {
return fmt.Sprintf("%d%s%s", tokenID, refresh.TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload))
}
func TestDBRefreshRepoCreate(t *testing.T) { func TestDBRefreshRepoCreate(t *testing.T) {
r := db.NewRefreshTokenRepo(connect(t)) r := db.NewRefreshTokenRepo(connect(t))
...@@ -383,6 +390,13 @@ func TestDBRefreshRepoVerify(t *testing.T) { ...@@ -383,6 +390,13 @@ func TestDBRefreshRepoVerify(t *testing.T) {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
badTokenPayload, err := refresh.DefaultRefreshTokenGenerator()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
tokenWithBadID := "404" + token[1:]
tokenWithBadPayload := buildRefreshToken(1, badTokenPayload)
tests := []struct { tests := []struct {
token string token string
creds oidc.ClientCredentials creds oidc.ClientCredentials
...@@ -390,7 +404,39 @@ func TestDBRefreshRepoVerify(t *testing.T) { ...@@ -390,7 +404,39 @@ func TestDBRefreshRepoVerify(t *testing.T) {
expected string expected string
}{ }{
{ {
"invalid-token-foo", "invalid-token-format",
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
"b/invalid-base64-encoded-format",
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
"1/invalid-base64-encoded-format",
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
token + "corrupted-token-payload",
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
// The token's ID content is invalid.
tokenWithBadID,
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken,
"",
},
{
// The token's payload content is invalid.
tokenWithBadPayload,
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"},
refresh.ErrorInvalidToken, refresh.ErrorInvalidToken,
"", "",
...@@ -428,13 +474,42 @@ func TestDBRefreshRepoRevoke(t *testing.T) { ...@@ -428,13 +474,42 @@ func TestDBRefreshRepoRevoke(t *testing.T) {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
badTokenPayload, err := refresh.DefaultRefreshTokenGenerator()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
tokenWithBadID := "404" + token[1:]
tokenWithBadPayload := buildRefreshToken(1, badTokenPayload)
tests := []struct { tests := []struct {
token string token string
userID string userID string
err error err error
}{ }{
{ {
"invalid-token-foo", "invalid-token-format",
"user-foo",
refresh.ErrorInvalidToken,
},
{
"1/invalid-base64-encoded-format",
"user-foo",
refresh.ErrorInvalidToken,
},
{
token + "corrupted-token-payload",
"user-foo",
refresh.ErrorInvalidToken,
},
{
// The token's ID is invalid.
tokenWithBadID,
"user-foo",
refresh.ErrorInvalidToken,
},
{
// The token's payload is invalid.
tokenWithBadPayload,
"user-foo", "user-foo",
refresh.ErrorInvalidToken, refresh.ErrorInvalidToken,
}, },
......
...@@ -10,9 +10,9 @@ import ( ...@@ -10,9 +10,9 @@ import (
// The tokens are in the form { refresh-1, refresh-2 ... refresh-n}. // The tokens are in the form { refresh-1, refresh-2 ... refresh-n}.
func NewTestRefreshTokenRepo() (refresh.RefreshTokenRepo, error) { func NewTestRefreshTokenRepo() (refresh.RefreshTokenRepo, error) {
var tokenIdx int var tokenIdx int
tokenGenerator := func() (string, error) { tokenGenerator := func() ([]byte, error) {
tokenIdx++ tokenIdx++
return fmt.Sprintf("refresh-%d", tokenIdx), nil return []byte(fmt.Sprintf("refresh-%d", tokenIdx)), nil
} }
return refresh.NewRefreshTokenRepoWithTokenGenerator(tokenGenerator), nil return refresh.NewRefreshTokenRepoWithTokenGenerator(tokenGenerator), nil
} }
package refresh package refresh
import ( import (
"bytes"
"crypto/rand" "crypto/rand"
"encoding/base64"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
...@@ -17,25 +17,26 @@ const ( ...@@ -17,25 +17,26 @@ const (
var ( var (
ErrorInvalidUserID = errors.New("invalid user ID") ErrorInvalidUserID = errors.New("invalid user ID")
ErrorInvalidClientID = errors.New("invalid client ID") ErrorInvalidClientID = errors.New("invalid client ID")
ErrorInvalidToken = errors.New("invalid token") ErrorInvalidToken = errors.New("invalid token")
) )
type RefreshTokenGenerator func() (string, error) type RefreshTokenGenerator func() ([]byte, error)
func (g RefreshTokenGenerator) Generate() (string, error) { func (g RefreshTokenGenerator) Generate() ([]byte, error) {
return g() return g()
} }
func DefaultRefreshTokenGenerator() (string, error) { func DefaultRefreshTokenGenerator() ([]byte, error) {
// TODO(yifan) Remove this duplicated token generate function. // TODO(yifan) Remove this duplicated token generate function.
b := make([]byte, DefaultRefreshTokenPayloadLength) b := make([]byte, DefaultRefreshTokenPayloadLength)
n, err := rand.Read(b) n, err := rand.Read(b)
if err != nil { if err != nil {
return "", err return nil, err
} else if n != DefaultRefreshTokenPayloadLength { } else if n != DefaultRefreshTokenPayloadLength {
return "", errors.New("unable to read enough random bytes") return nil, errors.New("unable to read enough random bytes")
} }
return base64.URLEncoding.EncodeToString(b), nil return b, nil
} }
type RefreshTokenRepo interface { type RefreshTokenRepo interface {
...@@ -52,7 +53,7 @@ type RefreshTokenRepo interface { ...@@ -52,7 +53,7 @@ type RefreshTokenRepo interface {
} }
type refreshToken struct { type refreshToken struct {
payload string payload []byte
userID string userID string
clientID string clientID string
} }
...@@ -63,21 +64,21 @@ type memRefreshTokenRepo struct { ...@@ -63,21 +64,21 @@ type memRefreshTokenRepo struct {
} }
// buildToken combines the token ID and token payload to create a new token. // buildToken combines the token ID and token payload to create a new token.
func buildToken(tokenID int, tokenPayload string) string { func buildToken(tokenID int, tokenPayload []byte) string {
return fmt.Sprintf("%d%s%s", tokenID, TokenDelimer, tokenPayload) return fmt.Sprintf("%d%s%s", tokenID, TokenDelimer, tokenPayload)
} }
// parseToken parses a token and returns the token ID and token payload. // parseToken parses a token and returns the token ID and token payload.
func parseToken(token string) (int, string, error) { func parseToken(token string) (int, []byte, error) {
parts := strings.SplitN(token, TokenDelimer, 2) parts := strings.SplitN(token, TokenDelimer, 2)
if len(parts) != 2 { if len(parts) != 2 {
return -1, "", ErrorInvalidToken return -1, nil, ErrorInvalidToken
} }
id, err := strconv.Atoi(parts[0]) id, err := strconv.Atoi(parts[0])
if err != nil { if err != nil {
return -1, "", ErrorInvalidToken return -1, nil, ErrorInvalidToken
} }
return id, parts[1], nil return id, []byte(parts[1]), nil
} }
// NewRefreshTokenRepo returns an in-memory RefreshTokenRepo useful for development. // NewRefreshTokenRepo returns an in-memory RefreshTokenRepo useful for development.
...@@ -131,7 +132,7 @@ func (r *memRefreshTokenRepo) Verify(clientID, token string) (string, error) { ...@@ -131,7 +132,7 @@ func (r *memRefreshTokenRepo) Verify(clientID, token string) (string, error) {
return "", ErrorInvalidToken return "", ErrorInvalidToken
} }
if record.payload != tokenPayload { if !bytes.Equal(record.payload, tokenPayload) {
return "", ErrorInvalidToken return "", ErrorInvalidToken
} }
...@@ -153,7 +154,7 @@ func (r *memRefreshTokenRepo) Revoke(userID, token string) error { ...@@ -153,7 +154,7 @@ func (r *memRefreshTokenRepo) Revoke(userID, token string) error {
return ErrorInvalidToken return ErrorInvalidToken
} }
if record.payload != tokenPayload { if !bytes.Equal(record.payload, tokenPayload) {
return ErrorInvalidToken return ErrorInvalidToken
} }
......
...@@ -397,7 +397,7 @@ func TestServerTokenFail(t *testing.T) { ...@@ -397,7 +397,7 @@ func TestServerTokenFail(t *testing.T) {
signer jose.Signer signer jose.Signer
argCC oidc.ClientCredentials argCC oidc.ClientCredentials
argKey string argKey string
err string err error
scope []string scope []string
refreshToken string refreshToken string
}{ }{
...@@ -423,7 +423,7 @@ func TestServerTokenFail(t *testing.T) { ...@@ -423,7 +423,7 @@ func TestServerTokenFail(t *testing.T) {
signer: signerFixture, signer: signerFixture,
argCC: ccFixture, argCC: ccFixture,
argKey: "foo", argKey: "foo",
err: oauth2.ErrorInvalidGrant, err: oauth2.NewError(oauth2.ErrorInvalidGrant),
scope: []string{"openid", "offline_access"}, scope: []string{"openid", "offline_access"},
}, },
...@@ -432,7 +432,7 @@ func TestServerTokenFail(t *testing.T) { ...@@ -432,7 +432,7 @@ func TestServerTokenFail(t *testing.T) {
signer: signerFixture, signer: signerFixture,
argCC: oidc.ClientCredentials{ID: "YYY"}, argCC: oidc.ClientCredentials{ID: "YYY"},
argKey: keyFixture, argKey: keyFixture,
err: oauth2.ErrorInvalidClient, err: oauth2.NewError(oauth2.ErrorInvalidClient),
scope: []string{"openid", "offline_access"}, scope: []string{"openid", "offline_access"},
}, },
...@@ -441,7 +441,7 @@ func TestServerTokenFail(t *testing.T) { ...@@ -441,7 +441,7 @@ func TestServerTokenFail(t *testing.T) {
signer: &StaticSigner{sig: nil, err: errors.New("fail")}, signer: &StaticSigner{sig: nil, err: errors.New("fail")},
argCC: ccFixture, argCC: ccFixture,
argKey: keyFixture, argKey: keyFixture,
err: oauth2.ErrorServerError, err: oauth2.NewError(oauth2.ErrorServerError),
scope: []string{"openid", "offline_access"}, scope: []string{"openid", "offline_access"},
}, },
} }
...@@ -502,18 +502,14 @@ func TestServerTokenFail(t *testing.T) { ...@@ -502,18 +502,14 @@ func TestServerTokenFail(t *testing.T) {
t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token) t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token)
panic("") panic("")
} }
if tt.err == "" { if !reflect.DeepEqual(err, tt.err) {
if err != nil { t.Errorf("case %d: expect %v, got %v", i, tt.err, err)
t.Errorf("case %d: got non-nil error: %v", i, err)
} else if jwt == nil {
t.Errorf("case %d: got nil JWT", i)
} }
} else { if err == nil && jwt == nil {
if err.Error() != tt.err { t.Errorf("case %d: got nil JWT", i)
t.Errorf("case %d: want err %q, got %q", i, tt.err, err.Error())
} else if jwt != nil {
t.Errorf("case %d: got non-nil JWT", i)
} }
if err != nil && jwt != nil {
t.Errorf("case %d: got non-nil JWT %v", i, jwt)
} }
} }
} }
...@@ -537,7 +533,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -537,7 +533,7 @@ func TestServerRefreshToken(t *testing.T) {
clientID string // The client that associates with the token. clientID string // The client that associates with the token.
creds oidc.ClientCredentials creds oidc.ClientCredentials
signer jose.Signer signer jose.Signer
err string err error
}{ }{
// Everything is good. // Everything is good.
{ {
...@@ -545,7 +541,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -545,7 +541,7 @@ func TestServerRefreshToken(t *testing.T) {
"XXX", "XXX",
credXXX, credXXX,
signerFixture, signerFixture,
"", nil,
}, },
// Invalid refresh token(malformatted). // Invalid refresh token(malformatted).
{ {
...@@ -553,15 +549,23 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -553,15 +549,23 @@ func TestServerRefreshToken(t *testing.T) {
"XXX", "XXX",
credXXX, credXXX,
signerFixture, signerFixture,
oauth2.ErrorInvalidRequest, oauth2.NewError(oauth2.ErrorInvalidRequest),
}, },
// Invalid refresh token. // Invalid refresh token(invalid payload content).
{ {
"0/refresh-1", "0/refresh-2",
"XXX", "XXX",
credXXX, credXXX,
signerFixture, signerFixture,
oauth2.ErrorInvalidRequest, oauth2.NewError(oauth2.ErrorInvalidRequest),
},
// Invalid refresh token(invalid ID content).
{
"1/refresh-2",
"XXX",
credXXX,
signerFixture,
oauth2.NewError(oauth2.ErrorInvalidRequest),
}, },
// Invalid client(client is not associated with the token). // Invalid client(client is not associated with the token).
{ {
...@@ -569,7 +573,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -569,7 +573,7 @@ func TestServerRefreshToken(t *testing.T) {
"XXX", "XXX",
credYYY, credYYY,
signerFixture, signerFixture,
oauth2.ErrorInvalidClient, oauth2.NewError(oauth2.ErrorInvalidClient),
}, },
// Invalid client(no client ID). // Invalid client(no client ID).
{ {
...@@ -577,7 +581,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -577,7 +581,7 @@ func TestServerRefreshToken(t *testing.T) {
"XXX", "XXX",
oidc.ClientCredentials{ID: "", Secret: "aaa"}, oidc.ClientCredentials{ID: "", Secret: "aaa"},
signerFixture, signerFixture,
oauth2.ErrorInvalidClient, oauth2.NewError(oauth2.ErrorInvalidClient),
}, },
// Invalid client(no such client). // Invalid client(no such client).
{ {
...@@ -585,7 +589,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -585,7 +589,7 @@ func TestServerRefreshToken(t *testing.T) {
"XXX", "XXX",
oidc.ClientCredentials{ID: "AAA", Secret: "aaa"}, oidc.ClientCredentials{ID: "AAA", Secret: "aaa"},
signerFixture, signerFixture,
oauth2.ErrorInvalidClient, oauth2.NewError(oauth2.ErrorInvalidClient),
}, },
// Invalid client(no secrets). // Invalid client(no secrets).
{ {
...@@ -593,7 +597,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -593,7 +597,7 @@ func TestServerRefreshToken(t *testing.T) {
"XXX", "XXX",
oidc.ClientCredentials{ID: "XXX"}, oidc.ClientCredentials{ID: "XXX"},
signerFixture, signerFixture,
oauth2.ErrorInvalidClient, oauth2.NewError(oauth2.ErrorInvalidClient),
}, },
// Invalid client(invalid secret). // Invalid client(invalid secret).
{ {
...@@ -601,7 +605,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -601,7 +605,7 @@ func TestServerRefreshToken(t *testing.T) {
"XXX", "XXX",
oidc.ClientCredentials{ID: "XXX", Secret: "bad-secret"}, oidc.ClientCredentials{ID: "XXX", Secret: "bad-secret"},
signerFixture, signerFixture,
oauth2.ErrorInvalidClient, oauth2.NewError(oauth2.ErrorInvalidClient),
}, },
// Signing operation fails. // Signing operation fails.
{ {
...@@ -609,7 +613,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -609,7 +613,7 @@ func TestServerRefreshToken(t *testing.T) {
"XXX", "XXX",
credXXX, credXXX,
&StaticSigner{sig: nil, err: errors.New("fail")}, &StaticSigner{sig: nil, err: errors.New("fail")},
oauth2.ErrorServerError, oauth2.NewError(oauth2.ErrorServerError),
}, },
} }
...@@ -646,11 +650,9 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -646,11 +650,9 @@ func TestServerRefreshToken(t *testing.T) {
} }
jwt, err := srv.RefreshToken(tt.creds, tt.token) jwt, err := srv.RefreshToken(tt.creds, tt.token)
if err != nil { if !reflect.DeepEqual(err, tt.err) {
if err.Error() != tt.err {
t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err) t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err)
} }
}
if jwt != nil { if jwt != nil {
if string(jwt.Signature) != "beer" { if string(jwt.Signature) != "beer" {
...@@ -715,7 +717,7 @@ func TestServerRefreshToken(t *testing.T) { ...@@ -715,7 +717,7 @@ func TestServerRefreshToken(t *testing.T) {
srv.UserRepo = userRepo srv.UserRepo = userRepo
_, err = srv.RefreshToken(credXXX, "0/refresh-1") _, err = srv.RefreshToken(credXXX, "0/refresh-1")
if err == nil || err.Error() != oauth2.ErrorServerError { if !reflect.DeepEqual(err, oauth2.NewError(oauth2.ErrorServerError)) {
t.Errorf("Expect: %v, got: %v", oauth2.ErrorServerError, err) t.Errorf("Expect: %v, got: %v", oauth2.NewError(oauth2.ErrorServerError), err)
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment