Commit 76b9eb1d authored by rithu john's avatar rithu john

connector/github: add support for github enterprise.

parent 943253fe
...@@ -31,4 +31,37 @@ connectors: ...@@ -31,4 +31,37 @@ connectors:
org: my-oranization org: my-oranization
``` ```
## GitHub Enterprise
Users can use their GitHub Enterprise account to login to dex. The following configuration can be used to enable a GitHub Enterprise connector on dex:
```yaml
connectors:
- type: github
# Required field for connector id.
id: github
# Required field for connector name.
name: GitHub
config:
# Required fields. Dex must be pre-registered with GitHub Enterprise
# to get the following values.
# Credentials can be string literals or pulled from the environment.
clientID: $GITHUB_CLIENT_ID
clientSecret: $GITHUB_CLIENT_SECRET
redirectURI: http://127.0.0.1:5556/dex/callback
# Optional organization to pull teams from, communicate through the
# "groups" scope.
#
# NOTE: This is an EXPERIMENTAL config option and will likely change.
org: my-oranization
# Required ONLY for GitHub Enterprise.
# This is the Hostname of the GitHub Enterprise account listed on the
# management console. Ensure this domain is routable on your network.
hostName: git.example.com
# ONLY for GitHub Enterprise. Optional field.
# Used to support self-signed or untrusted CA root certificates.
rootCA: /etc/dex/ca.crt
```
[github-oauth2]: https://github.com/settings/applications/new [github-oauth2]: https://github.com/settings/applications/new
...@@ -3,13 +3,18 @@ package github ...@@ -3,13 +3,18 @@ package github
import ( import (
"context" "context"
"crypto/tls"
"crypto/x509"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net"
"net/http" "net/http"
"regexp" "regexp"
"strconv" "strconv"
"strings"
"time"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/oauth2/github" "golang.org/x/oauth2/github"
...@@ -19,7 +24,7 @@ import ( ...@@ -19,7 +24,7 @@ import (
) )
const ( const (
baseURL = "https://api.github.com" apiURL = "https://api.github.com"
scopeEmail = "user:email" scopeEmail = "user:email"
scopeOrgs = "read:org" scopeOrgs = "read:org"
) )
...@@ -30,17 +35,45 @@ type Config struct { ...@@ -30,17 +35,45 @@ type Config struct {
ClientSecret string `json:"clientSecret"` ClientSecret string `json:"clientSecret"`
RedirectURI string `json:"redirectURI"` RedirectURI string `json:"redirectURI"`
Org string `json:"org"` Org string `json:"org"`
HostName string `json:"hostName"`
RootCA string `json:"rootCA"`
} }
// Open returns a strategy for logging in through GitHub. // Open returns a strategy for logging in through GitHub.
func (c *Config) Open(logger logrus.FieldLogger) (connector.Connector, error) { func (c *Config) Open(logger logrus.FieldLogger) (connector.Connector, error) {
return &githubConnector{ g := githubConnector{
redirectURI: c.RedirectURI, redirectURI: c.RedirectURI,
org: c.Org, org: c.Org,
clientID: c.ClientID, clientID: c.ClientID,
clientSecret: c.ClientSecret, clientSecret: c.ClientSecret,
apiURL: apiURL,
logger: logger, logger: logger,
}, nil }
if c.HostName != "" {
// ensure this is a hostname and not a URL or path.
if strings.Contains(c.HostName, "/") {
return nil, errors.New("invalid hostname: hostname cannot contain `/`")
}
g.hostName = c.HostName
g.apiURL = "https://" + c.HostName + "/api/v3"
}
if c.RootCA != "" {
if c.HostName == "" {
return nil, errors.New("invalid connector config: Host name field required for a root certificate file")
}
g.rootCA = c.RootCA
var err error
if g.httpClient, err = newHTTPClient(g.rootCA); err != nil {
return nil, fmt.Errorf("failed to create HTTP client: %v", err)
}
}
return &g, nil
} }
type connectorData struct { type connectorData struct {
...@@ -59,6 +92,14 @@ type githubConnector struct { ...@@ -59,6 +92,14 @@ type githubConnector struct {
clientID string clientID string
clientSecret string clientSecret string
logger logrus.FieldLogger logger logrus.FieldLogger
// apiURL defaults to "https://api.github.com"
apiURL string
// hostName of the GitHub enterprise account.
hostName string
// Used to support untrusted/self-signed CA certs.
rootCA string
// HTTP Client that trusts the custom delcared rootCA cert.
httpClient *http.Client
} }
func (c *githubConnector) oauth2Config(scopes connector.Scopes) *oauth2.Config { func (c *githubConnector) oauth2Config(scopes connector.Scopes) *oauth2.Config {
...@@ -68,10 +109,21 @@ func (c *githubConnector) oauth2Config(scopes connector.Scopes) *oauth2.Config { ...@@ -68,10 +109,21 @@ func (c *githubConnector) oauth2Config(scopes connector.Scopes) *oauth2.Config {
} else { } else {
githubScopes = []string{scopeEmail} githubScopes = []string{scopeEmail}
} }
endpoint := github.Endpoint
// case when it is a GitHub Enterprise account.
if c.hostName != "" {
endpoint = oauth2.Endpoint{
AuthURL: "https://" + c.hostName + "/login/oauth/authorize",
TokenURL: "https://" + c.hostName + "/login/oauth/access_token",
}
}
return &oauth2.Config{ return &oauth2.Config{
ClientID: c.clientID, ClientID: c.clientID,
ClientSecret: c.clientSecret, ClientSecret: c.clientSecret,
Endpoint: github.Endpoint, Endpoint: endpoint,
Scopes: githubScopes, Scopes: githubScopes,
} }
} }
...@@ -80,6 +132,7 @@ func (c *githubConnector) LoginURL(scopes connector.Scopes, callbackURL, state s ...@@ -80,6 +132,7 @@ func (c *githubConnector) LoginURL(scopes connector.Scopes, callbackURL, state s
if c.redirectURI != callbackURL { if c.redirectURI != callbackURL {
return "", fmt.Errorf("expected callback URL did not match the URL in the config") return "", fmt.Errorf("expected callback URL did not match the URL in the config")
} }
return c.oauth2Config(scopes).AuthCodeURL(state), nil return c.oauth2Config(scopes).AuthCodeURL(state), nil
} }
...@@ -95,6 +148,34 @@ func (e *oauth2Error) Error() string { ...@@ -95,6 +148,34 @@ func (e *oauth2Error) Error() string {
return e.error + ": " + e.errorDescription return e.error + ": " + e.errorDescription
} }
// newHTTPClient returns a new HTTP client that trusts the custom delcared rootCA cert.
func newHTTPClient(rootCA string) (*http.Client, error) {
tlsConfig := tls.Config{RootCAs: x509.NewCertPool()}
rootCABytes, err := ioutil.ReadFile(rootCA)
if err != nil {
return nil, fmt.Errorf("failed to read root-ca: %v", err)
}
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCABytes) {
return nil, fmt.Errorf("no certs found in root CA file %q", rootCA)
}
return &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tlsConfig,
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
}, nil
}
func (c *githubConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { func (c *githubConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query() q := r.URL.Query()
if errType := q.Get("error"); errType != "" { if errType := q.Get("error"); errType != "" {
...@@ -102,7 +183,12 @@ func (c *githubConnector) HandleCallback(s connector.Scopes, r *http.Request) (i ...@@ -102,7 +183,12 @@ func (c *githubConnector) HandleCallback(s connector.Scopes, r *http.Request) (i
} }
oauth2Config := c.oauth2Config(s) oauth2Config := c.oauth2Config(s)
ctx := r.Context() ctx := r.Context()
// GitHub Enterprise account
if c.httpClient != nil {
ctx = context.WithValue(r.Context(), oauth2.HTTPClient, c.httpClient)
}
token, err := oauth2Config.Exchange(ctx, q.Get("code")) token, err := oauth2Config.Exchange(ctx, q.Get("code"))
if err != nil { if err != nil {
...@@ -192,7 +278,7 @@ type user struct { ...@@ -192,7 +278,7 @@ type user struct {
// a bearer token as part of the request. // a bearer token as part of the request.
func (c *githubConnector) user(ctx context.Context, client *http.Client) (user, error) { func (c *githubConnector) user(ctx context.Context, client *http.Client) (user, error) {
var u user var u user
req, err := http.NewRequest("GET", baseURL+"/user", nil) req, err := http.NewRequest("GET", c.apiURL+"/user", nil)
if err != nil { if err != nil {
return u, fmt.Errorf("github: new req: %v", err) return u, fmt.Errorf("github: new req: %v", err)
} }
...@@ -228,7 +314,7 @@ func (c *githubConnector) teams(ctx context.Context, client *http.Client, org st ...@@ -228,7 +314,7 @@ func (c *githubConnector) teams(ctx context.Context, client *http.Client, org st
// https://developer.github.com/v3/#pagination // https://developer.github.com/v3/#pagination
reNext := regexp.MustCompile("<(.*)>; rel=\"next\"") reNext := regexp.MustCompile("<(.*)>; rel=\"next\"")
reLast := regexp.MustCompile("<(.*)>; rel=\"last\"") reLast := regexp.MustCompile("<(.*)>; rel=\"last\"")
apiURL := baseURL + "/user/teams" apiURL := c.apiURL + "/user/teams"
for { for {
req, err := http.NewRequest("GET", apiURL, nil) req, err := http.NewRequest("GET", apiURL, nil)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment