Commit 3da456ef authored by Yifan Gu's avatar Yifan Gu Committed by Yifan Gu

dex-worker: add TLS support.

Add two new flags '--cert-file' and '--key-file'.
If scheme == 'https', then we will use the two new flags to get
the cert/key pair for TLS connection.

Also add '--ca-file' to the example app to allow TLS connection to the
dex-worker using a specified ca file.
parent 5abc7633
...@@ -27,8 +27,11 @@ func init() { ...@@ -27,8 +27,11 @@ func init() {
func main() { func main() {
fs := flag.NewFlagSet("dex-worker", flag.ExitOnError) fs := flag.NewFlagSet("dex-worker", flag.ExitOnError)
listen := fs.String("listen", "http://0.0.0.0:5556", "") listen := fs.String("listen", "http://127.0.0.1:5556", "the address that the server will listen on")
issuer := fs.String("issuer", "http://127.0.0.1:5556", "") issuer := fs.String("issuer", "http://127.0.0.1:5556", "the issuer's location")
certFile := fs.String("tls-cert-file", "", "the server's certificate file for TLS connection")
keyFile := fs.String("tls-key-file", "", "the server's private key file for TLS connection")
templates := fs.String("html-assets", "./static/html", "directory of html template files") templates := fs.String("html-assets", "./static/html", "directory of html template files")
emailTemplateDirs := flagutil.StringSliceFlag{"./static/email"} emailTemplateDirs := flagutil.StringSliceFlag{"./static/email"}
...@@ -75,13 +78,30 @@ func main() { ...@@ -75,13 +78,30 @@ func main() {
log.EnableTimestamps() log.EnableTimestamps()
} }
// Validate listen address.
lu, err := url.Parse(*listen) lu, err := url.Parse(*listen)
if err != nil { if err != nil {
log.Fatalf("Unable to use --listen flag: %v", err) log.Fatalf("Invalid listen address %q: %v", *listen, err)
} }
if lu.Scheme != "http" { switch lu.Scheme {
log.Fatalf("Unable to listen using scheme %s", lu.Scheme) case "http":
case "https":
if *certFile == "" || *keyFile == "" {
log.Fatalf("Must provide certificate file and private key file")
}
default:
log.Fatalf("Only 'http' and 'https' schemes are supported")
}
// Validate issuer address.
iu, err := url.Parse(*issuer)
if err != nil {
log.Fatalf("Invalid issuer URL %q: %v", *issuer, err)
}
if iu.Scheme != "http" && iu.Scheme != "https" {
log.Fatalf("Only 'http' and 'https' schemes are supported")
} }
scfg := server.ServerConfig{ scfg := server.ServerConfig{
...@@ -145,7 +165,11 @@ func main() { ...@@ -145,7 +165,11 @@ func main() {
log.Infof("Binding to %s...", httpsrv.Addr) log.Infof("Binding to %s...", httpsrv.Addr)
go func() { go func() {
if lu.Scheme == "http" {
log.Fatal(httpsrv.ListenAndServe()) log.Fatal(httpsrv.ListenAndServe())
} else {
log.Fatal(httpsrv.ListenAndServeTLS(*certFile, *keyFile))
}
}() }()
<-srv.Run() <-srv.Run()
......
...@@ -2,9 +2,12 @@ package main ...@@ -2,9 +2,12 @@ package main
import ( import (
"bytes" "bytes"
"crypto/tls"
"crypto/x509"
"encoding/json" "encoding/json"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
...@@ -27,6 +30,8 @@ func main() { ...@@ -27,6 +30,8 @@ func main() {
redirectURL := fs.String("redirect-url", "http://127.0.0.1:5555/callback", "") redirectURL := fs.String("redirect-url", "http://127.0.0.1:5555/callback", "")
clientID := fs.String("client-id", "", "") clientID := fs.String("client-id", "", "")
clientSecret := fs.String("client-secret", "", "") clientSecret := fs.String("client-secret", "", "")
caFile := fs.String("ca-file", "", "the TLS CA file, if empty then the host's root CA will be used")
discovery := fs.String("discovery", "https://accounts.google.com", "") discovery := fs.String("discovery", "https://accounts.google.com", "")
logDebug := fs.Bool("log-debug", false, "log debug-level information") logDebug := fs.Bool("log-debug", false, "log debug-level information")
logTimestamps := fs.Bool("log-timestamps", false, "prefix log lines with timestamps") logTimestamps := fs.Bool("log-timestamps", false, "prefix log lines with timestamps")
...@@ -71,9 +76,22 @@ func main() { ...@@ -71,9 +76,22 @@ func main() {
Secret: *clientSecret, Secret: *clientSecret,
} }
var tlsConfig tls.Config
if *caFile != "" {
roots := x509.NewCertPool()
pemBlock, err := ioutil.ReadFile(*caFile)
if err != nil {
log.Fatalf("Unable to read ca file: %v", err)
}
roots.AppendCertsFromPEM(pemBlock)
tlsConfig.RootCAs = roots
}
httpClient := &http.Client{Transport: &http.Transport{TLSClientConfig: &tlsConfig}}
var cfg oidc.ProviderConfig var cfg oidc.ProviderConfig
for { for {
cfg, err = oidc.FetchProviderConfig(http.DefaultClient, *discovery) cfg, err = oidc.FetchProviderConfig(httpClient, *discovery)
if err == nil { if err == nil {
break break
} }
...@@ -86,6 +104,7 @@ func main() { ...@@ -86,6 +104,7 @@ func main() {
log.Infof("Fetched provider config from %s: %#v", *discovery, cfg) log.Infof("Fetched provider config from %s: %#v", *discovery, cfg)
ccfg := oidc.ClientConfig{ ccfg := oidc.ClientConfig{
HTTPClient: httpClient,
ProviderConfig: cfg, ProviderConfig: cfg,
Credentials: cc, Credentials: cc,
RedirectURL: *redirectURL, RedirectURL: *redirectURL,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment