Commit 48b3b38c authored by bobbyrullo's avatar bobbyrullo

Merge pull request #144 from bobbyrullo/no_register

server,cmd: Add flag for disabling registation
parents ac78c8f4 d3d6a75b
...@@ -44,6 +44,8 @@ func main() { ...@@ -44,6 +44,8 @@ func main() {
emailFrom := fs.String("email-from", "no-reply@coreos.com", "emails sent from dex will come from this address") emailFrom := fs.String("email-from", "no-reply@coreos.com", "emails sent from dex will come from this address")
emailConfig := fs.String("email-cfg", "./static/fixtures/emailer.json", "configures emailer.") emailConfig := fs.String("email-cfg", "./static/fixtures/emailer.json", "configures emailer.")
enableRegistration := fs.Bool("enable-registration", false, "Allows users to self-register")
noDB := fs.Bool("no-db", false, "manage entities in-process w/o any encryption, used only for single-node testing") noDB := fs.Bool("no-db", false, "manage entities in-process w/o any encryption, used only for single-node testing")
// UI-related: // UI-related:
...@@ -120,6 +122,7 @@ func main() { ...@@ -120,6 +122,7 @@ func main() {
EmailerConfigFile: *emailConfig, EmailerConfigFile: *emailConfig,
IssuerName: *issuerName, IssuerName: *issuerName,
IssuerLogoURL: *issuerLogoURL, IssuerLogoURL: *issuerLogoURL,
EnableRegistration: *enableRegistration,
} }
if *noDB { if *noDB {
......
...@@ -33,6 +33,7 @@ type ServerConfig struct { ...@@ -33,6 +33,7 @@ type ServerConfig struct {
EmailFromAddress string EmailFromAddress string
EmailerConfigFile string EmailerConfigFile string
StateConfig StateConfigurer StateConfig StateConfigurer
EnableRegistration bool
} }
type StateConfigurer interface { type StateConfigurer interface {
...@@ -56,7 +57,7 @@ func (cfg *ServerConfig) Server() (*Server, error) { ...@@ -56,7 +57,7 @@ func (cfg *ServerConfig) Server() (*Server, error) {
return nil, err return nil, err
} }
tpl, err := getTemplates(cfg.IssuerName, cfg.IssuerLogoURL, cfg.TemplateDir) tpl, err := getTemplates(cfg.IssuerName, cfg.IssuerLogoURL, cfg.EnableRegistration, cfg.TemplateDir)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -69,6 +70,8 @@ func (cfg *ServerConfig) Server() (*Server, error) { ...@@ -69,6 +70,8 @@ func (cfg *ServerConfig) Server() (*Server, error) {
HealthChecks: []health.Checkable{km}, HealthChecks: []health.Checkable{km},
Connectors: []connector.Connector{}, Connectors: []connector.Connector{},
EnableRegistration: cfg.EnableRegistration,
} }
err = cfg.StateConfig.Configure(&srv) err = cfg.StateConfig.Configure(&srv)
...@@ -183,7 +186,8 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error { ...@@ -183,7 +186,8 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error {
return nil return nil
} }
func getTemplates(issuerName, issuerLogoURL string, dir string) (*template.Template, error) { func getTemplates(issuerName, issuerLogoURL string,
enableRegister bool, dir string) (*template.Template, error) {
tpl := template.New("").Funcs(map[string]interface{}{ tpl := template.New("").Funcs(map[string]interface{}{
"issuerName": func() string { "issuerName": func() string {
return issuerName return issuerName
...@@ -191,6 +195,9 @@ func getTemplates(issuerName, issuerLogoURL string, dir string) (*template.Templ ...@@ -191,6 +195,9 @@ func getTemplates(issuerName, issuerLogoURL string, dir string) (*template.Templ
"issuerLogoURL": func() string { "issuerLogoURL": func() string {
return issuerLogoURL return issuerLogoURL
}, },
"enableRegister": func() bool {
return enableRegister
},
}) })
return tpl.ParseGlob(dir + "/*.html") return tpl.ParseGlob(dir + "/*.html")
......
...@@ -254,7 +254,7 @@ func renderLoginPage(w http.ResponseWriter, r *http.Request, srv OIDCServer, idp ...@@ -254,7 +254,7 @@ func renderLoginPage(w http.ResponseWriter, r *http.Request, srv OIDCServer, idp
execTemplate(w, tpl, td) execTemplate(w, tpl, td)
} }
func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.Template) http.HandlerFunc { func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.Template, registrationEnabled bool) http.HandlerFunc {
idx := makeConnectorMap(idpcs) idx := makeConnectorMap(idpcs)
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" { if r.Method != "GET" {
...@@ -264,7 +264,7 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T ...@@ -264,7 +264,7 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T
} }
q := r.URL.Query() q := r.URL.Query()
register := q.Get("register") == "1" register := q.Get("register") == "1" && registrationEnabled
e := q.Get("error") e := q.Get("error")
if e != "" { if e != "" {
sessionKey := q.Get("state") sessionKey := q.Get("state")
......
...@@ -51,7 +51,7 @@ func (c *fakeConnector) TrustedEmailProvider() bool { ...@@ -51,7 +51,7 @@ func (c *fakeConnector) TrustedEmailProvider() bool {
func TestHandleAuthFuncMethodNotAllowed(t *testing.T) { func TestHandleAuthFuncMethodNotAllowed(t *testing.T) {
for _, m := range []string{"POST", "PUT", "DELETE"} { for _, m := range []string{"POST", "PUT", "DELETE"} {
hdlr := handleAuthFunc(nil, nil, nil) hdlr := handleAuthFunc(nil, nil, nil, true)
req, err := http.NewRequest(m, "http://example.com", nil) req, err := http.NewRequest(m, "http://example.com", nil)
if err != nil { if err != nil {
t.Errorf("case %s: unable to create HTTP request: %v", m, err) t.Errorf("case %s: unable to create HTTP request: %v", m, err)
...@@ -170,7 +170,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) { ...@@ -170,7 +170,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
hdlr := handleAuthFunc(srv, idpcs, nil) hdlr := handleAuthFunc(srv, idpcs, nil, true)
w := httptest.NewRecorder() w := httptest.NewRecorder()
u := fmt.Sprintf("http://server.example.com?%s", tt.query.Encode()) u := fmt.Sprintf("http://server.example.com?%s", tt.query.Encode())
req, err := http.NewRequest("GET", u, nil) req, err := http.NewRequest("GET", u, nil)
...@@ -271,7 +271,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) { ...@@ -271,7 +271,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
hdlr := handleAuthFunc(srv, idpcs, nil) hdlr := handleAuthFunc(srv, idpcs, nil, true)
w := httptest.NewRecorder() w := httptest.NewRecorder()
u := fmt.Sprintf("http://server.example.com?%s", tt.query.Encode()) u := fmt.Sprintf("http://server.example.com?%s", tt.query.Encode())
req, err := http.NewRequest("GET", u, nil) req, err := http.NewRequest("GET", u, nil)
......
...@@ -72,6 +72,7 @@ type Server struct { ...@@ -72,6 +72,7 @@ type Server struct {
PasswordInfoRepo user.PasswordInfoRepo PasswordInfoRepo user.PasswordInfoRepo
RefreshTokenRepo refresh.RefreshTokenRepo RefreshTokenRepo refresh.RefreshTokenRepo
UserEmailer *useremail.UserEmailer UserEmailer *useremail.UserEmailer
EnableRegistration bool
localConnectorID string localConnectorID string
} }
...@@ -198,11 +199,15 @@ func (s *Server) HTTPHandler() http.Handler { ...@@ -198,11 +199,15 @@ func (s *Server) HTTPHandler() http.Handler {
clock := clockwork.NewRealClock() clock := clockwork.NewRealClock()
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc(httpPathDiscovery, handleDiscoveryFunc(s.ProviderConfig())) mux.HandleFunc(httpPathDiscovery, handleDiscoveryFunc(s.ProviderConfig()))
mux.HandleFunc(httpPathAuth, handleAuthFunc(s, s.Connectors, s.LoginTemplate)) mux.HandleFunc(httpPathAuth, handleAuthFunc(s, s.Connectors, s.LoginTemplate, s.EnableRegistration))
mux.HandleFunc(httpPathToken, handleTokenFunc(s)) mux.HandleFunc(httpPathToken, handleTokenFunc(s))
mux.HandleFunc(httpPathKeys, handleKeysFunc(s.KeyManager, clock)) mux.HandleFunc(httpPathKeys, handleKeysFunc(s.KeyManager, clock))
mux.Handle(httpPathHealth, makeHealthHandler(checks)) mux.Handle(httpPathHealth, makeHealthHandler(checks))
if s.EnableRegistration {
mux.HandleFunc(httpPathRegister, handleRegisterFunc(s)) mux.HandleFunc(httpPathRegister, handleRegisterFunc(s))
}
mux.HandleFunc(httpPathEmailVerify, handleEmailVerifyFunc(s.VerifyEmailTemplate, mux.HandleFunc(httpPathEmailVerify, handleEmailVerifyFunc(s.VerifyEmailTemplate,
s.IssuerURL, s.KeyManager.PublicKeys, s.UserManager)) s.IssuerURL, s.KeyManager.PublicKeys, s.UserManager))
......
...@@ -126,7 +126,8 @@ func makeTestFixtures() (*testFixtures, error) { ...@@ -126,7 +126,8 @@ func makeTestFixtures() (*testFixtures, error) {
return nil, err return nil, err
} }
tpl, err := getTemplates("dex", "https://coreos.com/assets/images/brand/coreos-mark-30px.png", templatesLocation) tpl, err := getTemplates("dex", "https://coreos.com/assets/images/brand/coreos-mark-30px.png",
true, templatesLocation)
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
...@@ -70,8 +70,10 @@ ...@@ -70,8 +70,10 @@
{{ if .Register }} {{ if .Register }}
Already have an account? <a href="{{ .RegisterOrLoginURL }}">Log in</a> Already have an account? <a href="{{ .RegisterOrLoginURL }}">Log in</a>
{{ else }} {{ else }}
{{ if enableRegister }}
Don't have an account yet? <a href="{{ .RegisterOrLoginURL }}">Register</a> Don't have an account yet? <a href="{{ .RegisterOrLoginURL }}">Register</a>
{{ end }} {{ end }}
{{ end }}
</div> </div>
{{ end }} {{ end }}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment