Commit 7bac93aa authored by Eric Chiang's avatar Eric Chiang

*: remove in memory session repos

Move manager to it's own package so it can import db. Move all
references to the in memory session repos to use sqlite3.
parent 5052d800
...@@ -101,3 +101,15 @@ func rollback(tx *gorp.Transaction) { ...@@ -101,3 +101,15 @@ func rollback(tx *gorp.Transaction) {
log.Errorf("unable to rollback: %v", err) log.Errorf("unable to rollback: %v", err)
} }
} }
// NewMemDB creates a new in memory sqlite3 database.
func NewMemDB() *gorp.DbMap {
dbMap, err := NewConnection(Config{DSN: "sqlite3://:memory:"})
if err != nil {
panic("Failed to create in memory database: " + err.Error())
}
if _, err := MigrateToLatest(dbMap); err != nil {
panic("In memory database migration failed: " + err.Error())
}
return dbMap
}
...@@ -65,7 +65,7 @@ CREATE TABLE session ( ...@@ -65,7 +65,7 @@ CREATE TABLE session (
); );
CREATE TABLE session_key ( CREATE TABLE session_key (
key text NOT NULL UNIQUE, key text NOT NULL,
session_id text, session_id text,
expires_at bigint, expires_at bigint,
stale integer stale integer
......
...@@ -15,7 +15,7 @@ import ( ...@@ -15,7 +15,7 @@ import (
func newSessionRepo(t *testing.T) (session.SessionRepo, clockwork.FakeClock) { func newSessionRepo(t *testing.T) (session.SessionRepo, clockwork.FakeClock) {
clock := clockwork.NewFakeClock() clock := clockwork.NewFakeClock()
if os.Getenv("DEX_TEST_DSN") == "" { if os.Getenv("DEX_TEST_DSN") == "" {
return session.NewSessionRepoWithClock(clock), clock return db.NewSessionRepoWithClock(db.NewMemDB(), clock), clock
} }
dbMap := connect(t) dbMap := connect(t)
return db.NewSessionRepoWithClock(dbMap, clock), clock return db.NewSessionRepoWithClock(dbMap, clock), clock
...@@ -24,7 +24,7 @@ func newSessionRepo(t *testing.T) (session.SessionRepo, clockwork.FakeClock) { ...@@ -24,7 +24,7 @@ func newSessionRepo(t *testing.T) (session.SessionRepo, clockwork.FakeClock) {
func newSessionKeyRepo(t *testing.T) (session.SessionKeyRepo, clockwork.FakeClock) { func newSessionKeyRepo(t *testing.T) (session.SessionKeyRepo, clockwork.FakeClock) {
clock := clockwork.NewFakeClock() clock := clockwork.NewFakeClock()
if os.Getenv("DEX_TEST_DSN") == "" { if os.Getenv("DEX_TEST_DSN") == "" {
return session.NewSessionKeyRepoWithClock(clock), clock return db.NewSessionKeyRepoWithClock(db.NewMemDB(), clock), clock
} }
dbMap := connect(t) dbMap := connect(t)
return db.NewSessionKeyRepoWithClock(dbMap, clock), clock return db.NewSessionKeyRepoWithClock(dbMap, clock), clock
......
...@@ -10,10 +10,11 @@ import ( ...@@ -10,10 +10,11 @@ import (
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/db"
phttp "github.com/coreos/dex/pkg/http" phttp "github.com/coreos/dex/pkg/http"
"github.com/coreos/dex/refresh/refreshtest" "github.com/coreos/dex/refresh/refreshtest"
"github.com/coreos/dex/server" "github.com/coreos/dex/server"
"github.com/coreos/dex/session" "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/key"
...@@ -33,7 +34,7 @@ func mockServer(cis []oidc.ClientIdentity) (*server.Server, error) { ...@@ -33,7 +34,7 @@ func mockServer(cis []oidc.ClientIdentity) (*server.Server, error) {
return nil, err return nil, err
} }
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
srv := &server.Server{ srv := &server.Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
KeyManager: km, KeyManager: km,
...@@ -120,7 +121,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) { ...@@ -120,7 +121,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
cir := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci}) cir := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci})
issuerURL := url.URL{Scheme: "http", Host: "server.example.com"} issuerURL := url.URL{Scheme: "http", Host: "server.example.com"}
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
k, err := key.GeneratePrivateKey() k, err := key.GeneratePrivateKey()
if err != nil { if err != nil {
......
...@@ -19,10 +19,10 @@ import ( ...@@ -19,10 +19,10 @@ import (
"github.com/coreos/dex/email" "github.com/coreos/dex/email"
"github.com/coreos/dex/refresh" "github.com/coreos/dex/refresh"
"github.com/coreos/dex/repo" "github.com/coreos/dex/repo"
"github.com/coreos/dex/session" sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
useremail "github.com/coreos/dex/user/email" useremail "github.com/coreos/dex/user/email"
"github.com/coreos/dex/user/manager" usermanager "github.com/coreos/dex/user/manager"
) )
type ServerConfig struct { type ServerConfig struct {
...@@ -128,9 +128,9 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error { ...@@ -128,9 +128,9 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
} }
cfgRepo := connector.NewConnectorConfigRepoFromConfigs(cfgs) cfgRepo := connector.NewConnectorConfigRepoFromConfigs(cfgs)
sRepo := session.NewSessionRepo() sRepo := db.NewSessionRepo(db.NewMemDB())
skRepo := session.NewSessionKeyRepo() skRepo := db.NewSessionKeyRepo(db.NewMemDB())
sm := session.NewSessionManager(sRepo, skRepo) sm := sessionmanager.NewSessionManager(sRepo, skRepo)
userRepo, err := user.NewUserRepoFromFile(cfg.UsersFile) userRepo, err := user.NewUserRepoFromFile(cfg.UsersFile)
if err != nil { if err != nil {
...@@ -142,7 +142,7 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error { ...@@ -142,7 +142,7 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
refTokRepo := refresh.NewRefreshTokenRepo() refTokRepo := refresh.NewRefreshTokenRepo()
txnFactory := repo.InMemTransactionFactory txnFactory := repo.InMemTransactionFactory
userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, manager.ManagerOptions{}) userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{})
srv.ClientIdentityRepo = ciRepo srv.ClientIdentityRepo = ciRepo
srv.KeySetRepo = kRepo srv.KeySetRepo = kRepo
srv.ConnectorConfigRepo = cfgRepo srv.ConnectorConfigRepo = cfgRepo
...@@ -180,10 +180,10 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error { ...@@ -180,10 +180,10 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error {
cfgRepo := db.NewConnectorConfigRepo(dbc) cfgRepo := db.NewConnectorConfigRepo(dbc)
userRepo := db.NewUserRepo(dbc) userRepo := db.NewUserRepo(dbc)
pwiRepo := db.NewPasswordInfoRepo(dbc) pwiRepo := db.NewPasswordInfoRepo(dbc)
userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{}) userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), usermanager.ManagerOptions{})
refreshTokenRepo := db.NewRefreshTokenRepo(dbc) refreshTokenRepo := db.NewRefreshTokenRepo(dbc)
sm := session.NewSessionManager(sRepo, skRepo) sm := sessionmanager.NewSessionManager(sRepo, skRepo)
srv.ClientIdentityRepo = ciRepo srv.ClientIdentityRepo = ciRepo
srv.KeySetRepo = kRepo srv.KeySetRepo = kRepo
......
...@@ -17,7 +17,8 @@ import ( ...@@ -17,7 +17,8 @@ import (
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/session" "github.com/coreos/dex/db"
"github.com/coreos/dex/session/manager"
"github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/oauth2" "github.com/coreos/go-oidc/oauth2"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
...@@ -75,7 +76,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) { ...@@ -75,7 +76,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
} }
srv := &Server{ srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
SessionManager: session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()), SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())),
ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{ ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{
oidc.ClientIdentity{ oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
...@@ -198,7 +199,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) { ...@@ -198,7 +199,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
} }
srv := &Server{ srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
SessionManager: session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()), SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())),
ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{ ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{
oidc.ClientIdentity{ oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
......
...@@ -9,10 +9,10 @@ import ( ...@@ -9,10 +9,10 @@ import (
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/session" sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
useremail "github.com/coreos/dex/user/email" useremail "github.com/coreos/dex/user/email"
"github.com/coreos/dex/user/manager" usermanager "github.com/coreos/dex/user/manager"
) )
type sendResetPasswordEmailData struct { type sendResetPasswordEmailData struct {
...@@ -28,7 +28,7 @@ type sendResetPasswordEmailData struct { ...@@ -28,7 +28,7 @@ type sendResetPasswordEmailData struct {
type SendResetPasswordEmailHandler struct { type SendResetPasswordEmailHandler struct {
tpl *template.Template tpl *template.Template
emailer *useremail.UserEmailer emailer *useremail.UserEmailer
sm *session.SessionManager sm *sessionmanager.SessionManager
cr client.ClientIdentityRepo cr client.ClientIdentityRepo
} }
...@@ -182,7 +182,7 @@ type resetPasswordTemplateData struct { ...@@ -182,7 +182,7 @@ type resetPasswordTemplateData struct {
type ResetPasswordHandler struct { type ResetPasswordHandler struct {
tpl *template.Template tpl *template.Template
issuerURL url.URL issuerURL url.URL
um *manager.UserManager um *usermanager.UserManager
keysFunc func() ([]key.PublicKey, error) keysFunc func() ([]key.PublicKey, error)
} }
...@@ -238,7 +238,7 @@ func (r *resetPasswordRequest) handlePOST() { ...@@ -238,7 +238,7 @@ func (r *resetPasswordRequest) handlePOST() {
cbURL, err := r.h.um.ChangePassword(r.pwReset, plaintext) cbURL, err := r.h.um.ChangePassword(r.pwReset, plaintext)
if err != nil { if err != nil {
switch err { switch err {
case manager.ErrorPasswordAlreadyChanged: case usermanager.ErrorPasswordAlreadyChanged:
r.data.Error = "Link Expired" r.data.Error = "Link Expired"
r.data.Message = "The link in your email is no longer valid. If you need to change your password, generate a new email." r.data.Message = "The link in your email is no longer valid. If you need to change your password, generate a new email."
r.data.DontShowForm = true r.data.DontShowForm = true
......
...@@ -10,8 +10,9 @@ import ( ...@@ -10,8 +10,9 @@ import (
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/session" "github.com/coreos/dex/session"
sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager" usermanager "github.com/coreos/dex/user/manager"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
) )
...@@ -274,7 +275,7 @@ func makeClientRedirectURL(baseRedirURL url.URL, code, clientState string) *url. ...@@ -274,7 +275,7 @@ func makeClientRedirectURL(baseRedirURL url.URL, code, clientState string) *url.
return &ru return &ru
} }
func registerFromLocalConnector(userManager *manager.UserManager, sessionManager *session.SessionManager, ses *session.Session, email, password string) (string, error) { func registerFromLocalConnector(userManager *usermanager.UserManager, sessionManager *sessionmanager.SessionManager, ses *session.Session, email, password string) (string, error) {
userID, err := userManager.RegisterWithPassword(email, password, ses.ConnectorID) userID, err := userManager.RegisterWithPassword(email, password, ses.ConnectorID)
if err != nil { if err != nil {
return "", err return "", err
...@@ -289,7 +290,7 @@ func registerFromLocalConnector(userManager *manager.UserManager, sessionManager ...@@ -289,7 +290,7 @@ func registerFromLocalConnector(userManager *manager.UserManager, sessionManager
return userID, nil return userID, nil
} }
func registerFromRemoteConnector(userManager *manager.UserManager, ses *session.Session, email string, emailVerified bool) (string, error) { func registerFromRemoteConnector(userManager *usermanager.UserManager, ses *session.Session, email string, emailVerified bool) (string, error) {
if ses.Identity.ID == "" { if ses.Identity.ID == "" {
return "", errors.New("No Identity found in session.") return "", errors.New("No Identity found in session.")
} }
......
...@@ -22,10 +22,11 @@ import ( ...@@ -22,10 +22,11 @@ import (
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh" "github.com/coreos/dex/refresh"
"github.com/coreos/dex/session" "github.com/coreos/dex/session"
sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
usersapi "github.com/coreos/dex/user/api" usersapi "github.com/coreos/dex/user/api"
useremail "github.com/coreos/dex/user/email" useremail "github.com/coreos/dex/user/email"
"github.com/coreos/dex/user/manager" usermanager "github.com/coreos/dex/user/manager"
) )
const ( const (
...@@ -57,7 +58,7 @@ type Server struct { ...@@ -57,7 +58,7 @@ type Server struct {
IssuerURL url.URL IssuerURL url.URL
KeyManager key.PrivateKeyManager KeyManager key.PrivateKeyManager
KeySetRepo key.PrivateKeySetRepo KeySetRepo key.PrivateKeySetRepo
SessionManager *session.SessionManager SessionManager *sessionmanager.SessionManager
ClientIdentityRepo client.ClientIdentityRepo ClientIdentityRepo client.ClientIdentityRepo
ConnectorConfigRepo connector.ConnectorConfigRepo ConnectorConfigRepo connector.ConnectorConfigRepo
Templates *template.Template Templates *template.Template
...@@ -69,7 +70,7 @@ type Server struct { ...@@ -69,7 +70,7 @@ type Server struct {
HealthChecks []health.Checkable HealthChecks []health.Checkable
Connectors []connector.Connector Connectors []connector.Connector
UserRepo user.UserRepo UserRepo user.UserRepo
UserManager *manager.UserManager UserManager *usermanager.UserManager
PasswordInfoRepo user.PasswordInfoRepo PasswordInfoRepo user.PasswordInfoRepo
RefreshTokenRepo refresh.RefreshTokenRepo RefreshTokenRepo refresh.RefreshTokenRepo
UserEmailer *useremail.UserEmailer UserEmailer *useremail.UserEmailer
......
...@@ -10,8 +10,9 @@ import ( ...@@ -10,8 +10,9 @@ import (
"time" "time"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/db"
"github.com/coreos/dex/refresh/refreshtest" "github.com/coreos/dex/refresh/refreshtest"
"github.com/coreos/dex/session" "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
"github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/key"
...@@ -68,7 +69,7 @@ func (ss *StaticSigner) JWK() jose.JWK { ...@@ -68,7 +69,7 @@ func (ss *StaticSigner) JWK() jose.JWK {
return jose.JWK{} return jose.JWK{}
} }
func staticGenerateCodeFunc(code string) session.GenerateCodeFunc { func staticGenerateCodeFunc(code string) manager.GenerateCodeFunc {
return func() (string, error) { return func() (string, error) {
return code, nil return code, nil
} }
...@@ -120,7 +121,7 @@ func TestServerProviderConfig(t *testing.T) { ...@@ -120,7 +121,7 @@ func TestServerProviderConfig(t *testing.T) {
} }
func TestServerNewSession(t *testing.T) { func TestServerNewSession(t *testing.T) {
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
srv := &Server{ srv := &Server{
SessionManager: sm, SessionManager: sm,
} }
...@@ -197,7 +198,7 @@ func TestServerLogin(t *testing.T) { ...@@ -197,7 +198,7 @@ func TestServerLogin(t *testing.T) {
signer: &StaticSigner{sig: []byte("beer"), err: nil}, signer: &StaticSigner{sig: []byte("beer"), err: nil},
} }
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
sm.GenerateCode = staticGenerateCodeFunc("fakecode") sm.GenerateCode = staticGenerateCodeFunc("fakecode")
sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURIs[0], "", false, []string{"openid"}) sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURIs[0], "", false, []string{"openid"})
if err != nil { if err != nil {
...@@ -245,7 +246,7 @@ func TestServerLoginUnrecognizedSessionKey(t *testing.T) { ...@@ -245,7 +246,7 @@ func TestServerLoginUnrecognizedSessionKey(t *testing.T) {
km := &StaticKeyManager{ km := &StaticKeyManager{
signer: &StaticSigner{sig: nil, err: errors.New("fail")}, signer: &StaticSigner{sig: nil, err: errors.New("fail")},
} }
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
srv := &Server{ srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
KeyManager: km, KeyManager: km,
...@@ -286,7 +287,7 @@ func TestServerLoginDisabledUser(t *testing.T) { ...@@ -286,7 +287,7 @@ func TestServerLoginDisabledUser(t *testing.T) {
signer: &StaticSigner{sig: []byte("beer"), err: nil}, signer: &StaticSigner{sig: []byte("beer"), err: nil},
} }
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
sm.GenerateCode = staticGenerateCodeFunc("fakecode") sm.GenerateCode = staticGenerateCodeFunc("fakecode")
sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURIs[0], "", false, []string{"openid"}) sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURIs[0], "", false, []string{"openid"})
if err != nil { if err != nil {
...@@ -343,7 +344,7 @@ func TestServerCodeToken(t *testing.T) { ...@@ -343,7 +344,7 @@ func TestServerCodeToken(t *testing.T) {
km := &StaticKeyManager{ km := &StaticKeyManager{
signer: &StaticSigner{sig: []byte("beer"), err: nil}, signer: &StaticSigner{sig: []byte("beer"), err: nil},
} }
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
userRepo, err := makeNewUserRepo() userRepo, err := makeNewUserRepo()
if err != nil { if err != nil {
...@@ -424,7 +425,7 @@ func TestServerTokenUnrecognizedKey(t *testing.T) { ...@@ -424,7 +425,7 @@ func TestServerTokenUnrecognizedKey(t *testing.T) {
km := &StaticKeyManager{ km := &StaticKeyManager{
signer: &StaticSigner{sig: []byte("beer"), err: nil}, signer: &StaticSigner{sig: []byte("beer"), err: nil},
} }
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
srv := &Server{ srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
...@@ -518,7 +519,7 @@ func TestServerTokenFail(t *testing.T) { ...@@ -518,7 +519,7 @@ func TestServerTokenFail(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
sm.GenerateCode = func() (string, error) { return keyFixture, nil } sm.GenerateCode = func() (string, error) { return keyFixture, nil }
sessionID, err := sm.NewSession("connector_id", ccFixture.ID, "bogus", url.URL{}, "", false, tt.scope) sessionID, err := sm.NewSession("connector_id", ccFixture.ID, "bogus", url.URL{}, "", false, tt.scope)
......
...@@ -10,12 +10,13 @@ import ( ...@@ -10,12 +10,13 @@ import (
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
"github.com/coreos/dex/db"
"github.com/coreos/dex/email" "github.com/coreos/dex/email"
"github.com/coreos/dex/repo" "github.com/coreos/dex/repo"
"github.com/coreos/dex/session" sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user" "github.com/coreos/dex/user"
useremail "github.com/coreos/dex/user/email" useremail "github.com/coreos/dex/user/email"
"github.com/coreos/dex/user/manager" usermanager "github.com/coreos/dex/user/manager"
) )
const ( const (
...@@ -75,13 +76,13 @@ var ( ...@@ -75,13 +76,13 @@ var (
type testFixtures struct { type testFixtures struct {
srv *Server srv *Server
userRepo user.UserRepo userRepo user.UserRepo
sessionManager *session.SessionManager sessionManager *sessionmanager.SessionManager
emailer *email.TemplatizedEmailer emailer *email.TemplatizedEmailer
redirectURL url.URL redirectURL url.URL
clientIdentityRepo client.ClientIdentityRepo clientIdentityRepo client.ClientIdentityRepo
} }
func sequentialGenerateCodeFunc() session.GenerateCodeFunc { func sequentialGenerateCodeFunc() sessionmanager.GenerateCodeFunc {
x := 0 x := 0
return func() (string, error) { return func() (string, error) {
x += 1 x += 1
...@@ -113,9 +114,9 @@ func makeTestFixtures() (*testFixtures, error) { ...@@ -113,9 +114,9 @@ func makeTestFixtures() (*testFixtures, error) {
} }
connCfgRepo := connector.NewConnectorConfigRepoFromConfigs(connConfigs) connCfgRepo := connector.NewConnectorConfigRepoFromConfigs(connConfigs)
manager := manager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, manager.ManagerOptions{}) manager := usermanager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, usermanager.ManagerOptions{})
sessionManager := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sessionManager := sessionmanager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
sessionManager.GenerateCode = sequentialGenerateCodeFunc() sessionManager.GenerateCode = sequentialGenerateCodeFunc()
emailer, err := email.NewTemplatizedEmailerFromGlobs( emailer, err := email.NewTemplatizedEmailerFromGlobs(
......
package session package manager
import ( import (
"crypto/rand" "crypto/rand"
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"github.com/jonboulle/clockwork" "github.com/jonboulle/clockwork"
"github.com/coreos/dex/session"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
) )
...@@ -27,11 +28,11 @@ func DefaultGenerateCode() (string, error) { ...@@ -27,11 +28,11 @@ func DefaultGenerateCode() (string, error) {
return base64.URLEncoding.EncodeToString(b), nil return base64.URLEncoding.EncodeToString(b), nil
} }
func NewSessionManager(sRepo SessionRepo, skRepo SessionKeyRepo) *SessionManager { func NewSessionManager(sRepo session.SessionRepo, skRepo session.SessionKeyRepo) *SessionManager {
return &SessionManager{ return &SessionManager{
GenerateCode: DefaultGenerateCode, GenerateCode: DefaultGenerateCode,
Clock: clockwork.NewRealClock(), Clock: clockwork.NewRealClock(),
ValidityWindow: DefaultSessionValidityWindow, ValidityWindow: session.DefaultSessionValidityWindow,
sessions: sRepo, sessions: sRepo,
keys: skRepo, keys: skRepo,
} }
...@@ -41,8 +42,8 @@ type SessionManager struct { ...@@ -41,8 +42,8 @@ type SessionManager struct {
GenerateCode GenerateCodeFunc GenerateCode GenerateCodeFunc
Clock clockwork.Clock Clock clockwork.Clock
ValidityWindow time.Duration ValidityWindow time.Duration
sessions SessionRepo sessions session.SessionRepo
keys SessionKeyRepo keys session.SessionKeyRepo
} }
func (m *SessionManager) NewSession(connectorID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) { func (m *SessionManager) NewSession(connectorID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) {
...@@ -52,10 +53,10 @@ func (m *SessionManager) NewSession(connectorID, clientID, clientState string, r ...@@ -52,10 +53,10 @@ func (m *SessionManager) NewSession(connectorID, clientID, clientState string, r
} }
now := m.Clock.Now() now := m.Clock.Now()
s := Session{ s := session.Session{
ConnectorID: connectorID, ConnectorID: connectorID,
ID: sID, ID: sID,
State: SessionStateNew, State: session.SessionStateNew,
CreatedAt: now, CreatedAt: now,
ExpiresAt: now.Add(m.ValidityWindow), ExpiresAt: now.Add(m.ValidityWindow),
ClientID: clientID, ClientID: clientID,
...@@ -80,11 +81,12 @@ func (m *SessionManager) NewSessionKey(sessionID string) (string, error) { ...@@ -80,11 +81,12 @@ func (m *SessionManager) NewSessionKey(sessionID string) (string, error) {
return "", err return "", err
} }
k := SessionKey{ k := session.SessionKey{
Key: key, Key: key,
SessionID: sessionID, SessionID: sessionID,
} }
sessionKeyValidityWindow := 10 * time.Minute //RFC6749
err = m.keys.Push(k, sessionKeyValidityWindow) err = m.keys.Push(k, sessionKeyValidityWindow)
if err != nil { if err != nil {
return "", err return "", err
...@@ -97,7 +99,7 @@ func (m *SessionManager) ExchangeKey(key string) (string, error) { ...@@ -97,7 +99,7 @@ func (m *SessionManager) ExchangeKey(key string) (string, error) {
return m.keys.Pop(key) return m.keys.Pop(key)
} }
func (m *SessionManager) getSessionInState(sessionID string, state SessionState) (*Session, error) { func (m *SessionManager) getSessionInState(sessionID string, state session.SessionState) (*session.Session, error) {
s, err := m.sessions.Get(sessionID) s, err := m.sessions.Get(sessionID)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -110,14 +112,14 @@ func (m *SessionManager) getSessionInState(sessionID string, state SessionState) ...@@ -110,14 +112,14 @@ func (m *SessionManager) getSessionInState(sessionID string, state SessionState)
return s, nil return s, nil
} }
func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Identity) (*Session, error) { func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Identity) (*session.Session, error) {
s, err := m.getSessionInState(sessionID, SessionStateNew) s, err := m.getSessionInState(sessionID, session.SessionStateNew)
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.Identity = ident s.Identity = ident
s.State = SessionStateRemoteAttached s.State = session.SessionStateRemoteAttached
if err = m.sessions.Update(*s); err != nil { if err = m.sessions.Update(*s); err != nil {
return nil, err return nil, err
...@@ -126,14 +128,14 @@ func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Ident ...@@ -126,14 +128,14 @@ func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Ident
return s, nil return s, nil
} }
func (m *SessionManager) AttachUser(sessionID string, userID string) (*Session, error) { func (m *SessionManager) AttachUser(sessionID string, userID string) (*session.Session, error) {
s, err := m.getSessionInState(sessionID, SessionStateRemoteAttached) s, err := m.getSessionInState(sessionID, session.SessionStateRemoteAttached)
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.UserID = userID s.UserID = userID
s.State = SessionStateIdentified s.State = session.SessionStateIdentified
if err = m.sessions.Update(*s); err != nil { if err = m.sessions.Update(*s); err != nil {
return nil, err return nil, err
...@@ -142,13 +144,13 @@ func (m *SessionManager) AttachUser(sessionID string, userID string) (*Session, ...@@ -142,13 +144,13 @@ func (m *SessionManager) AttachUser(sessionID string, userID string) (*Session,
return s, nil return s, nil
} }
func (m *SessionManager) Kill(sessionID string) (*Session, error) { func (m *SessionManager) Kill(sessionID string) (*session.Session, error) {
s, err := m.sessions.Get(sessionID) s, err := m.sessions.Get(sessionID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.State = SessionStateDead s.State = session.SessionStateDead
if err = m.sessions.Update(*s); err != nil { if err = m.sessions.Update(*s); err != nil {
return nil, err return nil, err
...@@ -157,6 +159,6 @@ func (m *SessionManager) Kill(sessionID string) (*Session, error) { ...@@ -157,6 +159,6 @@ func (m *SessionManager) Kill(sessionID string) (*Session, error) {
return s, nil return s, nil
} }
func (m *SessionManager) Get(sessionID string) (*Session, error) { func (m *SessionManager) Get(sessionID string) (*session.Session, error) {
return m.sessions.Get(sessionID) return m.sessions.Get(sessionID)
} }
package session package manager
import ( import (
"net/url" "net/url"
"testing" "testing"
"github.com/coreos/dex/db"
"github.com/coreos/dex/session"
"github.com/coreos/go-oidc/oidc" "github.com/coreos/go-oidc/oidc"
) )
...@@ -13,8 +15,13 @@ func staticGenerateCodeFunc(code string) GenerateCodeFunc { ...@@ -13,8 +15,13 @@ func staticGenerateCodeFunc(code string) GenerateCodeFunc {
} }
} }
func newManager(t *testing.T) *SessionManager {
dbMap := db.NewMemDB()
return NewSessionManager(db.NewSessionRepo(dbMap), db.NewSessionKeyRepo(dbMap))
}
func TestSessionManagerNewSession(t *testing.T) { func TestSessionManagerNewSession(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) sm := newManager(t)
sm.GenerateCode = staticGenerateCodeFunc("boo") sm.GenerateCode = staticGenerateCodeFunc("boo")
got, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, []string{"openid"}) got, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
if err != nil { if err != nil {
...@@ -26,7 +33,7 @@ func TestSessionManagerNewSession(t *testing.T) { ...@@ -26,7 +33,7 @@ func TestSessionManagerNewSession(t *testing.T) {
} }
func TestSessionAttachRemoteIdentityTwice(t *testing.T) { func TestSessionAttachRemoteIdentityTwice(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) sm := newManager(t)
sessionID, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, []string{"openid"}) sessionID, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
...@@ -43,7 +50,7 @@ func TestSessionAttachRemoteIdentityTwice(t *testing.T) { ...@@ -43,7 +50,7 @@ func TestSessionAttachRemoteIdentityTwice(t *testing.T) {
} }
func TestSessionManagerExchangeKey(t *testing.T) { func TestSessionManagerExchangeKey(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) sm := newManager(t)
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"}) sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
...@@ -68,8 +75,8 @@ func TestSessionManagerExchangeKey(t *testing.T) { ...@@ -68,8 +75,8 @@ func TestSessionManagerExchangeKey(t *testing.T) {
} }
func TestSessionManagerGetSessionInStateNoExist(t *testing.T) { func TestSessionManagerGetSessionInStateNoExist(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) sm := newManager(t)
ses, err := sm.getSessionInState("123", SessionStateNew) ses, err := sm.getSessionInState("123", session.SessionStateNew)
if err == nil { if err == nil {
t.Errorf("Expected non-nil error") t.Errorf("Expected non-nil error")
} }
...@@ -79,12 +86,12 @@ func TestSessionManagerGetSessionInStateNoExist(t *testing.T) { ...@@ -79,12 +86,12 @@ func TestSessionManagerGetSessionInStateNoExist(t *testing.T) {
} }
func TestSessionManagerGetSessionInStateWrongState(t *testing.T) { func TestSessionManagerGetSessionInStateWrongState(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) sm := newManager(t)
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"}) sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
ses, err := sm.getSessionInState(sessionID, SessionStateDead) ses, err := sm.getSessionInState(sessionID, session.SessionStateDead)
if err == nil { if err == nil {
t.Errorf("Expected non-nil error") t.Errorf("Expected non-nil error")
} }
...@@ -94,7 +101,7 @@ func TestSessionManagerGetSessionInStateWrongState(t *testing.T) { ...@@ -94,7 +101,7 @@ func TestSessionManagerGetSessionInStateWrongState(t *testing.T) {
} }
func TestSessionManagerKill(t *testing.T) { func TestSessionManagerKill(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) sm := newManager(t)
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"}) sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
......
package session package session
import ( import "time"
"errors"
"time"
"github.com/jonboulle/clockwork"
)
type SessionRepo interface { type SessionRepo interface {
Get(string) (*Session, error) Get(string) (*Session, error)
...@@ -17,87 +12,3 @@ type SessionKeyRepo interface { ...@@ -17,87 +12,3 @@ type SessionKeyRepo interface {
Push(SessionKey, time.Duration) error Push(SessionKey, time.Duration) error
Pop(string) (string, error) Pop(string) (string, error)
} }
func NewSessionRepo() SessionRepo {
return NewSessionRepoWithClock(clockwork.NewRealClock())
}
func NewSessionRepoWithClock(clock clockwork.Clock) SessionRepo {
return &memSessionRepo{
store: make(map[string]Session),
clock: clock,
}
}
type memSessionRepo struct {
store map[string]Session
clock clockwork.Clock
}
func (m *memSessionRepo) Get(sessionID string) (*Session, error) {
s, ok := m.store[sessionID]
if !ok || s.ExpiresAt.Before(m.clock.Now()) {
return nil, errors.New("unrecognized ID")
}
return &s, nil
}
func (m *memSessionRepo) Create(s Session) error {
if _, ok := m.store[s.ID]; ok {
return errors.New("ID exists")
}
m.store[s.ID] = s
return nil
}
func (m *memSessionRepo) Update(s Session) error {
if _, ok := m.store[s.ID]; !ok {
return errors.New("unrecognized ID")
}
m.store[s.ID] = s
return nil
}
type expiringSessionKey struct {
SessionKey
expiresAt time.Time
}
func NewSessionKeyRepo() SessionKeyRepo {
return NewSessionKeyRepoWithClock(clockwork.NewRealClock())
}
func NewSessionKeyRepoWithClock(clock clockwork.Clock) SessionKeyRepo {
return &memSessionKeyRepo{
store: make(map[string]expiringSessionKey),
clock: clock,
}
}
type memSessionKeyRepo struct {
store map[string]expiringSessionKey
clock clockwork.Clock
}
func (m *memSessionKeyRepo) Pop(key string) (string, error) {
esk, ok := m.store[key]
if !ok {
return "", errors.New("unrecognized key")
}
defer delete(m.store, key)
if esk.expiresAt.Before(m.clock.Now()) {
return "", errors.New("expired key")
}
return esk.SessionKey.SessionID, nil
}
func (m *memSessionKeyRepo) Push(sk SessionKey, ttl time.Duration) error {
m.store[sk.Key] = expiringSessionKey{
SessionKey: sk,
expiresAt: m.clock.Now().Add(ttl),
}
return nil
}
...@@ -14,7 +14,7 @@ COVER=${COVER:-"-cover"} ...@@ -14,7 +14,7 @@ COVER=${COVER:-"-cover"}
source ./build source ./build
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session user user/api user/manager email admin" TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session session/manager user user/api user/manager email admin"
FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log" FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log"
# user has not provided PKG override # user has not provided PKG override
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment