Commit 312ca749 authored by Eric Chiang's avatar Eric Chiang

storage: add extra fields to refresh token and update method

parent c66cce8b
...@@ -208,10 +208,14 @@ func testClientCRUD(t *testing.T, s storage.Storage) { ...@@ -208,10 +208,14 @@ func testClientCRUD(t *testing.T, s storage.Storage) {
func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
id := storage.NewID() id := storage.NewID()
refresh := storage.RefreshToken{ refresh := storage.RefreshToken{
RefreshToken: id, ID: id,
ClientID: "client_id", Token: "bar",
ConnectorID: "client_secret", Nonce: "foo",
Scopes: []string{"openid", "email", "profile"}, ClientID: "client_id",
ConnectorID: "client_secret",
Scopes: []string{"openid", "email", "profile"},
CreatedAt: time.Now().UTC().Round(time.Millisecond),
LastUsed: time.Now().UTC().Round(time.Millisecond),
Claims: storage.Claims{ Claims: storage.Claims{
UserID: "1", UserID: "1",
Username: "jane", Username: "jane",
...@@ -238,6 +242,20 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { ...@@ -238,6 +242,20 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
getAndCompare(id, refresh) getAndCompare(id, refresh)
updatedAt := time.Now().UTC().Round(time.Millisecond)
updater := func(r storage.RefreshToken) (storage.RefreshToken, error) {
r.Token = "spam"
r.LastUsed = updatedAt
return r, nil
}
if err := s.UpdateRefreshToken(id, updater); err != nil {
t.Errorf("failed to udpate refresh token: %v", err)
}
refresh.Token = "spam"
refresh.LastUsed = updatedAt
getAndCompare(id, refresh)
if err := s.DeleteRefresh(id); err != nil { if err := s.DeleteRefresh(id); err != nil {
t.Fatalf("failed to delete refresh request: %v", err) t.Fatalf("failed to delete refresh request: %v", err)
} }
......
...@@ -153,23 +153,7 @@ func (cli *client) CreatePassword(p storage.Password) error { ...@@ -153,23 +153,7 @@ func (cli *client) CreatePassword(p storage.Password) error {
} }
func (cli *client) CreateRefresh(r storage.RefreshToken) error { func (cli *client) CreateRefresh(r storage.RefreshToken) error {
refresh := RefreshToken{ return cli.post(resourceRefreshToken, cli.fromStorageRefreshToken(r))
TypeMeta: k8sapi.TypeMeta{
Kind: kindRefreshToken,
APIVersion: cli.apiVersion,
},
ObjectMeta: k8sapi.ObjectMeta{
Name: r.RefreshToken,
Namespace: cli.namespace,
},
ClientID: r.ClientID,
ConnectorID: r.ConnectorID,
Scopes: r.Scopes,
Nonce: r.Nonce,
Claims: fromStorageClaims(r.Claims),
ConnectorData: r.ConnectorData,
}
return cli.post(resourceRefreshToken, refresh)
} }
func (cli *client) GetAuthRequest(id string) (storage.AuthRequest, error) { func (cli *client) GetAuthRequest(id string) (storage.AuthRequest, error) {
...@@ -239,19 +223,16 @@ func (cli *client) GetKeys() (storage.Keys, error) { ...@@ -239,19 +223,16 @@ func (cli *client) GetKeys() (storage.Keys, error) {
} }
func (cli *client) GetRefresh(id string) (storage.RefreshToken, error) { func (cli *client) GetRefresh(id string) (storage.RefreshToken, error) {
var r RefreshToken r, err := cli.getRefreshToken(id)
if err := cli.get(resourceRefreshToken, id, &r); err != nil { if err != nil {
return storage.RefreshToken{}, err return storage.RefreshToken{}, err
} }
return storage.RefreshToken{ return toStorageRefreshToken(r), nil
RefreshToken: r.ObjectMeta.Name, }
ClientID: r.ClientID,
ConnectorID: r.ConnectorID, func (cli *client) getRefreshToken(id string) (r RefreshToken, err error) {
Scopes: r.Scopes, err = cli.get(resourceRefreshToken, id, &r)
Nonce: r.Nonce, return
Claims: toStorageClaims(r.Claims),
ConnectorData: r.ConnectorData,
}, nil
} }
func (cli *client) ListClients() ([]storage.Client, error) { func (cli *client) ListClients() ([]storage.Client, error) {
...@@ -311,6 +292,22 @@ func (cli *client) DeletePassword(email string) error { ...@@ -311,6 +292,22 @@ func (cli *client) DeletePassword(email string) error {
return cli.delete(resourcePassword, p.ObjectMeta.Name) return cli.delete(resourcePassword, p.ObjectMeta.Name)
} }
func (cli *client) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
r, err := cli.getRefreshToken(id)
if err != nil {
return err
}
updated, err := updater(toStorageRefreshToken(r))
if err != nil {
return err
}
updated.ID = id
newToken := cli.fromStorageRefreshToken(updated)
newToken.ObjectMeta = r.ObjectMeta
return cli.put(resourceRefreshToken, r.ObjectMeta.Name, newToken)
}
func (cli *client) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error { func (cli *client) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
c, err := cli.getClient(id) c, err := cli.getClient(id)
if err != nil { if err != nil {
......
...@@ -362,9 +362,14 @@ type RefreshToken struct { ...@@ -362,9 +362,14 @@ type RefreshToken struct {
k8sapi.TypeMeta `json:",inline"` k8sapi.TypeMeta `json:",inline"`
k8sapi.ObjectMeta `json:"metadata,omitempty"` k8sapi.ObjectMeta `json:"metadata,omitempty"`
CreatedAt time.Time
LastUsed time.Time
ClientID string `json:"clientID"` ClientID string `json:"clientID"`
Scopes []string `json:"scopes,omitempty"` Scopes []string `json:"scopes,omitempty"`
Token string `json:"token,omitempty"`
Nonce string `json:"nonce,omitempty"` Nonce string `json:"nonce,omitempty"`
Claims Claims `json:"claims,omitempty"` Claims Claims `json:"claims,omitempty"`
...@@ -379,6 +384,43 @@ type RefreshList struct { ...@@ -379,6 +384,43 @@ type RefreshList struct {
RefreshTokens []RefreshToken `json:"items"` RefreshTokens []RefreshToken `json:"items"`
} }
func toStorageRefreshToken(r RefreshToken) storage.RefreshToken {
return storage.RefreshToken{
ID: r.ObjectMeta.Name,
Token: r.Token,
CreatedAt: r.CreatedAt,
LastUsed: r.LastUsed,
ClientID: r.ClientID,
ConnectorID: r.ConnectorID,
ConnectorData: r.ConnectorData,
Scopes: r.Scopes,
Nonce: r.Nonce,
Claims: toStorageClaims(r.Claims),
}
}
func (cli *client) fromStorageRefreshToken(r storage.RefreshToken) RefreshToken {
return RefreshToken{
TypeMeta: k8sapi.TypeMeta{
Kind: kindRefreshToken,
APIVersion: cli.apiVersion,
},
ObjectMeta: k8sapi.ObjectMeta{
Name: r.ID,
Namespace: cli.namespace,
},
Token: r.Token,
CreatedAt: r.CreatedAt,
LastUsed: r.LastUsed,
ClientID: r.ClientID,
ConnectorID: r.ConnectorID,
ConnectorData: r.ConnectorData,
Scopes: r.Scopes,
Nonce: r.Nonce,
Claims: fromStorageClaims(r.Claims),
}
}
// Keys is a mirrored struct from storage with JSON struct tags and Kubernetes // Keys is a mirrored struct from storage with JSON struct tags and Kubernetes
// type metadata. // type metadata.
type Keys struct { type Keys struct {
......
...@@ -98,10 +98,10 @@ func (s *memStorage) CreateAuthCode(c storage.AuthCode) (err error) { ...@@ -98,10 +98,10 @@ func (s *memStorage) CreateAuthCode(c storage.AuthCode) (err error) {
func (s *memStorage) CreateRefresh(r storage.RefreshToken) (err error) { func (s *memStorage) CreateRefresh(r storage.RefreshToken) (err error) {
s.tx(func() { s.tx(func() {
if _, ok := s.refreshTokens[r.RefreshToken]; ok { if _, ok := s.refreshTokens[r.ID]; ok {
err = storage.ErrAlreadyExists err = storage.ErrAlreadyExists
} else { } else {
s.refreshTokens[r.RefreshToken] = r s.refreshTokens[r.ID] = r
} }
}) })
return return
...@@ -324,3 +324,17 @@ func (s *memStorage) UpdatePassword(email string, updater func(p storage.Passwor ...@@ -324,3 +324,17 @@ func (s *memStorage) UpdatePassword(email string, updater func(p storage.Passwor
}) })
return return
} }
func (s *memStorage) UpdateRefreshToken(id string, updater func(p storage.RefreshToken) (storage.RefreshToken, error)) (err error) {
s.tx(func() {
r, ok := s.refreshTokens[id]
if !ok {
err = storage.ErrNotFound
return
}
if r, err = updater(r); err == nil {
s.refreshTokens[id] = r
}
})
return
}
...@@ -244,14 +244,16 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error { ...@@ -244,14 +244,16 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
id, client_id, scopes, nonce, id, client_id, scopes, nonce,
claims_user_id, claims_username, claims_email, claims_email_verified, claims_user_id, claims_username, claims_email, claims_email_verified,
claims_groups, claims_groups,
connector_id, connector_data connector_id, connector_data,
token, created_at, last_used
) )
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11); values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14);
`, `,
r.RefreshToken, r.ClientID, encoder(r.Scopes), r.Nonce, r.ID, r.ClientID, encoder(r.Scopes), r.Nonce,
r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified, r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified,
encoder(r.Claims.Groups), encoder(r.Claims.Groups),
r.ConnectorID, r.ConnectorData, r.ConnectorID, r.ConnectorData,
r.Token, r.CreatedAt, r.LastUsed,
) )
if err != nil { if err != nil {
return fmt.Errorf("insert refresh_token: %v", err) return fmt.Errorf("insert refresh_token: %v", err)
...@@ -259,13 +261,57 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error { ...@@ -259,13 +261,57 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
return nil return nil
} }
func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
return c.ExecTx(func(tx *trans) error {
r, err := getRefresh(tx, id)
if err != nil {
return err
}
if r, err = updater(r); err != nil {
return err
}
_, err = tx.Exec(`
update refresh_token
set
client_id = $1,
scopes = $2,
nonce = $3,
claims_user_id = $4,
claims_username = $5,
claims_email = $6,
claims_email_verified = $7,
claims_groups = $8,
connector_id = $9,
connector_data = $10,
token = $11,
created_at = $12,
last_used = $13
`,
r.ClientID, encoder(r.Scopes), r.Nonce,
r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified,
encoder(r.Claims.Groups),
r.ConnectorID, r.ConnectorData,
r.Token, r.CreatedAt, r.LastUsed,
)
if err != nil {
return fmt.Errorf("update refresh token: %v", err)
}
return nil
})
}
func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) { func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) {
return scanRefresh(c.QueryRow(` return getRefresh(c, id)
}
func getRefresh(q querier, id string) (storage.RefreshToken, error) {
return scanRefresh(q.QueryRow(`
select select
id, client_id, scopes, nonce, id, client_id, scopes, nonce,
claims_user_id, claims_username, claims_email, claims_email_verified, claims_user_id, claims_username, claims_email, claims_email_verified,
claims_groups, claims_groups,
connector_id, connector_data connector_id, connector_data,
token, created_at, last_used
from refresh_token where id = $1; from refresh_token where id = $1;
`, id)) `, id))
} }
...@@ -276,7 +322,8 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) { ...@@ -276,7 +322,8 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
id, client_id, scopes, nonce, id, client_id, scopes, nonce,
claims_user_id, claims_username, claims_email, claims_email_verified, claims_user_id, claims_username, claims_email, claims_email_verified,
claims_groups, claims_groups,
connector_id, connector_data connector_id, connector_data,
token, created_at, last_used
from refresh_token; from refresh_token;
`) `)
if err != nil { if err != nil {
...@@ -298,10 +345,11 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) { ...@@ -298,10 +345,11 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
func scanRefresh(s scanner) (r storage.RefreshToken, err error) { func scanRefresh(s scanner) (r storage.RefreshToken, err error) {
err = s.Scan( err = s.Scan(
&r.RefreshToken, &r.ClientID, decoder(&r.Scopes), &r.Nonce, &r.ID, &r.ClientID, decoder(&r.Scopes), &r.Nonce,
&r.Claims.UserID, &r.Claims.Username, &r.Claims.Email, &r.Claims.EmailVerified, &r.Claims.UserID, &r.Claims.Username, &r.Claims.Email, &r.Claims.EmailVerified,
decoder(&r.Claims.Groups), decoder(&r.Claims.Groups),
&r.ConnectorID, &r.ConnectorData, &r.ConnectorID, &r.ConnectorData,
&r.Token, &r.CreatedAt, &r.LastUsed,
) )
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
......
...@@ -155,4 +155,14 @@ var migrations = []migration{ ...@@ -155,4 +155,14 @@ var migrations = []migration{
); );
`, `,
}, },
{
stmt: `
alter table refresh_token
add column token text not null default '';
alter table refresh_token
add column created_at timestamptz not null default '0001-01-01 00:00:00 UTC';
alter table refresh_token
add column last_used timestamptz not null default '0001-01-01 00:00:00 UTC';
`,
},
} }
...@@ -94,6 +94,7 @@ type Storage interface { ...@@ -94,6 +94,7 @@ type Storage interface {
UpdateClient(id string, updater func(old Client) (Client, error)) error UpdateClient(id string, updater func(old Client) (Client, error)) error
UpdateKeys(updater func(old Keys) (Keys, error)) error UpdateKeys(updater func(old Keys) (Keys, error)) error
UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error
UpdateRefreshToken(id string, updater func(r RefreshToken) (RefreshToken, error)) error
UpdatePassword(email string, updater func(p Password) (Password, error)) error UpdatePassword(email string, updater func(p Password) (Password, error)) error
// GarbageCollect deletes all expired AuthCodes and AuthRequests. // GarbageCollect deletes all expired AuthCodes and AuthRequests.
...@@ -216,8 +217,15 @@ type AuthCode struct { ...@@ -216,8 +217,15 @@ type AuthCode struct {
// RefreshToken is an OAuth2 refresh token which allows a client to request new // RefreshToken is an OAuth2 refresh token which allows a client to request new
// tokens on the end user's behalf. // tokens on the end user's behalf.
type RefreshToken struct { type RefreshToken struct {
// The actual refresh token. ID string
RefreshToken string
// A single token that's rotated every time the refresh token is refreshed.
//
// May be empty.
Token string
CreatedAt time.Time
LastUsed time.Time
// Client this refresh token is valid for. // Client this refresh token is valid for.
ClientID string ClientID string
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment