Commit 066fd859 authored by Yifan Gu's avatar Yifan Gu

session: add 'scope' field in session.

parent d87b5c9b
-- +migrate Up
ALTER TABLE session ADD COLUMN "scope" text;
This diff is collapsed.
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/url" "net/url"
"strings"
"time" "time"
"github.com/go-gorp/gorp" "github.com/go-gorp/gorp"
...@@ -42,6 +43,7 @@ type sessionModel struct { ...@@ -42,6 +43,7 @@ type sessionModel struct {
UserID string `db:"user_id"` UserID string `db:"user_id"`
Register bool `db:"register"` Register bool `db:"register"`
Nonce string `db:"nonce"` Nonce string `db:"nonce"`
Scope string `db:"scope"`
} }
func (s *sessionModel) session() (*session.Session, error) { func (s *sessionModel) session() (*session.Session, error) {
...@@ -71,6 +73,7 @@ func (s *sessionModel) session() (*session.Session, error) { ...@@ -71,6 +73,7 @@ func (s *sessionModel) session() (*session.Session, error) {
UserID: s.UserID, UserID: s.UserID,
Register: s.Register, Register: s.Register,
Nonce: s.Nonce, Nonce: s.Nonce,
Scope: strings.Fields(s.Scope),
} }
if s.CreatedAt != 0 { if s.CreatedAt != 0 {
...@@ -101,6 +104,7 @@ func newSessionModel(s *session.Session) (*sessionModel, error) { ...@@ -101,6 +104,7 @@ func newSessionModel(s *session.Session) (*sessionModel, error) {
UserID: s.UserID, UserID: s.UserID,
Register: s.Register, Register: s.Register,
Nonce: s.Nonce, Nonce: s.Nonce,
Scope: strings.Join(s.Scope, " "),
} }
if !s.CreatedAt.IsZero() { if !s.CreatedAt.IsZero() {
......
...@@ -196,7 +196,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) { ...@@ -196,7 +196,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
// this will actually happen due to some interaction between the // this will actually happen due to some interaction between the
// end-user and a remote identity provider // end-user and a remote identity provider
sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false) sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false, nil)
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
......
...@@ -332,7 +332,7 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T ...@@ -332,7 +332,7 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T
nonce := q.Get("nonce") nonce := q.Get("nonce")
key, err := srv.NewSession(connectorID, acr.ClientID, acr.State, redirectURL, nonce, register) key, err := srv.NewSession(connectorID, acr.ClientID, acr.State, redirectURL, nonce, register, acr.Scope)
if err != nil { if err != nil {
log.Errorf("Error creating new session: %v: ", err) log.Errorf("Error creating new session: %v: ", err)
redirectAuthError(w, err, acr.State, redirectURL) redirectAuthError(w, err, acr.State, redirectURL)
......
...@@ -245,7 +245,7 @@ func TestSendResetPasswordEmailHandler(t *testing.T) { ...@@ -245,7 +245,7 @@ func TestSendResetPasswordEmailHandler(t *testing.T) {
t.Fatalf("case %d: could not make test fixtures: %v", i, err) t.Fatalf("case %d: could not make test fixtures: %v", i, err)
} }
_, err = f.srv.NewSession("local", "XXX", "", f.redirectURL, "", true) _, err = f.srv.NewSession("local", "XXX", "", f.redirectURL, "", true, nil)
if err != nil { if err != nil {
t.Fatalf("case %d: could not create new session: %v", i, err) t.Fatalf("case %d: could not create new session: %v", i, err)
} }
......
...@@ -197,7 +197,7 @@ func TestHandleRegister(t *testing.T) { ...@@ -197,7 +197,7 @@ func TestHandleRegister(t *testing.T) {
t.Fatalf("case %d: could not make test fixtures: %v", i, err) t.Fatalf("case %d: could not make test fixtures: %v", i, err)
} }
key, err := f.srv.NewSession(tt.connID, "XXX", "", f.redirectURL, "", true) key, err := f.srv.NewSession(tt.connID, "XXX", "", f.redirectURL, "", true, nil)
t.Logf("case %d: key for NewSession: %v", i, key) t.Logf("case %d: key for NewSession: %v", i, key)
if tt.attachRemote { if tt.attachRemote {
......
...@@ -39,7 +39,7 @@ const ( ...@@ -39,7 +39,7 @@ const (
type OIDCServer interface { type OIDCServer interface {
ClientMetadata(string) (*oidc.ClientMetadata, error) ClientMetadata(string) (*oidc.ClientMetadata, error)
NewSession(connectorID, clientID, clientState string, redirectURL url.URL, nonce string, register bool) (string, error) NewSession(connectorID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error)
Login(oidc.Identity, string) (string, error) Login(oidc.Identity, string) (string, error)
// CodeToken exchanges a code for an ID token and a refresh token string on success. // CodeToken exchanges a code for an ID token and a refresh token string on success.
CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, error) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, error)
...@@ -263,8 +263,8 @@ func (s *Server) ClientMetadata(clientID string) (*oidc.ClientMetadata, error) { ...@@ -263,8 +263,8 @@ func (s *Server) ClientMetadata(clientID string) (*oidc.ClientMetadata, error) {
return s.ClientIdentityRepo.Metadata(clientID) return s.ClientIdentityRepo.Metadata(clientID)
} }
func (s *Server) NewSession(ipdcID, clientID, clientState string, redirectURL url.URL, nonce string, register bool) (string, error) { func (s *Server) NewSession(ipdcID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) {
sessionID, err := s.SessionManager.NewSession(ipdcID, clientID, clientState, redirectURL, nonce, register) sessionID, err := s.SessionManager.NewSession(ipdcID, clientID, clientState, redirectURL, nonce, register, scope)
if err != nil { if err != nil {
return "", err return "", err
} }
......
...@@ -139,7 +139,7 @@ func TestServerNewSession(t *testing.T) { ...@@ -139,7 +139,7 @@ func TestServerNewSession(t *testing.T) {
}, },
} }
key, err := srv.NewSession("bogus_idpc", ci.Credentials.ID, state, ci.Metadata.RedirectURLs[0], nonce, false) key, err := srv.NewSession("bogus_idpc", ci.Credentials.ID, state, ci.Metadata.RedirectURLs[0], nonce, false, nil)
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
...@@ -195,7 +195,7 @@ func TestServerLogin(t *testing.T) { ...@@ -195,7 +195,7 @@ func TestServerLogin(t *testing.T) {
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
sm.GenerateCode = staticGenerateCodeFunc("fakecode") sm.GenerateCode = staticGenerateCodeFunc("fakecode")
sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURLs[0], "", false) sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURLs[0], "", false, nil)
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
...@@ -292,7 +292,7 @@ func TestServerCodeToken(t *testing.T) { ...@@ -292,7 +292,7 @@ func TestServerCodeToken(t *testing.T) {
RefreshTokenRepo: refreshTokenRepo, RefreshTokenRepo: refreshTokenRepo,
} }
sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false) sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false, nil)
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
...@@ -343,7 +343,7 @@ func TestServerTokenUnrecognizedKey(t *testing.T) { ...@@ -343,7 +343,7 @@ func TestServerTokenUnrecognizedKey(t *testing.T) {
ClientIdentityRepo: ciRepo, ClientIdentityRepo: ciRepo,
} }
sessionID, err := sm.NewSession("connector_id", ci.Credentials.ID, "bogus", url.URL{}, "", false) sessionID, err := sm.NewSession("connector_id", ci.Credentials.ID, "bogus", url.URL{}, "", false, nil)
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
...@@ -416,7 +416,7 @@ func TestServerTokenFail(t *testing.T) { ...@@ -416,7 +416,7 @@ func TestServerTokenFail(t *testing.T) {
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
sm.GenerateCode = func() (string, error) { return keyFixture, nil } sm.GenerateCode = func() (string, error) { return keyFixture, nil }
sessionID, err := sm.NewSession("connector_id", ccFixture.ID, "bogus", url.URL{}, "", false) sessionID, err := sm.NewSession("connector_id", ccFixture.ID, "bogus", url.URL{}, "", false, nil)
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
......
...@@ -44,7 +44,7 @@ type SessionManager struct { ...@@ -44,7 +44,7 @@ type SessionManager struct {
keys SessionKeyRepo keys SessionKeyRepo
} }
func (m *SessionManager) NewSession(connectorID, clientID, clientState string, redirectURL url.URL, nonce string, register bool) (string, error) { func (m *SessionManager) NewSession(connectorID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) {
sID, err := m.GenerateCode() sID, err := m.GenerateCode()
if err != nil { if err != nil {
return "", err return "", err
...@@ -62,6 +62,7 @@ func (m *SessionManager) NewSession(connectorID, clientID, clientState string, r ...@@ -62,6 +62,7 @@ func (m *SessionManager) NewSession(connectorID, clientID, clientState string, r
RedirectURL: redirectURL, RedirectURL: redirectURL,
Register: register, Register: register,
Nonce: nonce, Nonce: nonce,
Scope: scope,
} }
err = m.sessions.Create(s) err = m.sessions.Create(s)
......
...@@ -16,7 +16,7 @@ func staticGenerateCodeFunc(code string) GenerateCodeFunc { ...@@ -16,7 +16,7 @@ func staticGenerateCodeFunc(code string) GenerateCodeFunc {
func TestSessionManagerNewSession(t *testing.T) { func TestSessionManagerNewSession(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
sm.GenerateCode = staticGenerateCodeFunc("boo") sm.GenerateCode = staticGenerateCodeFunc("boo")
got, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false) got, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, nil)
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
...@@ -27,7 +27,7 @@ func TestSessionManagerNewSession(t *testing.T) { ...@@ -27,7 +27,7 @@ func TestSessionManagerNewSession(t *testing.T) {
func TestSessionAttachRemoteIdentityTwice(t *testing.T) { func TestSessionAttachRemoteIdentityTwice(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
sessionID, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false) sessionID, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, nil)
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
...@@ -44,7 +44,7 @@ func TestSessionAttachRemoteIdentityTwice(t *testing.T) { ...@@ -44,7 +44,7 @@ func TestSessionAttachRemoteIdentityTwice(t *testing.T) {
func TestSessionManagerExchangeKey(t *testing.T) { func TestSessionManagerExchangeKey(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false) sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, nil)
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
...@@ -80,7 +80,7 @@ func TestSessionManagerGetSessionInStateNoExist(t *testing.T) { ...@@ -80,7 +80,7 @@ func TestSessionManagerGetSessionInStateNoExist(t *testing.T) {
func TestSessionManagerGetSessionInStateWrongState(t *testing.T) { func TestSessionManagerGetSessionInStateWrongState(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false) sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, nil)
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
...@@ -95,7 +95,7 @@ func TestSessionManagerGetSessionInStateWrongState(t *testing.T) { ...@@ -95,7 +95,7 @@ func TestSessionManagerGetSessionInStateWrongState(t *testing.T) {
func TestSessionManagerKill(t *testing.T) { func TestSessionManagerKill(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false) sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, nil)
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
......
...@@ -48,6 +48,9 @@ type Session struct { ...@@ -48,6 +48,9 @@ type Session struct {
// Nonce is optionally provided in the initial authorization request, and propogated in such cases to the generated claims. // Nonce is optionally provided in the initial authorization request, and propogated in such cases to the generated claims.
Nonce string Nonce string
// Scope is the 'scope' field in the authentication request. Example scopes are 'openid', 'email', 'offline', etc.
Scope []string
} }
// Claims returns a new set of Claims for the current session. // Claims returns a new set of Claims for the current session.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment