Commit 07a77e0d authored by LanceH's avatar LanceH

Use connector_id param to skip directly to a specific connector

parent 39dc5dcf
......@@ -233,6 +233,18 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
return
}
// Redirect if a client chooses a specific connector_id
if authReq.ConnectorID != "" {
for _, c := range connectors {
if c.ID == authReq.ConnectorID {
http.Redirect(w, r, s.absPath("/auth", c.ID)+"?req="+authReq.ID, http.StatusFound)
return
}
}
s.tokenErrHelper(w, errInvalidConnectorID, "Connector ID does not match a valid Connector", http.StatusNotFound)
return
}
if len(connectors) == 1 {
for _, c := range connectors {
// TODO(ericchiang): Make this pass on r.URL.RawQuery and let something latter
......
......@@ -100,6 +100,7 @@ const (
errUnsupportedGrantType = "unsupported_grant_type"
errInvalidGrant = "invalid_grant"
errInvalidClient = "invalid_client"
errInvalidConnectorID = "invalid_connector_id"
)
const (
......@@ -391,6 +392,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
clientID := q.Get("client_id")
state := q.Get("state")
nonce := q.Get("nonce")
connectorID := q.Get("connector_id")
// Some clients, like the old go-oidc, provide extra whitespace. Tolerate this.
scopes := strings.Fields(q.Get("scope"))
responseTypes := strings.Fields(q.Get("response_type"))
......@@ -405,6 +407,16 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
return req, &authErr{"", "", errServerError, ""}
}
if connectorID != "" {
connectors, err := s.storage.ListConnectors()
if err != nil {
return req, &authErr{"", "", errServerError, "Unable to retrieve connectors"}
}
if !validateConnectorID(connectors, connectorID) {
return req, &authErr{"", "", errInvalidRequest, "Invalid ConnectorID"}
}
}
if !validateRedirectURI(client, redirectURI) {
description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI)
return req, &authErr{"", "", errInvalidRequest, description}
......@@ -509,6 +521,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
Scopes: scopes,
RedirectURI: redirectURI,
ResponseTypes: responseTypes,
ConnectorID: connectorID,
}, nil
}
......@@ -568,6 +581,15 @@ func validateRedirectURI(client storage.Client, redirectURI string) bool {
return err == nil && host == "localhost"
}
func validateConnectorID(connectors []storage.Connector, connectorID string) bool {
for _, c := range connectors {
if c.ID == connectorID {
return true
}
}
return false
}
// storageKeySet implements the oidc.KeySet interface backed by Dex storage
type storageKeySet struct {
storage.Storage
......
......@@ -10,7 +10,7 @@ import (
"strings"
"testing"
jose "gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/memory"
......@@ -145,6 +145,58 @@ func TestParseAuthorizationRequest(t *testing.T) {
},
wantErr: true,
},
{
name: "choose connector_id",
clients: []storage.Client{
{
ID: "bar",
RedirectURIs: []string{"https://example.com/bar"},
},
},
supportedResponseTypes: []string{"code", "id_token", "token"},
queryParams: map[string]string{
"connector_id": "mock",
"client_id": "bar",
"redirect_uri": "https://example.com/bar",
"response_type": "code id_token",
"scope": "openid email profile",
},
},
{
name: "choose second connector_id",
clients: []storage.Client{
{
ID: "bar",
RedirectURIs: []string{"https://example.com/bar"},
},
},
supportedResponseTypes: []string{"code", "id_token", "token"},
queryParams: map[string]string{
"connector_id": "mock2",
"client_id": "bar",
"redirect_uri": "https://example.com/bar",
"response_type": "code id_token",
"scope": "openid email profile",
},
},
{
name: "choose invalid connector_id",
clients: []storage.Client{
{
ID: "bar",
RedirectURIs: []string{"https://example.com/bar"},
},
},
supportedResponseTypes: []string{"code", "id_token", "token"},
queryParams: map[string]string{
"connector_id": "bogus",
"client_id": "bar",
"redirect_uri": "https://example.com/bar",
"response_type": "code id_token",
"scope": "openid email profile",
},
wantErr: true,
},
}
for _, tc := range tests {
......@@ -152,7 +204,7 @@ func TestParseAuthorizationRequest(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
httpServer, server := newTestServer(ctx, t, func(c *Config) {
httpServer, server := newTestServerMultipleConnectors(ctx, t, func(c *Config) {
c.SupportedResponseTypes = tc.supportedResponseTypes
c.Storage = storage.WithStaticClients(c.Storage, tc.clients)
})
......@@ -162,7 +214,6 @@ func TestParseAuthorizationRequest(t *testing.T) {
for k, v := range tc.queryParams {
params.Set(k, v)
}
var req *http.Request
if tc.usePOST {
body := strings.NewReader(params.Encode())
......
......@@ -117,6 +117,53 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi
return s, server
}
func newTestServerMultipleConnectors(ctx context.Context, t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) {
var server *Server
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server.ServeHTTP(w, r)
}))
config := Config{
Issuer: s.URL,
Storage: memory.New(logger),
Web: WebConfig{
Dir: "../web",
},
Logger: logger,
PrometheusRegistry: prometheus.NewRegistry(),
}
if updateConfig != nil {
updateConfig(&config)
}
s.URL = config.Issuer
connector := storage.Connector{
ID: "mock",
Type: "mockCallback",
Name: "Mock",
ResourceVersion: "1",
}
connector2 := storage.Connector{
ID: "mock2",
Type: "mockCallback",
Name: "Mock",
ResourceVersion: "1",
}
if err := config.Storage.CreateConnector(connector); err != nil {
t.Fatalf("create connector: %v", err)
}
if err := config.Storage.CreateConnector(connector2); err != nil {
t.Fatalf("create connector: %v", err)
}
var err error
if server, err = newServer(ctx, config, staticRotationStrategy(testKey)); err != nil {
t.Fatal(err)
}
server.skipApproval = true // Don't prompt for approval, just immediately redirect with code.
return s, server
}
func TestNewTestServer(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment