Commit 53e38367 authored by rithu leena john's avatar rithu leena john Committed by GitHub

Merge pull request #793 from rithujohn191/token-revocation

storage: Add OfflineSession object to backend storage.
parents 49f446c1 d928ac06
...@@ -682,6 +682,75 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s ...@@ -682,6 +682,75 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return
} }
// deleteToken determines if we need to delete the newly created refresh token
// due to a failure in updating/creating the OfflineSession object for the
// corresponding user.
var deleteToken bool
defer func() {
if deleteToken {
// Delete newly created refresh token from storage.
if err := s.storage.DeleteRefresh(refresh.ID); err != nil {
s.logger.Errorf("failed to delete refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
}
}()
tokenRef := storage.RefreshTokenRef{
ID: refresh.ID,
ClientID: refresh.ClientID,
CreatedAt: refresh.CreatedAt,
LastUsed: refresh.LastUsed,
}
// Try to retrieve an existing OfflineSession object for the corresponding user.
if session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID); err != nil {
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
return
}
offlineSessions := storage.OfflineSessions{
UserID: refresh.Claims.UserID,
ConnID: refresh.ConnectorID,
Refresh: make(map[string]*storage.RefreshTokenRef),
}
offlineSessions.Refresh[tokenRef.ClientID] = &tokenRef
// Create a new OfflineSession object for the user and add a reference object for
// the newly recieved refreshtoken.
if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil {
s.logger.Errorf("failed to create offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
return
}
} else {
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
// Delete old refresh token from storage.
if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil {
s.logger.Errorf("failed to delete refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
return
}
}
// Update existing OfflineSession obj with new RefreshTokenRef.
if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
old.Refresh[tokenRef.ClientID] = &tokenRef
return old, nil
}); err != nil {
s.logger.Errorf("failed to update offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
return
}
}
} }
s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry) s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry)
} }
...@@ -815,6 +884,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie ...@@ -815,6 +884,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
return return
} }
lastUsed := s.now()
updater := func(old storage.RefreshToken) (storage.RefreshToken, error) { updater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
if old.Token != refresh.Token { if old.Token != refresh.Token {
return old, errors.New("refresh token claimed twice") return old, errors.New("refresh token claimed twice")
...@@ -828,14 +898,31 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie ...@@ -828,14 +898,31 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
old.Claims.EmailVerified = ident.EmailVerified old.Claims.EmailVerified = ident.EmailVerified
old.Claims.Groups = ident.Groups old.Claims.Groups = ident.Groups
old.ConnectorData = ident.ConnectorData old.ConnectorData = ident.ConnectorData
old.LastUsed = s.now() old.LastUsed = lastUsed
return old, nil return old, nil
} }
// Update LastUsed time stamp in refresh token reference object
// in offline session for the user.
if err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
if old.Refresh[refresh.ClientID].ID != refresh.ID {
return old, errors.New("refresh token invalid")
}
old.Refresh[refresh.ClientID].LastUsed = lastUsed
return old, nil
}); err != nil {
s.logger.Errorf("failed to update offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
// Update refresh token in the storage.
if err := s.storage.UpdateRefreshToken(refresh.ID, updater); err != nil { if err := s.storage.UpdateRefreshToken(refresh.ID, updater); err != nil {
s.logger.Errorf("failed to update refresh token: %v", err) s.logger.Errorf("failed to update refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return
} }
s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry) s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry)
} }
......
...@@ -971,3 +971,108 @@ func TestKeyCacher(t *testing.T) { ...@@ -971,3 +971,108 @@ func TestKeyCacher(t *testing.T) {
} }
} }
} }
type oauth2Client struct {
config *oauth2.Config
token *oauth2.Token
server *httptest.Server
}
// TestRefreshTokenFlow tests the refresh token code flow for oauth2. The test verifies
// that only valid refresh tokens can be used to refresh an expired token.
func TestRefreshTokenFlow(t *testing.T) {
state := "state"
now := func() time.Time { return time.Now() }
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
httpServer, s := newTestServer(ctx, t, func(c *Config) {
c.Now = now
})
defer httpServer.Close()
p, err := oidc.NewProvider(ctx, httpServer.URL)
if err != nil {
t.Fatalf("failed to get provider: %v", err)
}
var oauth2Client oauth2Client
oauth2Client.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/callback" {
// User is visiting app first time. Redirect to dex.
http.Redirect(w, r, oauth2Client.config.AuthCodeURL(state), http.StatusSeeOther)
return
}
// User is at '/callback' so they were just redirected _from_ dex.
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
if desc := q.Get("error_description"); desc != "" {
t.Errorf("got error from server %s: %s", errType, desc)
} else {
t.Errorf("got error from server %s", errType)
}
w.WriteHeader(http.StatusInternalServerError)
return
}
// Grab code, exchange for token.
if code := q.Get("code"); code != "" {
token, err := oauth2Client.config.Exchange(ctx, code)
if err != nil {
t.Errorf("failed to exchange code for token: %v", err)
return
}
oauth2Client.token = token
}
// Ensure state matches.
if gotState := q.Get("state"); gotState != state {
t.Errorf("state did not match, want=%q got=%q", state, gotState)
}
w.WriteHeader(http.StatusOK)
return
}))
defer oauth2Client.server.Close()
// Register the client above with dex.
redirectURL := oauth2Client.server.URL + "/callback"
client := storage.Client{
ID: "testclient",
Secret: "testclientsecret",
RedirectURIs: []string{redirectURL},
}
if err := s.storage.CreateClient(client); err != nil {
t.Fatalf("failed to create client: %v", err)
}
oauth2Client.config = &oauth2.Config{
ClientID: client.ID,
ClientSecret: client.Secret,
Endpoint: p.Endpoint(),
Scopes: []string{oidc.ScopeOpenID, "email", "offline_access"},
RedirectURL: redirectURL,
}
if _, err = http.Get(oauth2Client.server.URL + "/login"); err != nil {
t.Fatalf("get failed: %v", err)
}
tok := &oauth2.Token{
RefreshToken: oauth2Client.token.RefreshToken,
Expiry: time.Now().Add(-time.Hour),
}
// Login in again to recieve a new token.
if _, err = http.Get(oauth2Client.server.URL + "/login"); err != nil {
t.Fatalf("get failed: %v", err)
}
// try to refresh expired token with old refresh token.
newToken, err := oauth2Client.config.TokenSource(ctx, tok).Token()
if newToken != nil {
t.Errorf("Token refreshed with invalid refresh token.")
}
}
...@@ -47,6 +47,7 @@ func RunTests(t *testing.T, newStorage func() storage.Storage) { ...@@ -47,6 +47,7 @@ func RunTests(t *testing.T, newStorage func() storage.Storage) {
{"RefreshTokenCRUD", testRefreshTokenCRUD}, {"RefreshTokenCRUD", testRefreshTokenCRUD},
{"PasswordCRUD", testPasswordCRUD}, {"PasswordCRUD", testPasswordCRUD},
{"KeysCRUD", testKeysCRUD}, {"KeysCRUD", testKeysCRUD},
{"OfflineSessionCRUD", testOfflineSessionCRUD},
{"GarbageCollection", testGC}, {"GarbageCollection", testGC},
{"TimezoneSupport", testTimezones}, {"TimezoneSupport", testTimezones},
}) })
...@@ -340,6 +341,60 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) { ...@@ -340,6 +341,60 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
} }
func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
session := storage.OfflineSessions{
UserID: "User",
ConnID: "Conn",
Refresh: make(map[string]*storage.RefreshTokenRef),
}
// Creating an OfflineSession with an empty Refresh list to ensure that
// an empty map is translated as expected by the storage.
if err := s.CreateOfflineSessions(session); err != nil {
t.Fatalf("create offline session: %v", err)
}
getAndCompare := func(userID string, connID string, want storage.OfflineSessions) {
gr, err := s.GetOfflineSessions(userID, connID)
if err != nil {
t.Errorf("get offline session: %v", err)
return
}
if diff := pretty.Compare(want, gr); diff != "" {
t.Errorf("offline session retrieved from storage did not match: %s", diff)
}
}
getAndCompare("User", "Conn", session)
id := storage.NewID()
tokenRef := storage.RefreshTokenRef{
ID: id,
ClientID: "client_id",
CreatedAt: time.Now().UTC().Round(time.Millisecond),
LastUsed: time.Now().UTC().Round(time.Millisecond),
}
session.Refresh[tokenRef.ClientID] = &tokenRef
if err := s.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
old.Refresh[tokenRef.ClientID] = &tokenRef
return old, nil
}); err != nil {
t.Fatalf("failed to update offline session: %v", err)
}
getAndCompare("User", "Conn", session)
if err := s.DeleteOfflineSessions(session.UserID, session.ConnID); err != nil {
t.Fatalf("failed to delete offline session: %v", err)
}
if _, err := s.GetOfflineSessions(session.UserID, session.ConnID); err != storage.ErrNotFound {
t.Errorf("after deleting offline session expected storage.ErrNotFound, got %v", err)
}
}
func testKeysCRUD(t *testing.T, s storage.Storage) { func testKeysCRUD(t *testing.T, s storage.Storage) {
updateAndCompare := func(k storage.Keys) { updateAndCompare := func(k storage.Keys) {
err := s.UpdateKeys(func(oldKeys storage.Keys) (storage.Keys, error) { err := s.UpdateKeys(func(oldKeys storage.Keys) (storage.Keys, error) {
......
...@@ -58,6 +58,12 @@ func (c *client) idToName(s string) string { ...@@ -58,6 +58,12 @@ func (c *client) idToName(s string) string {
return idToName(s, c.hash) return idToName(s, c.hash)
} }
// offlineTokenName maps two arbitrary IDs, to a single Kubernetes object name.
// This is used when more than one field is used to uniquely identify the object.
func (c *client) offlineTokenName(userID string, connID string) string {
return offlineTokenName(userID, connID, c.hash)
}
// Kubernetes names must match the regexp '[a-z0-9]([-a-z0-9]*[a-z0-9])?'. // Kubernetes names must match the regexp '[a-z0-9]([-a-z0-9]*[a-z0-9])?'.
var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567") var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567")
...@@ -65,6 +71,12 @@ func idToName(s string, h func() hash.Hash) string { ...@@ -65,6 +71,12 @@ func idToName(s string, h func() hash.Hash) string {
return strings.TrimRight(encoding.EncodeToString(h().Sum([]byte(s))), "=") return strings.TrimRight(encoding.EncodeToString(h().Sum([]byte(s))), "=")
} }
func offlineTokenName(userID string, connID string, h func() hash.Hash) string {
h().Write([]byte(userID))
h().Write([]byte(connID))
return strings.TrimRight(encoding.EncodeToString(h().Sum(nil)), "=")
}
func (c *client) urlFor(apiVersion, namespace, resource, name string) string { func (c *client) urlFor(apiVersion, namespace, resource, name string) string {
basePath := "apis/" basePath := "apis/"
if apiVersion == "v1" { if apiVersion == "v1" {
......
...@@ -15,21 +15,23 @@ import ( ...@@ -15,21 +15,23 @@ import (
) )
const ( const (
kindAuthCode = "AuthCode" kindAuthCode = "AuthCode"
kindAuthRequest = "AuthRequest" kindAuthRequest = "AuthRequest"
kindClient = "OAuth2Client" kindClient = "OAuth2Client"
kindRefreshToken = "RefreshToken" kindRefreshToken = "RefreshToken"
kindKeys = "SigningKey" kindKeys = "SigningKey"
kindPassword = "Password" kindPassword = "Password"
kindOfflineSessions = "OfflineSessions"
) )
const ( const (
resourceAuthCode = "authcodes" resourceAuthCode = "authcodes"
resourceAuthRequest = "authrequests" resourceAuthRequest = "authrequests"
resourceClient = "oauth2clients" resourceClient = "oauth2clients"
resourceRefreshToken = "refreshtokens" resourceRefreshToken = "refreshtokens"
resourceKeys = "signingkeies" // Kubernetes attempts to pluralize. resourceKeys = "signingkeies" // Kubernetes attempts to pluralize.
resourcePassword = "passwords" resourcePassword = "passwords"
resourceOfflineSessions = "offlinesessions"
) )
// Config values for the Kubernetes storage type. // Config values for the Kubernetes storage type.
...@@ -156,6 +158,10 @@ func (cli *client) CreateRefresh(r storage.RefreshToken) error { ...@@ -156,6 +158,10 @@ func (cli *client) CreateRefresh(r storage.RefreshToken) error {
return cli.post(resourceRefreshToken, cli.fromStorageRefreshToken(r)) return cli.post(resourceRefreshToken, cli.fromStorageRefreshToken(r))
} }
func (cli *client) CreateOfflineSessions(o storage.OfflineSessions) error {
return cli.post(resourceOfflineSessions, cli.fromStorageOfflineSessions(o))
}
func (cli *client) GetAuthRequest(id string) (storage.AuthRequest, error) { func (cli *client) GetAuthRequest(id string) (storage.AuthRequest, error) {
var req AuthRequest var req AuthRequest
if err := cli.get(resourceAuthRequest, id, &req); err != nil { if err := cli.get(resourceAuthRequest, id, &req); err != nil {
...@@ -235,6 +241,25 @@ func (cli *client) getRefreshToken(id string) (r RefreshToken, err error) { ...@@ -235,6 +241,25 @@ func (cli *client) getRefreshToken(id string) (r RefreshToken, err error) {
return return
} }
func (cli *client) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) {
o, err := cli.getOfflineSessions(userID, connID)
if err != nil {
return storage.OfflineSessions{}, err
}
return toStorageOfflineSessions(o), nil
}
func (cli *client) getOfflineSessions(userID string, connID string) (o OfflineSessions, err error) {
name := cli.offlineTokenName(userID, connID)
if err = cli.get(resourceOfflineSessions, name, &o); err != nil {
return OfflineSessions{}, err
}
if userID != o.UserID || connID != o.ConnID {
return OfflineSessions{}, fmt.Errorf("get offline session: wrong session retrieved")
}
return o, nil
}
func (cli *client) ListClients() ([]storage.Client, error) { func (cli *client) ListClients() ([]storage.Client, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
...@@ -292,6 +317,15 @@ func (cli *client) DeletePassword(email string) error { ...@@ -292,6 +317,15 @@ func (cli *client) DeletePassword(email string) error {
return cli.delete(resourcePassword, p.ObjectMeta.Name) return cli.delete(resourcePassword, p.ObjectMeta.Name)
} }
func (cli *client) DeleteOfflineSessions(userID string, connID string) error {
// Check for hash collition.
o, err := cli.getOfflineSessions(userID, connID)
if err != nil {
return err
}
return cli.delete(resourceOfflineSessions, o.ObjectMeta.Name)
}
func (cli *client) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { func (cli *client) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
r, err := cli.getRefreshToken(id) r, err := cli.getRefreshToken(id)
if err != nil { if err != nil {
...@@ -342,6 +376,22 @@ func (cli *client) UpdatePassword(email string, updater func(old storage.Passwor ...@@ -342,6 +376,22 @@ func (cli *client) UpdatePassword(email string, updater func(old storage.Passwor
return cli.put(resourcePassword, p.ObjectMeta.Name, newPassword) return cli.put(resourcePassword, p.ObjectMeta.Name, newPassword)
} }
func (cli *client) UpdateOfflineSessions(userID string, connID string, updater func(old storage.OfflineSessions) (storage.OfflineSessions, error)) error {
o, err := cli.getOfflineSessions(userID, connID)
if err != nil {
return err
}
updated, err := updater(toStorageOfflineSessions(o))
if err != nil {
return err
}
newOfflineSessions := cli.fromStorageOfflineSessions(updated)
newOfflineSessions.ObjectMeta = o.ObjectMeta
return cli.put(resourceOfflineSessions, o.ObjectMeta.Name, newOfflineSessions)
}
func (cli *client) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error { func (cli *client) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error {
firstUpdate := false firstUpdate := false
var keys Keys var keys Keys
......
...@@ -66,6 +66,14 @@ var thirdPartyResources = []k8sapi.ThirdPartyResource{ ...@@ -66,6 +66,14 @@ var thirdPartyResources = []k8sapi.ThirdPartyResource{
Description: "Passwords managed by the OIDC server.", Description: "Passwords managed by the OIDC server.",
Versions: []k8sapi.APIVersion{{Name: "v1"}}, Versions: []k8sapi.APIVersion{{Name: "v1"}},
}, },
{
ObjectMeta: k8sapi.ObjectMeta{
Name: "offline-sessions.oidc.coreos.com",
},
TypeMeta: tprMeta,
Description: "User sessions with an active refresh token.",
Versions: []k8sapi.APIVersion{{Name: "v1"}},
},
} }
// There will only ever be a single keys resource. Maintain this by setting a // There will only ever be a single keys resource. Maintain this by setting a
...@@ -465,3 +473,38 @@ func toStorageKeys(keys Keys) storage.Keys { ...@@ -465,3 +473,38 @@ func toStorageKeys(keys Keys) storage.Keys {
NextRotation: keys.NextRotation, NextRotation: keys.NextRotation,
} }
} }
// OfflineSessions is a mirrored struct from storage with JSON struct tags and Kubernetes
// type metadata.
type OfflineSessions struct {
k8sapi.TypeMeta `json:",inline"`
k8sapi.ObjectMeta `json:"metadata,omitempty"`
UserID string `json:"userID,omitempty"`
ConnID string `json:"connID,omitempty"`
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
}
func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions {
return OfflineSessions{
TypeMeta: k8sapi.TypeMeta{
Kind: kindOfflineSessions,
APIVersion: cli.apiVersion,
},
ObjectMeta: k8sapi.ObjectMeta{
Name: cli.offlineTokenName(o.UserID, o.ConnID),
Namespace: cli.namespace,
},
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
}
}
func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
return storage.OfflineSessions{
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
}
}
...@@ -13,12 +13,13 @@ import ( ...@@ -13,12 +13,13 @@ import (
// New returns an in memory storage. // New returns an in memory storage.
func New(logger logrus.FieldLogger) storage.Storage { func New(logger logrus.FieldLogger) storage.Storage {
return &memStorage{ return &memStorage{
clients: make(map[string]storage.Client), clients: make(map[string]storage.Client),
authCodes: make(map[string]storage.AuthCode), authCodes: make(map[string]storage.AuthCode),
refreshTokens: make(map[string]storage.RefreshToken), refreshTokens: make(map[string]storage.RefreshToken),
authReqs: make(map[string]storage.AuthRequest), authReqs: make(map[string]storage.AuthRequest),
passwords: make(map[string]storage.Password), passwords: make(map[string]storage.Password),
logger: logger, offlineSessions: make(map[offlineSessionID]storage.OfflineSessions),
logger: logger,
} }
} }
...@@ -37,17 +38,23 @@ func (c *Config) Open(logger logrus.FieldLogger) (storage.Storage, error) { ...@@ -37,17 +38,23 @@ func (c *Config) Open(logger logrus.FieldLogger) (storage.Storage, error) {
type memStorage struct { type memStorage struct {
mu sync.Mutex mu sync.Mutex
clients map[string]storage.Client clients map[string]storage.Client
authCodes map[string]storage.AuthCode authCodes map[string]storage.AuthCode
refreshTokens map[string]storage.RefreshToken refreshTokens map[string]storage.RefreshToken
authReqs map[string]storage.AuthRequest authReqs map[string]storage.AuthRequest
passwords map[string]storage.Password passwords map[string]storage.Password
offlineSessions map[offlineSessionID]storage.OfflineSessions
keys storage.Keys keys storage.Keys
logger logrus.FieldLogger logger logrus.FieldLogger
} }
type offlineSessionID struct {
userID string
connID string
}
func (s *memStorage) tx(f func()) { func (s *memStorage) tx(f func()) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
...@@ -130,6 +137,32 @@ func (s *memStorage) CreatePassword(p storage.Password) (err error) { ...@@ -130,6 +137,32 @@ func (s *memStorage) CreatePassword(p storage.Password) (err error) {
return return
} }
func (s *memStorage) CreateOfflineSessions(o storage.OfflineSessions) (err error) {
id := offlineSessionID{
userID: o.UserID,
connID: o.ConnID,
}
s.tx(func() {
if _, ok := s.offlineSessions[id]; ok {
err = storage.ErrAlreadyExists
} else {
s.offlineSessions[id] = o
}
})
return
}
func (s *memStorage) GetAuthCode(id string) (c storage.AuthCode, err error) {
s.tx(func() {
var ok bool
if c, ok = s.authCodes[id]; !ok {
err = storage.ErrNotFound
return
}
})
return
}
func (s *memStorage) GetPassword(email string) (p storage.Password, err error) { func (s *memStorage) GetPassword(email string) (p storage.Password, err error) {
email = strings.ToLower(email) email = strings.ToLower(email)
s.tx(func() { s.tx(func() {
...@@ -156,10 +189,10 @@ func (s *memStorage) GetKeys() (keys storage.Keys, err error) { ...@@ -156,10 +189,10 @@ func (s *memStorage) GetKeys() (keys storage.Keys, err error) {
return return
} }
func (s *memStorage) GetRefresh(token string) (tok storage.RefreshToken, err error) { func (s *memStorage) GetRefresh(id string) (tok storage.RefreshToken, err error) {
s.tx(func() { s.tx(func() {
var ok bool var ok bool
if tok, ok = s.refreshTokens[token]; !ok { if tok, ok = s.refreshTokens[id]; !ok {
err = storage.ErrNotFound err = storage.ErrNotFound
return return
} }
...@@ -178,6 +211,21 @@ func (s *memStorage) GetAuthRequest(id string) (req storage.AuthRequest, err err ...@@ -178,6 +211,21 @@ func (s *memStorage) GetAuthRequest(id string) (req storage.AuthRequest, err err
return return
} }
func (s *memStorage) GetOfflineSessions(userID string, connID string) (o storage.OfflineSessions, err error) {
id := offlineSessionID{
userID: userID,
connID: connID,
}
s.tx(func() {
var ok bool
if o, ok = s.offlineSessions[id]; !ok {
err = storage.ErrNotFound
return
}
})
return
}
func (s *memStorage) ListClients() (clients []storage.Client, err error) { func (s *memStorage) ListClients() (clients []storage.Client, err error) {
s.tx(func() { s.tx(func() {
for _, client := range s.clients { for _, client := range s.clients {
...@@ -228,13 +276,13 @@ func (s *memStorage) DeleteClient(id string) (err error) { ...@@ -228,13 +276,13 @@ func (s *memStorage) DeleteClient(id string) (err error) {
return return
} }
func (s *memStorage) DeleteRefresh(token string) (err error) { func (s *memStorage) DeleteRefresh(id string) (err error) {
s.tx(func() { s.tx(func() {
if _, ok := s.refreshTokens[token]; !ok { if _, ok := s.refreshTokens[id]; !ok {
err = storage.ErrNotFound err = storage.ErrNotFound
return return
} }
delete(s.refreshTokens, token) delete(s.refreshTokens, id)
}) })
return return
} }
...@@ -261,13 +309,17 @@ func (s *memStorage) DeleteAuthRequest(id string) (err error) { ...@@ -261,13 +309,17 @@ func (s *memStorage) DeleteAuthRequest(id string) (err error) {
return return
} }
func (s *memStorage) GetAuthCode(id string) (c storage.AuthCode, err error) { func (s *memStorage) DeleteOfflineSessions(userID string, connID string) (err error) {
id := offlineSessionID{
userID: userID,
connID: connID,
}
s.tx(func() { s.tx(func() {
var ok bool if _, ok := s.offlineSessions[id]; !ok {
if c, ok = s.authCodes[id]; !ok {
err = storage.ErrNotFound err = storage.ErrNotFound
return return
} }
delete(s.offlineSessions, id)
}) })
return return
} }
...@@ -338,3 +390,21 @@ func (s *memStorage) UpdateRefreshToken(id string, updater func(p storage.Refres ...@@ -338,3 +390,21 @@ func (s *memStorage) UpdateRefreshToken(id string, updater func(p storage.Refres
}) })
return return
} }
func (s *memStorage) UpdateOfflineSessions(userID string, connID string, updater func(o storage.OfflineSessions) (storage.OfflineSessions, error)) (err error) {
id := offlineSessionID{
userID: userID,
connID: connID,
}
s.tx(func() {
r, ok := s.offlineSessions[id]
if !ok {
err = storage.ErrNotFound
return
}
if r, err = updater(r); err == nil {
s.offlineSessions[id] = r
}
})
return
}
...@@ -624,6 +624,75 @@ func scanPassword(s scanner) (p storage.Password, err error) { ...@@ -624,6 +624,75 @@ func scanPassword(s scanner) (p storage.Password, err error) {
return p, nil return p, nil
} }
func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
_, err := c.Exec(`
insert into offline_session (
user_id, conn_id, refresh
)
values (
$1, $2, $3
);
`,
s.UserID, s.ConnID, encoder(s.Refresh),
)
if err != nil {
return fmt.Errorf("insert offline session: %v", err)
}
return nil
}
func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
return c.ExecTx(func(tx *trans) error {
s, err := getOfflineSessions(tx, userID, connID)
if err != nil {
return err
}
newSession, err := updater(s)
if err != nil {
return err
}
_, err = tx.Exec(`
update offline_session
set
refresh = $1
where user_id = $2 AND conn_id = $3;
`,
encoder(newSession.Refresh), s.UserID, s.ConnID,
)
if err != nil {
return fmt.Errorf("update offline session: %v", err)
}
return nil
})
}
func (c *conn) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) {
return getOfflineSessions(c, userID, connID)
}
func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) {
return scanOfflineSessions(q.QueryRow(`
select
user_id, conn_id, refresh
from offline_session
where user_id = $1 AND conn_id = $2;
`, userID, connID))
}
func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) {
err = s.Scan(
&o.UserID, &o.ConnID, decoder(&o.Refresh),
)
if err != nil {
if err == sql.ErrNoRows {
return o, storage.ErrNotFound
}
return o, fmt.Errorf("select offline session: %v", err)
}
return o, nil
}
func (c *conn) DeleteAuthRequest(id string) error { return c.delete("auth_request", "id", id) } func (c *conn) DeleteAuthRequest(id string) error { return c.delete("auth_request", "id", id) }
func (c *conn) DeleteAuthCode(id string) error { return c.delete("auth_code", "id", id) } func (c *conn) DeleteAuthCode(id string) error { return c.delete("auth_code", "id", id) }
func (c *conn) DeleteClient(id string) error { return c.delete("client", "id", id) } func (c *conn) DeleteClient(id string) error { return c.delete("client", "id", id) }
...@@ -632,6 +701,24 @@ func (c *conn) DeletePassword(email string) error { ...@@ -632,6 +701,24 @@ func (c *conn) DeletePassword(email string) error {
return c.delete("password", "email", strings.ToLower(email)) return c.delete("password", "email", strings.ToLower(email))
} }
func (c *conn) DeleteOfflineSessions(userID string, connID string) error {
result, err := c.Exec(`delete from offline_session where user_id = $1 AND conn_id = $2`, userID, connID)
if err != nil {
return fmt.Errorf("delete offline_session: user_id = %s, conn_id = %s", userID, connID)
}
// For now mandate that the driver implements RowsAffected. If we ever need to support
// a driver that doesn't implement this, we can run this in a transaction with a get beforehand.
n, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("rows affected: %v", err)
}
if n < 1 {
return storage.ErrNotFound
}
return nil
}
// Do NOT call directly. Does not escape table. // Do NOT call directly. Does not escape table.
func (c *conn) delete(table, field, id string) error { func (c *conn) delete(table, field, id string) error {
result, err := c.Exec(`delete from `+table+` where `+field+` = $1`, id) result, err := c.Exec(`delete from `+table+` where `+field+` = $1`, id)
......
...@@ -153,6 +153,7 @@ var migrations = []migration{ ...@@ -153,6 +153,7 @@ var migrations = []migration{
signing_key_pub bytea not null, -- JSON object signing_key_pub bytea not null, -- JSON object
next_rotation timestamptz not null next_rotation timestamptz not null
); );
`, `,
}, },
{ {
...@@ -165,4 +166,14 @@ var migrations = []migration{ ...@@ -165,4 +166,14 @@ var migrations = []migration{
add column last_used timestamptz not null default '0001-01-01 00:00:00 UTC'; add column last_used timestamptz not null default '0001-01-01 00:00:00 UTC';
`, `,
}, },
{
stmt: `
create table offline_session (
user_id text not null,
conn_id text not null,
refresh bytea not null,
PRIMARY KEY (user_id, conn_id)
);
`,
},
} }
...@@ -52,6 +52,7 @@ type Storage interface { ...@@ -52,6 +52,7 @@ type Storage interface {
CreateAuthCode(c AuthCode) error CreateAuthCode(c AuthCode) error
CreateRefresh(r RefreshToken) error CreateRefresh(r RefreshToken) error
CreatePassword(p Password) error CreatePassword(p Password) error
CreateOfflineSessions(s OfflineSessions) error
// TODO(ericchiang): return (T, bool, error) so we can indicate not found // TODO(ericchiang): return (T, bool, error) so we can indicate not found
// requests that way instead of using ErrNotFound. // requests that way instead of using ErrNotFound.
...@@ -61,6 +62,7 @@ type Storage interface { ...@@ -61,6 +62,7 @@ type Storage interface {
GetKeys() (Keys, error) GetKeys() (Keys, error)
GetRefresh(id string) (RefreshToken, error) GetRefresh(id string) (RefreshToken, error)
GetPassword(email string) (Password, error) GetPassword(email string) (Password, error)
GetOfflineSessions(userID string, connID string) (OfflineSessions, error)
ListClients() ([]Client, error) ListClients() ([]Client, error)
ListRefreshTokens() ([]RefreshToken, error) ListRefreshTokens() ([]RefreshToken, error)
...@@ -72,6 +74,7 @@ type Storage interface { ...@@ -72,6 +74,7 @@ type Storage interface {
DeleteClient(id string) error DeleteClient(id string) error
DeleteRefresh(id string) error DeleteRefresh(id string) error
DeletePassword(email string) error DeletePassword(email string) error
DeleteOfflineSessions(userID string, connID string) error
// Update methods take a function for updating an object then performs that update within // Update methods take a function for updating an object then performs that update within
// a transaction. "updater" functions may be called multiple times by a single update call. // a transaction. "updater" functions may be called multiple times by a single update call.
...@@ -92,6 +95,7 @@ type Storage interface { ...@@ -92,6 +95,7 @@ type Storage interface {
UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error
UpdateRefreshToken(id string, updater func(r RefreshToken) (RefreshToken, error)) error UpdateRefreshToken(id string, updater func(r RefreshToken) (RefreshToken, error)) error
UpdatePassword(email string, updater func(p Password) (Password, error)) error UpdatePassword(email string, updater func(p Password) (Password, error)) error
UpdateOfflineSessions(userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error
// GarbageCollect deletes all expired AuthCodes and AuthRequests. // GarbageCollect deletes all expired AuthCodes and AuthRequests.
GarbageCollect(now time.Time) (GCResult, error) GarbageCollect(now time.Time) (GCResult, error)
...@@ -241,6 +245,30 @@ type RefreshToken struct { ...@@ -241,6 +245,30 @@ type RefreshToken struct {
Nonce string Nonce string
} }
// RefreshTokenRef is a reference object that contains metadata about refresh tokens.
type RefreshTokenRef struct {
ID string
// Client the refresh token is valid for.
ClientID string
CreatedAt time.Time
LastUsed time.Time
}
// OfflineSessions objects are sessions pertaining to users with refresh tokens.
type OfflineSessions struct {
// UserID of an end user who has logged in to the server.
UserID string
// The ID of the connector used to login the user.
ConnID string
// Refresh is a hash table of refresh token reference objects
// indexed by the ClientID of the refresh token.
Refresh map[string]*RefreshTokenRef
}
// Password is an email to password mapping managed by the storage. // Password is an email to password mapping managed by the storage.
type Password struct { type Password struct {
// Email and identifying name of the password. Emails are assumed to be valid and // Email and identifying name of the password. Emails are assumed to be valid and
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment