Commit 7bac93aa authored by Eric Chiang's avatar Eric Chiang

*: remove in memory session repos

Move manager to it's own package so it can import db. Move all
references to the in memory session repos to use sqlite3.
parent 5052d800
......@@ -101,3 +101,15 @@ func rollback(tx *gorp.Transaction) {
log.Errorf("unable to rollback: %v", err)
}
}
// NewMemDB creates a new in memory sqlite3 database.
func NewMemDB() *gorp.DbMap {
dbMap, err := NewConnection(Config{DSN: "sqlite3://:memory:"})
if err != nil {
panic("Failed to create in memory database: " + err.Error())
}
if _, err := MigrateToLatest(dbMap); err != nil {
panic("In memory database migration failed: " + err.Error())
}
return dbMap
}
......@@ -65,7 +65,7 @@ CREATE TABLE session (
);
CREATE TABLE session_key (
key text NOT NULL UNIQUE,
key text NOT NULL,
session_id text,
expires_at bigint,
stale integer
......
......@@ -15,7 +15,7 @@ import (
func newSessionRepo(t *testing.T) (session.SessionRepo, clockwork.FakeClock) {
clock := clockwork.NewFakeClock()
if os.Getenv("DEX_TEST_DSN") == "" {
return session.NewSessionRepoWithClock(clock), clock
return db.NewSessionRepoWithClock(db.NewMemDB(), clock), clock
}
dbMap := connect(t)
return db.NewSessionRepoWithClock(dbMap, clock), clock
......@@ -24,7 +24,7 @@ func newSessionRepo(t *testing.T) (session.SessionRepo, clockwork.FakeClock) {
func newSessionKeyRepo(t *testing.T) (session.SessionKeyRepo, clockwork.FakeClock) {
clock := clockwork.NewFakeClock()
if os.Getenv("DEX_TEST_DSN") == "" {
return session.NewSessionKeyRepoWithClock(clock), clock
return db.NewSessionKeyRepoWithClock(db.NewMemDB(), clock), clock
}
dbMap := connect(t)
return db.NewSessionKeyRepoWithClock(dbMap, clock), clock
......
......@@ -10,10 +10,11 @@ import (
"github.com/coreos/dex/client"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/db"
phttp "github.com/coreos/dex/pkg/http"
"github.com/coreos/dex/refresh/refreshtest"
"github.com/coreos/dex/server"
"github.com/coreos/dex/session"
"github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key"
......@@ -33,7 +34,7 @@ func mockServer(cis []oidc.ClientIdentity) (*server.Server, error) {
return nil, err
}
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
srv := &server.Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
KeyManager: km,
......@@ -120,7 +121,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
cir := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci})
issuerURL := url.URL{Scheme: "http", Host: "server.example.com"}
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
k, err := key.GeneratePrivateKey()
if err != nil {
......
......@@ -19,10 +19,10 @@ import (
"github.com/coreos/dex/email"
"github.com/coreos/dex/refresh"
"github.com/coreos/dex/repo"
"github.com/coreos/dex/session"
sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user"
useremail "github.com/coreos/dex/user/email"
"github.com/coreos/dex/user/manager"
usermanager "github.com/coreos/dex/user/manager"
)
type ServerConfig struct {
......@@ -128,9 +128,9 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
}
cfgRepo := connector.NewConnectorConfigRepoFromConfigs(cfgs)
sRepo := session.NewSessionRepo()
skRepo := session.NewSessionKeyRepo()
sm := session.NewSessionManager(sRepo, skRepo)
sRepo := db.NewSessionRepo(db.NewMemDB())
skRepo := db.NewSessionKeyRepo(db.NewMemDB())
sm := sessionmanager.NewSessionManager(sRepo, skRepo)
userRepo, err := user.NewUserRepoFromFile(cfg.UsersFile)
if err != nil {
......@@ -142,7 +142,7 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
refTokRepo := refresh.NewRefreshTokenRepo()
txnFactory := repo.InMemTransactionFactory
userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, manager.ManagerOptions{})
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{})
srv.ClientIdentityRepo = ciRepo
srv.KeySetRepo = kRepo
srv.ConnectorConfigRepo = cfgRepo
......@@ -180,10 +180,10 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error {
cfgRepo := db.NewConnectorConfigRepo(dbc)
userRepo := db.NewUserRepo(dbc)
pwiRepo := db.NewPasswordInfoRepo(dbc)
userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), usermanager.ManagerOptions{})
refreshTokenRepo := db.NewRefreshTokenRepo(dbc)
sm := session.NewSessionManager(sRepo, skRepo)
sm := sessionmanager.NewSessionManager(sRepo, skRepo)
srv.ClientIdentityRepo = ciRepo
srv.KeySetRepo = kRepo
......
......@@ -17,7 +17,8 @@ import (
"github.com/coreos/dex/client"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/session"
"github.com/coreos/dex/db"
"github.com/coreos/dex/session/manager"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/oauth2"
"github.com/coreos/go-oidc/oidc"
......@@ -75,7 +76,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
}
srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
SessionManager: session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()),
SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())),
ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{
oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{
......@@ -198,7 +199,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
}
srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
SessionManager: session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()),
SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())),
ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{
oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{
......
......@@ -9,10 +9,10 @@ import (
"github.com/coreos/dex/client"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/session"
sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user"
useremail "github.com/coreos/dex/user/email"
"github.com/coreos/dex/user/manager"
usermanager "github.com/coreos/dex/user/manager"
)
type sendResetPasswordEmailData struct {
......@@ -28,7 +28,7 @@ type sendResetPasswordEmailData struct {
type SendResetPasswordEmailHandler struct {
tpl *template.Template
emailer *useremail.UserEmailer
sm *session.SessionManager
sm *sessionmanager.SessionManager
cr client.ClientIdentityRepo
}
......@@ -182,7 +182,7 @@ type resetPasswordTemplateData struct {
type ResetPasswordHandler struct {
tpl *template.Template
issuerURL url.URL
um *manager.UserManager
um *usermanager.UserManager
keysFunc func() ([]key.PublicKey, error)
}
......@@ -238,7 +238,7 @@ func (r *resetPasswordRequest) handlePOST() {
cbURL, err := r.h.um.ChangePassword(r.pwReset, plaintext)
if err != nil {
switch err {
case manager.ErrorPasswordAlreadyChanged:
case usermanager.ErrorPasswordAlreadyChanged:
r.data.Error = "Link Expired"
r.data.Message = "The link in your email is no longer valid. If you need to change your password, generate a new email."
r.data.DontShowForm = true
......
......@@ -10,8 +10,9 @@ import (
"github.com/coreos/dex/connector"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/session"
sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager"
usermanager "github.com/coreos/dex/user/manager"
"github.com/coreos/go-oidc/oidc"
)
......@@ -274,7 +275,7 @@ func makeClientRedirectURL(baseRedirURL url.URL, code, clientState string) *url.
return &ru
}
func registerFromLocalConnector(userManager *manager.UserManager, sessionManager *session.SessionManager, ses *session.Session, email, password string) (string, error) {
func registerFromLocalConnector(userManager *usermanager.UserManager, sessionManager *sessionmanager.SessionManager, ses *session.Session, email, password string) (string, error) {
userID, err := userManager.RegisterWithPassword(email, password, ses.ConnectorID)
if err != nil {
return "", err
......@@ -289,7 +290,7 @@ func registerFromLocalConnector(userManager *manager.UserManager, sessionManager
return userID, nil
}
func registerFromRemoteConnector(userManager *manager.UserManager, ses *session.Session, email string, emailVerified bool) (string, error) {
func registerFromRemoteConnector(userManager *usermanager.UserManager, ses *session.Session, email string, emailVerified bool) (string, error) {
if ses.Identity.ID == "" {
return "", errors.New("No Identity found in session.")
}
......
......@@ -22,10 +22,11 @@ import (
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh"
"github.com/coreos/dex/session"
sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user"
usersapi "github.com/coreos/dex/user/api"
useremail "github.com/coreos/dex/user/email"
"github.com/coreos/dex/user/manager"
usermanager "github.com/coreos/dex/user/manager"
)
const (
......@@ -57,7 +58,7 @@ type Server struct {
IssuerURL url.URL
KeyManager key.PrivateKeyManager
KeySetRepo key.PrivateKeySetRepo
SessionManager *session.SessionManager
SessionManager *sessionmanager.SessionManager
ClientIdentityRepo client.ClientIdentityRepo
ConnectorConfigRepo connector.ConnectorConfigRepo
Templates *template.Template
......@@ -69,7 +70,7 @@ type Server struct {
HealthChecks []health.Checkable
Connectors []connector.Connector
UserRepo user.UserRepo
UserManager *manager.UserManager
UserManager *usermanager.UserManager
PasswordInfoRepo user.PasswordInfoRepo
RefreshTokenRepo refresh.RefreshTokenRepo
UserEmailer *useremail.UserEmailer
......
......@@ -10,8 +10,9 @@ import (
"time"
"github.com/coreos/dex/client"
"github.com/coreos/dex/db"
"github.com/coreos/dex/refresh/refreshtest"
"github.com/coreos/dex/session"
"github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key"
......@@ -68,7 +69,7 @@ func (ss *StaticSigner) JWK() jose.JWK {
return jose.JWK{}
}
func staticGenerateCodeFunc(code string) session.GenerateCodeFunc {
func staticGenerateCodeFunc(code string) manager.GenerateCodeFunc {
return func() (string, error) {
return code, nil
}
......@@ -120,7 +121,7 @@ func TestServerProviderConfig(t *testing.T) {
}
func TestServerNewSession(t *testing.T) {
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
srv := &Server{
SessionManager: sm,
}
......@@ -197,7 +198,7 @@ func TestServerLogin(t *testing.T) {
signer: &StaticSigner{sig: []byte("beer"), err: nil},
}
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
sm.GenerateCode = staticGenerateCodeFunc("fakecode")
sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURIs[0], "", false, []string{"openid"})
if err != nil {
......@@ -245,7 +246,7 @@ func TestServerLoginUnrecognizedSessionKey(t *testing.T) {
km := &StaticKeyManager{
signer: &StaticSigner{sig: nil, err: errors.New("fail")},
}
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
KeyManager: km,
......@@ -286,7 +287,7 @@ func TestServerLoginDisabledUser(t *testing.T) {
signer: &StaticSigner{sig: []byte("beer"), err: nil},
}
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
sm.GenerateCode = staticGenerateCodeFunc("fakecode")
sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURIs[0], "", false, []string{"openid"})
if err != nil {
......@@ -343,7 +344,7 @@ func TestServerCodeToken(t *testing.T) {
km := &StaticKeyManager{
signer: &StaticSigner{sig: []byte("beer"), err: nil},
}
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
userRepo, err := makeNewUserRepo()
if err != nil {
......@@ -424,7 +425,7 @@ func TestServerTokenUnrecognizedKey(t *testing.T) {
km := &StaticKeyManager{
signer: &StaticSigner{sig: []byte("beer"), err: nil},
}
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
......@@ -518,7 +519,7 @@ func TestServerTokenFail(t *testing.T) {
}
for i, tt := range tests {
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
sm.GenerateCode = func() (string, error) { return keyFixture, nil }
sessionID, err := sm.NewSession("connector_id", ccFixture.ID, "bogus", url.URL{}, "", false, tt.scope)
......
......@@ -10,12 +10,13 @@ import (
"github.com/coreos/dex/client"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/db"
"github.com/coreos/dex/email"
"github.com/coreos/dex/repo"
"github.com/coreos/dex/session"
sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user"
useremail "github.com/coreos/dex/user/email"
"github.com/coreos/dex/user/manager"
usermanager "github.com/coreos/dex/user/manager"
)
const (
......@@ -75,13 +76,13 @@ var (
type testFixtures struct {
srv *Server
userRepo user.UserRepo
sessionManager *session.SessionManager
sessionManager *sessionmanager.SessionManager
emailer *email.TemplatizedEmailer
redirectURL url.URL
clientIdentityRepo client.ClientIdentityRepo
}
func sequentialGenerateCodeFunc() session.GenerateCodeFunc {
func sequentialGenerateCodeFunc() sessionmanager.GenerateCodeFunc {
x := 0
return func() (string, error) {
x += 1
......@@ -113,9 +114,9 @@ func makeTestFixtures() (*testFixtures, error) {
}
connCfgRepo := connector.NewConnectorConfigRepoFromConfigs(connConfigs)
manager := manager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, manager.ManagerOptions{})
manager := usermanager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, usermanager.ManagerOptions{})
sessionManager := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
sessionManager := sessionmanager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
sessionManager.GenerateCode = sequentialGenerateCodeFunc()
emailer, err := email.NewTemplatizedEmailerFromGlobs(
......
package session
package manager
import (
"crypto/rand"
......@@ -10,6 +10,7 @@ import (
"github.com/jonboulle/clockwork"
"github.com/coreos/dex/session"
"github.com/coreos/go-oidc/oidc"
)
......@@ -27,11 +28,11 @@ func DefaultGenerateCode() (string, error) {
return base64.URLEncoding.EncodeToString(b), nil
}
func NewSessionManager(sRepo SessionRepo, skRepo SessionKeyRepo) *SessionManager {
func NewSessionManager(sRepo session.SessionRepo, skRepo session.SessionKeyRepo) *SessionManager {
return &SessionManager{
GenerateCode: DefaultGenerateCode,
Clock: clockwork.NewRealClock(),
ValidityWindow: DefaultSessionValidityWindow,
ValidityWindow: session.DefaultSessionValidityWindow,
sessions: sRepo,
keys: skRepo,
}
......@@ -41,8 +42,8 @@ type SessionManager struct {
GenerateCode GenerateCodeFunc
Clock clockwork.Clock
ValidityWindow time.Duration
sessions SessionRepo
keys SessionKeyRepo
sessions session.SessionRepo
keys session.SessionKeyRepo
}
func (m *SessionManager) NewSession(connectorID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) {
......@@ -52,10 +53,10 @@ func (m *SessionManager) NewSession(connectorID, clientID, clientState string, r
}
now := m.Clock.Now()
s := Session{
s := session.Session{
ConnectorID: connectorID,
ID: sID,
State: SessionStateNew,
State: session.SessionStateNew,
CreatedAt: now,
ExpiresAt: now.Add(m.ValidityWindow),
ClientID: clientID,
......@@ -80,11 +81,12 @@ func (m *SessionManager) NewSessionKey(sessionID string) (string, error) {
return "", err
}
k := SessionKey{
k := session.SessionKey{
Key: key,
SessionID: sessionID,
}
sessionKeyValidityWindow := 10 * time.Minute //RFC6749
err = m.keys.Push(k, sessionKeyValidityWindow)
if err != nil {
return "", err
......@@ -97,7 +99,7 @@ func (m *SessionManager) ExchangeKey(key string) (string, error) {
return m.keys.Pop(key)
}
func (m *SessionManager) getSessionInState(sessionID string, state SessionState) (*Session, error) {
func (m *SessionManager) getSessionInState(sessionID string, state session.SessionState) (*session.Session, error) {
s, err := m.sessions.Get(sessionID)
if err != nil {
return nil, err
......@@ -110,14 +112,14 @@ func (m *SessionManager) getSessionInState(sessionID string, state SessionState)
return s, nil
}
func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Identity) (*Session, error) {
s, err := m.getSessionInState(sessionID, SessionStateNew)
func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Identity) (*session.Session, error) {
s, err := m.getSessionInState(sessionID, session.SessionStateNew)
if err != nil {
return nil, err
}
s.Identity = ident
s.State = SessionStateRemoteAttached
s.State = session.SessionStateRemoteAttached
if err = m.sessions.Update(*s); err != nil {
return nil, err
......@@ -126,14 +128,14 @@ func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Ident
return s, nil
}
func (m *SessionManager) AttachUser(sessionID string, userID string) (*Session, error) {
s, err := m.getSessionInState(sessionID, SessionStateRemoteAttached)
func (m *SessionManager) AttachUser(sessionID string, userID string) (*session.Session, error) {
s, err := m.getSessionInState(sessionID, session.SessionStateRemoteAttached)
if err != nil {
return nil, err
}
s.UserID = userID
s.State = SessionStateIdentified
s.State = session.SessionStateIdentified
if err = m.sessions.Update(*s); err != nil {
return nil, err
......@@ -142,13 +144,13 @@ func (m *SessionManager) AttachUser(sessionID string, userID string) (*Session,
return s, nil
}
func (m *SessionManager) Kill(sessionID string) (*Session, error) {
func (m *SessionManager) Kill(sessionID string) (*session.Session, error) {
s, err := m.sessions.Get(sessionID)
if err != nil {
return nil, err
}
s.State = SessionStateDead
s.State = session.SessionStateDead
if err = m.sessions.Update(*s); err != nil {
return nil, err
......@@ -157,6 +159,6 @@ func (m *SessionManager) Kill(sessionID string) (*Session, error) {
return s, nil
}
func (m *SessionManager) Get(sessionID string) (*Session, error) {
func (m *SessionManager) Get(sessionID string) (*session.Session, error) {
return m.sessions.Get(sessionID)
}
package session
package manager
import (
"net/url"
"testing"
"github.com/coreos/dex/db"
"github.com/coreos/dex/session"
"github.com/coreos/go-oidc/oidc"
)
......@@ -13,8 +15,13 @@ func staticGenerateCodeFunc(code string) GenerateCodeFunc {
}
}
func newManager(t *testing.T) *SessionManager {
dbMap := db.NewMemDB()
return NewSessionManager(db.NewSessionRepo(dbMap), db.NewSessionKeyRepo(dbMap))
}
func TestSessionManagerNewSession(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
sm := newManager(t)
sm.GenerateCode = staticGenerateCodeFunc("boo")
got, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
if err != nil {
......@@ -26,7 +33,7 @@ func TestSessionManagerNewSession(t *testing.T) {
}
func TestSessionAttachRemoteIdentityTwice(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
sm := newManager(t)
sessionID, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
......@@ -43,7 +50,7 @@ func TestSessionAttachRemoteIdentityTwice(t *testing.T) {
}
func TestSessionManagerExchangeKey(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
sm := newManager(t)
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
......@@ -68,8 +75,8 @@ func TestSessionManagerExchangeKey(t *testing.T) {
}
func TestSessionManagerGetSessionInStateNoExist(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
ses, err := sm.getSessionInState("123", SessionStateNew)
sm := newManager(t)
ses, err := sm.getSessionInState("123", session.SessionStateNew)
if err == nil {
t.Errorf("Expected non-nil error")
}
......@@ -79,12 +86,12 @@ func TestSessionManagerGetSessionInStateNoExist(t *testing.T) {
}
func TestSessionManagerGetSessionInStateWrongState(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
sm := newManager(t)
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
ses, err := sm.getSessionInState(sessionID, SessionStateDead)
ses, err := sm.getSessionInState(sessionID, session.SessionStateDead)
if err == nil {
t.Errorf("Expected non-nil error")
}
......@@ -94,7 +101,7 @@ func TestSessionManagerGetSessionInStateWrongState(t *testing.T) {
}
func TestSessionManagerKill(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
sm := newManager(t)
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
......
package session
import (
"errors"
"time"
"github.com/jonboulle/clockwork"
)
import "time"
type SessionRepo interface {
Get(string) (*Session, error)
......@@ -17,87 +12,3 @@ type SessionKeyRepo interface {
Push(SessionKey, time.Duration) error
Pop(string) (string, error)
}
func NewSessionRepo() SessionRepo {
return NewSessionRepoWithClock(clockwork.NewRealClock())
}
func NewSessionRepoWithClock(clock clockwork.Clock) SessionRepo {
return &memSessionRepo{
store: make(map[string]Session),
clock: clock,
}
}
type memSessionRepo struct {
store map[string]Session
clock clockwork.Clock
}
func (m *memSessionRepo) Get(sessionID string) (*Session, error) {
s, ok := m.store[sessionID]
if !ok || s.ExpiresAt.Before(m.clock.Now()) {
return nil, errors.New("unrecognized ID")
}
return &s, nil
}
func (m *memSessionRepo) Create(s Session) error {
if _, ok := m.store[s.ID]; ok {
return errors.New("ID exists")
}
m.store[s.ID] = s
return nil
}
func (m *memSessionRepo) Update(s Session) error {
if _, ok := m.store[s.ID]; !ok {
return errors.New("unrecognized ID")
}
m.store[s.ID] = s
return nil
}
type expiringSessionKey struct {
SessionKey
expiresAt time.Time
}
func NewSessionKeyRepo() SessionKeyRepo {
return NewSessionKeyRepoWithClock(clockwork.NewRealClock())
}
func NewSessionKeyRepoWithClock(clock clockwork.Clock) SessionKeyRepo {
return &memSessionKeyRepo{
store: make(map[string]expiringSessionKey),
clock: clock,
}
}
type memSessionKeyRepo struct {
store map[string]expiringSessionKey
clock clockwork.Clock
}
func (m *memSessionKeyRepo) Pop(key string) (string, error) {
esk, ok := m.store[key]
if !ok {
return "", errors.New("unrecognized key")
}
defer delete(m.store, key)
if esk.expiresAt.Before(m.clock.Now()) {
return "", errors.New("expired key")
}
return esk.SessionKey.SessionID, nil
}
func (m *memSessionKeyRepo) Push(sk SessionKey, ttl time.Duration) error {
m.store[sk.Key] = expiringSessionKey{
SessionKey: sk,
expiresAt: m.clock.Now().Add(ttl),
}
return nil
}
......@@ -14,7 +14,7 @@ COVER=${COVER:-"-cover"}
source ./build
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session user user/api user/manager email admin"
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session session/manager user user/api user/manager email admin"
FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log"
# user has not provided PKG override
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment