Commit 04fa8354 authored by bobbyrullo's avatar bobbyrullo

Merge pull request #353 from fnordahl/issue/156

vendor: update go-oidc to latest
parents a846016c bbaea52e
hash: 2542cf3debe2db35e04c4cd29d8db83f7b52d90167f4e55be5a2b24f5b74c65b hash: 7e7b258be2aa7e03d1bb15c96f856a9b7982d850a1faeba699a0b11bcb0c6936
updated: 2016-04-25T17:06:03.25154684-07:00 updated: 2016-05-27T00:22:02.656877293+02:00
imports: imports:
- name: github.com/andybalholm/cascadia - name: github.com/andybalholm/cascadia
version: 6122e68c2642b7b75c538a63b15168c6c80fb757 version: 6122e68c2642b7b75c538a63b15168c6c80fb757
- name: github.com/coreos/go-oidc - name: github.com/coreos/go-oidc
version: 46fc3c20dcad27d27596c8832b354621bdd8b88f version: e6174c764e906bd60c76fdfc33faf5e0bdc875d6
subpackages: subpackages:
- http - http
- jose - jose
......
...@@ -5,7 +5,7 @@ import: ...@@ -5,7 +5,7 @@ import:
- package: github.com/andybalholm/cascadia - package: github.com/andybalholm/cascadia
version: 6122e68c2642b7b75c538a63b15168c6c80fb757 version: 6122e68c2642b7b75c538a63b15168c6c80fb757
- package: github.com/coreos/go-oidc - package: github.com/coreos/go-oidc
version: 46fc3c20dcad27d27596c8832b354621bdd8b88f version: e6174c764e906bd60c76fdfc33faf5e0bdc875d6
subpackages: subpackages:
- http - http
- jose - jose
......
...@@ -159,9 +159,9 @@ func TestInvitationHandler(t *testing.T) { ...@@ -159,9 +159,9 @@ func TestInvitationHandler(t *testing.T) {
t.Errorf("case %d: password token is invalid: %v", i, err) t.Errorf("case %d: password token is invalid: %v", i, err)
} }
expTime := pwrReset.Claims["exp"].(float64) expTime := pwrReset.Claims["exp"].(int64)
if expTime > float64(tZero.Add(handler.redirectValidityWindow).Unix()) || if expTime > tZero.Add(handler.redirectValidityWindow).Unix() ||
expTime < float64(tZero.Unix()) { expTime < tZero.Unix() {
t.Errorf("case %d: funny expiration time detected: %d", i, pwrReset.Claims["exp"]) t.Errorf("case %d: funny expiration time detected: %d", i, pwrReset.Claims["exp"])
} }
......
...@@ -34,8 +34,8 @@ func TestSessionClaims(t *testing.T) { ...@@ -34,8 +34,8 @@ func TestSessionClaims(t *testing.T) {
"iss": issuerURL, "iss": issuerURL,
"sub": "elroy-id", "sub": "elroy-id",
"aud": "XXX", "aud": "XXX",
"iat": float64(now.Unix()), "iat": now.Unix(),
"exp": float64(now.Add(time.Hour).Unix()), "exp": now.Add(time.Hour).Unix(),
}, },
}, },
...@@ -57,8 +57,8 @@ func TestSessionClaims(t *testing.T) { ...@@ -57,8 +57,8 @@ func TestSessionClaims(t *testing.T) {
"iss": issuerURL, "iss": issuerURL,
"sub": "elroy-id", "sub": "elroy-id",
"aud": "XXX", "aud": "XXX",
"iat": float64(now.Unix()), "iat": now.Unix(),
"exp": float64(now.Add(time.Hour).Unix()), "exp": now.Add(time.Hour).Unix(),
}, },
}, },
// Nonce gets propagated. // Nonce gets propagated.
...@@ -79,8 +79,8 @@ func TestSessionClaims(t *testing.T) { ...@@ -79,8 +79,8 @@ func TestSessionClaims(t *testing.T) {
"iss": issuerURL, "iss": issuerURL,
"sub": "elroy-id", "sub": "elroy-id",
"aud": "XXX", "aud": "XXX",
"iat": float64(now.Unix()), "iat": now.Unix(),
"exp": float64(now.Add(time.Hour).Unix()), "exp": now.Add(time.Hour).Unix(),
"nonce": "oncenay", "nonce": "oncenay",
}, },
}, },
......
...@@ -45,9 +45,9 @@ func TestNewEmailVerification(t *testing.T) { ...@@ -45,9 +45,9 @@ func TestNewEmailVerification(t *testing.T) {
"aud": clientID, "aud": clientID,
ClaimEmailVerificationCallback: callback, ClaimEmailVerificationCallback: callback,
ClaimEmailVerificationEmail: usr.Email, ClaimEmailVerificationEmail: usr.Email,
"exp": float64(now.Add(expires).Unix()), "exp": now.Add(expires).Unix(),
"sub": usr.ID, "sub": usr.ID,
"iat": float64(now.Unix()), "iat": now.Unix(),
}, },
}, },
} }
......
...@@ -106,9 +106,9 @@ func TestNewPasswordReset(t *testing.T) { ...@@ -106,9 +106,9 @@ func TestNewPasswordReset(t *testing.T) {
"aud": clientID, "aud": clientID,
ClaimPasswordResetCallback: callback, ClaimPasswordResetCallback: callback,
ClaimPasswordResetPassword: string(password), ClaimPasswordResetPassword: string(password),
"exp": float64(now.Add(expires).Unix()), "exp": now.Add(expires).Unix(),
"sub": usr.ID, "sub": usr.ID,
"iat": float64(now.Unix()), "iat": now.Unix(),
}, },
}, },
} }
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"math"
"time" "time"
"net/mail" "net/mail"
...@@ -259,5 +260,27 @@ func parseAndVerifyTokenClaims(token string, issuer url.URL, keys []key.PublicKe ...@@ -259,5 +260,27 @@ func parseAndVerifyTokenClaims(token string, issuer url.URL, keys []key.PublicKe
return TokenClaims{}, err return TokenClaims{}, err
} }
timeClaimsToInt(claims)
return TokenClaims{claims}, nil return TokenClaims{claims}, nil
} }
// timeClaimsToInt converts float64 time claims to ints.
// This is unfortunately neccessary for interop as some clients incorrectly fail
// to marshal floats as times.
func timeClaimsToInt(claims jose.Claims) {
for _, k := range []string{"exp", "iat"} {
v, ok := claims[k]
if !ok {
continue
}
fVal, ok := v.(float64)
if !ok {
continue
}
// round
claims[k] = int64(fVal + math.Copysign(0.5, fVal))
}
}
...@@ -2,12 +2,12 @@ language: go ...@@ -2,12 +2,12 @@ language: go
go: go:
- 1.4.3 - 1.4.3
- 1.5.2 - 1.5.4
- 1.6.1
install: install:
- go get -v -t ./... - go get -v -t ./...
- go get golang.org/x/tools/cmd/cover - go get golang.org/x/tools/cmd/cover
- go get golang.org/x/tools/cmd/vet
script: script:
- ./test - ./test
......
package http package http
import ( import "net/http"
"io/ioutil"
"net/http"
"net/http/httptest"
)
type Client interface { type Client interface {
Do(*http.Request) (*http.Response, error) Do(*http.Request) (*http.Response, error)
} }
type HandlerClient struct {
Handler http.Handler
}
func (hc *HandlerClient) Do(r *http.Request) (*http.Response, error) {
w := httptest.NewRecorder()
hc.Handler.ServeHTTP(w, r)
resp := http.Response{
StatusCode: w.Code,
Header: w.Header(),
Body: ioutil.NopCloser(w.Body),
}
return &resp, nil
}
type RequestRecorder struct {
Response *http.Response
Error error
Request *http.Request
}
func (rr *RequestRecorder) Do(req *http.Request) (*http.Response, error) {
rr.Request = req
if rr.Response == nil && rr.Error == nil {
panic("RequestRecorder Response and Error cannot both be nil")
} else if rr.Response != nil && rr.Error != nil {
panic("RequestRecorder Response and Error cannot both be non-nil")
}
return rr.Response, rr.Error
}
func (rr *RequestRecorder) RoundTrip(req *http.Request) (*http.Response, error) {
return rr.Do(req)
}
...@@ -3,6 +3,7 @@ package jose ...@@ -3,6 +3,7 @@ package jose
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"math"
"time" "time"
) )
...@@ -70,13 +71,33 @@ func (c Claims) Int64Claim(name string) (int64, bool, error) { ...@@ -70,13 +71,33 @@ func (c Claims) Int64Claim(name string) (int64, bool, error) {
return v, true, nil return v, true, nil
} }
func (c Claims) Float64Claim(name string) (float64, bool, error) {
cl, ok := c[name]
if !ok {
return 0, false, nil
}
v, ok := cl.(float64)
if !ok {
vi, ok := cl.(int64)
if !ok {
return 0, false, fmt.Errorf("unable to parse claim as float64: %v", name)
}
v = float64(vi)
}
return v, true, nil
}
func (c Claims) TimeClaim(name string) (time.Time, bool, error) { func (c Claims) TimeClaim(name string) (time.Time, bool, error) {
v, ok, err := c.Int64Claim(name) v, ok, err := c.Float64Claim(name)
if !ok || err != nil { if !ok || err != nil {
return time.Time{}, ok, err return time.Time{}, ok, err
} }
return time.Unix(v, 0).UTC(), true, nil s := math.Trunc(v)
ns := (v - s) * math.Pow(10, 9)
return time.Unix(int64(s), int64(ns)).UTC(), true, nil
} }
func decodeClaims(payload []byte) (Claims, error) { func decodeClaims(payload []byte) (Claims, error) {
......
...@@ -2,7 +2,6 @@ package jose ...@@ -2,7 +2,6 @@ package jose
import ( import (
"fmt" "fmt"
"strings"
) )
type Verifier interface { type Verifier interface {
...@@ -17,7 +16,7 @@ type Signer interface { ...@@ -17,7 +16,7 @@ type Signer interface {
} }
func NewVerifier(jwk JWK) (Verifier, error) { func NewVerifier(jwk JWK) (Verifier, error) {
if strings.ToUpper(jwk.Type) != "RSA" { if jwk.Type != "RSA" {
return nil, fmt.Errorf("unsupported key type %q", jwk.Type) return nil, fmt.Errorf("unsupported key type %q", jwk.Type)
} }
......
...@@ -7,7 +7,6 @@ import ( ...@@ -7,7 +7,6 @@ import (
_ "crypto/sha256" _ "crypto/sha256"
"errors" "errors"
"fmt" "fmt"
"strings"
) )
type VerifierHMAC struct { type VerifierHMAC struct {
...@@ -21,7 +20,7 @@ type SignerHMAC struct { ...@@ -21,7 +20,7 @@ type SignerHMAC struct {
} }
func NewVerifierHMAC(jwk JWK) (*VerifierHMAC, error) { func NewVerifierHMAC(jwk JWK) (*VerifierHMAC, error) {
if strings.ToUpper(jwk.Alg) != "HS256" { if jwk.Alg != "" && jwk.Alg != "HS256" {
return nil, fmt.Errorf("unsupported key algorithm %q", jwk.Alg) return nil, fmt.Errorf("unsupported key algorithm %q", jwk.Alg)
} }
......
...@@ -5,7 +5,6 @@ import ( ...@@ -5,7 +5,6 @@ import (
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"fmt" "fmt"
"strings"
) )
type VerifierRSA struct { type VerifierRSA struct {
...@@ -20,7 +19,7 @@ type SignerRSA struct { ...@@ -20,7 +19,7 @@ type SignerRSA struct {
} }
func NewVerifierRSA(jwk JWK) (*VerifierRSA, error) { func NewVerifierRSA(jwk JWK) (*VerifierRSA, error) {
if strings.ToUpper(jwk.Alg) != "RS256" { if jwk.Alg != "" && jwk.Alg != "RS256" {
return nil, fmt.Errorf("unsupported key algorithm %q", jwk.Alg) return nil, fmt.Errorf("unsupported key algorithm %q", jwk.Alg)
} }
......
...@@ -20,7 +20,7 @@ type PublicKey struct { ...@@ -20,7 +20,7 @@ type PublicKey struct {
} }
func (k *PublicKey) MarshalJSON() ([]byte, error) { func (k *PublicKey) MarshalJSON() ([]byte, error) {
return json.Marshal(k.jwk) return json.Marshal(&k.jwk)
} }
func (k *PublicKey) UnmarshalJSON(data []byte) error { func (k *PublicKey) UnmarshalJSON(data []byte) error {
......
...@@ -66,3 +66,24 @@ func TestPublicKeySetKey(t *testing.T) { ...@@ -66,3 +66,24 @@ func TestPublicKeySetKey(t *testing.T) {
t.Errorf("Expected nil response from PublicKeySet.Key, got %#v", got) t.Errorf("Expected nil response from PublicKeySet.Key, got %#v", got)
} }
} }
func TestPublicKeyMarshalJSON(t *testing.T) {
k := jose.JWK{
ID: "foo",
Type: "RSA",
Alg: "RS256",
Use: "sig",
Modulus: big.NewInt(int64(17)),
Exponent: 65537,
}
want := `{"kid":"foo","kty":"RSA","alg":"RS256","use":"sig","e":"AQAB","n":"EQ=="}`
pubKey := NewPublicKey(k)
gotBytes, err := pubKey.MarshalJSON()
if err != nil {
t.Fatalf("failed to marshal public key: %v", err)
}
got := string(gotBytes)
if got != want {
t.Errorf("got != want:\n%s\n%s", got, want)
}
}
...@@ -56,6 +56,7 @@ const ( ...@@ -56,6 +56,7 @@ const (
const ( const (
GrantTypeAuthCode = "authorization_code" GrantTypeAuthCode = "authorization_code"
GrantTypeClientCreds = "client_credentials" GrantTypeClientCreds = "client_credentials"
GrantTypeUserCreds = "password"
GrantTypeImplicit = "implicit" GrantTypeImplicit = "implicit"
GrantTypeRefreshToken = "refresh_token" GrantTypeRefreshToken = "refresh_token"
...@@ -140,6 +141,11 @@ func NewClient(hc phttp.Client, cfg Config) (c *Client, err error) { ...@@ -140,6 +141,11 @@ func NewClient(hc phttp.Client, cfg Config) (c *Client, err error) {
return return
} }
// Return the embedded HTTP client
func (c *Client) HttpClient() phttp.Client {
return c.hc
}
// Generate the url for initial redirect to oauth provider. // Generate the url for initial redirect to oauth provider.
func (c *Client) AuthCodeURL(state, accessType, prompt string) string { func (c *Client) AuthCodeURL(state, accessType, prompt string) string {
v := c.commonURLValues() v := c.commonURLValues()
...@@ -171,22 +177,24 @@ func (c *Client) commonURLValues() url.Values { ...@@ -171,22 +177,24 @@ func (c *Client) commonURLValues() url.Values {
} }
} }
func (c *Client) newAuthenticatedRequest(url string, values url.Values) (*http.Request, error) { func (c *Client) newAuthenticatedRequest(urlToken string, values url.Values) (*http.Request, error) {
var req *http.Request var req *http.Request
var err error var err error
switch c.authMethod { switch c.authMethod {
case AuthMethodClientSecretPost: case AuthMethodClientSecretPost:
values.Set("client_secret", c.creds.Secret) values.Set("client_secret", c.creds.Secret)
req, err = http.NewRequest("POST", url, strings.NewReader(values.Encode())) req, err = http.NewRequest("POST", urlToken, strings.NewReader(values.Encode()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
case AuthMethodClientSecretBasic: case AuthMethodClientSecretBasic:
req, err = http.NewRequest("POST", url, strings.NewReader(values.Encode())) req, err = http.NewRequest("POST", urlToken, strings.NewReader(values.Encode()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.SetBasicAuth(c.creds.ID, c.creds.Secret) encodedID := url.QueryEscape(c.creds.ID)
encodedSecret := url.QueryEscape(c.creds.Secret)
req.SetBasicAuth(encodedID, encodedSecret)
default: default:
panic("misconfigured client: auth method not supported") panic("misconfigured client: auth method not supported")
} }
...@@ -218,6 +226,30 @@ func (c *Client) ClientCredsToken(scope []string) (result TokenResponse, err err ...@@ -218,6 +226,30 @@ func (c *Client) ClientCredsToken(scope []string) (result TokenResponse, err err
return parseTokenResponse(resp) return parseTokenResponse(resp)
} }
// UserCredsToken posts the username and password to obtain a token scoped to the OAuth2 client via the "password" grant_type
// May not be supported by all OAuth2 servers.
func (c *Client) UserCredsToken(username, password string) (result TokenResponse, err error) {
v := url.Values{
"scope": {strings.Join(c.scope, " ")},
"grant_type": {GrantTypeUserCreds},
"username": {username},
"password": {password},
}
req, err := c.newAuthenticatedRequest(c.tokenURL.String(), v)
if err != nil {
return
}
resp, err := c.hc.Do(req)
if err != nil {
return
}
defer resp.Body.Close()
return parseTokenResponse(resp)
}
// RequestToken requests a token from the Token Endpoint with the specified grantType. // RequestToken requests a token from the Token Endpoint with the specified grantType.
// If 'grantType' == GrantTypeAuthCode, then 'value' should be the authorization code. // If 'grantType' == GrantTypeAuthCode, then 'value' should be the authorization code.
// If 'grantType' == GrantTypeRefreshToken, then 'value' should be the refresh token. // If 'grantType' == GrantTypeRefreshToken, then 'value' should be the refresh token.
......
...@@ -158,10 +158,20 @@ func TestParseAuthCodeRequest(t *testing.T) { ...@@ -158,10 +158,20 @@ func TestParseAuthCodeRequest(t *testing.T) {
} }
} }
type fakeBadClient struct {
Request *http.Request
err error
}
func (f *fakeBadClient) Do(r *http.Request) (*http.Response, error) {
f.Request = r
return nil, f.err
}
func TestClientCredsToken(t *testing.T) { func TestClientCredsToken(t *testing.T) {
hc := &phttp.RequestRecorder{Error: errors.New("error")} hc := &fakeBadClient{nil, errors.New("error")}
cfg := Config{ cfg := Config{
Credentials: ClientCredentials{ID: "cid", Secret: "csecret"}, Credentials: ClientCredentials{ID: "c#id", Secret: "c secret"},
Scope: []string{"foo-scope", "bar-scope"}, Scope: []string{"foo-scope", "bar-scope"},
TokenURL: "http://example.com/token", TokenURL: "http://example.com/token",
AuthMethod: AuthMethodClientSecretBasic, AuthMethod: AuthMethodClientSecretBasic,
...@@ -195,11 +205,11 @@ func TestClientCredsToken(t *testing.T) { ...@@ -195,11 +205,11 @@ func TestClientCredsToken(t *testing.T) {
t.Error("unexpected error parsing basic auth") t.Error("unexpected error parsing basic auth")
} }
if cfg.Credentials.ID != cid { if url.QueryEscape(cfg.Credentials.ID) != cid {
t.Errorf("wrong client ID, want=%v, got=%v", cfg.Credentials.ID, cid) t.Errorf("wrong client ID, want=%v, got=%v", cfg.Credentials.ID, cid)
} }
if cfg.Credentials.Secret != secret { if url.QueryEscape(cfg.Credentials.Secret) != secret {
t.Errorf("wrong client secret, want=%v, got=%v", cfg.Credentials.Secret, secret) t.Errorf("wrong client secret, want=%v, got=%v", cfg.Credentials.Secret, secret)
} }
...@@ -210,7 +220,7 @@ func TestClientCredsToken(t *testing.T) { ...@@ -210,7 +220,7 @@ func TestClientCredsToken(t *testing.T) {
gt := hc.Request.PostForm.Get("grant_type") gt := hc.Request.PostForm.Get("grant_type")
if gt != GrantTypeClientCreds { if gt != GrantTypeClientCreds {
t.Errorf("wrong grant_type, want=client_credentials, got=%v", gt) t.Errorf("wrong grant_type, want=%v, got=%v", GrantTypeClientCreds, gt)
} }
sc := strings.Split(hc.Request.PostForm.Get("scope"), " ") sc := strings.Split(hc.Request.PostForm.Get("scope"), " ")
...@@ -219,6 +229,66 @@ func TestClientCredsToken(t *testing.T) { ...@@ -219,6 +229,66 @@ func TestClientCredsToken(t *testing.T) {
} }
} }
func TestUserCredsToken(t *testing.T) {
hc := &fakeBadClient{nil, errors.New("error")}
cfg := Config{
Credentials: ClientCredentials{ID: "c#id", Secret: "c secret"},
Scope: []string{"foo-scope", "bar-scope"},
TokenURL: "http://example.com/token",
AuthMethod: AuthMethodClientSecretBasic,
RedirectURL: "http://example.com/redirect",
AuthURL: "http://example.com/auth",
}
c, err := NewClient(hc, cfg)
if err != nil {
t.Errorf("unexpected error %v", err)
}
c.UserCredsToken("username", "password")
if hc.Request == nil {
t.Error("request is empty")
}
tu := hc.Request.URL.String()
if cfg.TokenURL != tu {
t.Errorf("wrong token url, want=%v, got=%v", cfg.TokenURL, tu)
}
ct := hc.Request.Header.Get("Content-Type")
if ct != "application/x-www-form-urlencoded" {
t.Errorf("wrong content-type, want=application/x-www-form-urlencoded, got=%v", ct)
}
cid, secret, ok := phttp.BasicAuth(hc.Request)
if !ok {
t.Error("unexpected error parsing basic auth")
}
if url.QueryEscape(cfg.Credentials.ID) != cid {
t.Errorf("wrong client ID, want=%v, got=%v", cfg.Credentials.ID, cid)
}
if url.QueryEscape(cfg.Credentials.Secret) != secret {
t.Errorf("wrong client secret, want=%v, got=%v", cfg.Credentials.Secret, secret)
}
err = hc.Request.ParseForm()
if err != nil {
t.Error("unexpected error parsing form")
}
gt := hc.Request.PostForm.Get("grant_type")
if gt != GrantTypeUserCreds {
t.Errorf("wrong grant_type, want=%v, got=%v", GrantTypeUserCreds, gt)
}
sc := strings.Split(hc.Request.PostForm.Get("scope"), " ")
if !reflect.DeepEqual(c.scope, sc) {
t.Errorf("wrong scope, want=%v, got=%v", c.scope, sc)
}
}
func TestNewAuthenticatedRequest(t *testing.T) { func TestNewAuthenticatedRequest(t *testing.T) {
tests := []struct { tests := []struct {
authMethod string authMethod string
...@@ -238,16 +308,15 @@ func TestNewAuthenticatedRequest(t *testing.T) { ...@@ -238,16 +308,15 @@ func TestNewAuthenticatedRequest(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
hc := &phttp.HandlerClient{}
cfg := Config{ cfg := Config{
Credentials: ClientCredentials{ID: "cid", Secret: "csecret"}, Credentials: ClientCredentials{ID: "c#id", Secret: "c secret"},
Scope: []string{"foo-scope", "bar-scope"}, Scope: []string{"foo-scope", "bar-scope"},
TokenURL: "http://example.com/token", TokenURL: "http://example.com/token",
AuthURL: "http://example.com/auth", AuthURL: "http://example.com/auth",
RedirectURL: "http://example.com/redirect", RedirectURL: "http://example.com/redirect",
AuthMethod: tt.authMethod, AuthMethod: tt.authMethod,
} }
c, err := NewClient(hc, cfg) c, err := NewClient(nil, cfg)
req, err := c.newAuthenticatedRequest(tt.url, tt.values) req, err := c.newAuthenticatedRequest(tt.url, tt.values)
if err != nil { if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err) t.Errorf("case %d: unexpected error: %v", i, err)
...@@ -264,10 +333,10 @@ func TestNewAuthenticatedRequest(t *testing.T) { ...@@ -264,10 +333,10 @@ func TestNewAuthenticatedRequest(t *testing.T) {
t.Errorf("case %d: !ok parsing Basic Auth headers", i) t.Errorf("case %d: !ok parsing Basic Auth headers", i)
continue continue
} }
if cid != cfg.Credentials.ID { if cid != url.QueryEscape(cfg.Credentials.ID) {
t.Errorf("case %d: want CID == %q, got CID == %q", i, cfg.Credentials.ID, cid) t.Errorf("case %d: want CID == %q, got CID == %q", i, cfg.Credentials.ID, cid)
} }
if secret != cfg.Credentials.Secret { if secret != url.QueryEscape(cfg.Credentials.Secret) {
t.Errorf("case %d: want secret == %q, got secret == %q", i, cfg.Credentials.Secret, secret) t.Errorf("case %d: want secret == %q, got secret == %q", i, cfg.Credentials.Secret, secret)
} }
} else if tt.authMethod == AuthMethodClientSecretPost { } else if tt.authMethod == AuthMethodClientSecretPost {
......
...@@ -11,6 +11,11 @@ import ( ...@@ -11,6 +11,11 @@ import (
"github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/key"
) )
// DefaultPublicKeySetTTL is the default TTL set on the PublicKeySet if no
// Cache-Control header is provided by the JWK Set document endpoint.
const DefaultPublicKeySetTTL = 24 * time.Hour
// NewRemotePublicKeyRepo is responsible for fetching the JWK Set document.
func NewRemotePublicKeyRepo(hc phttp.Client, ep string) *remotePublicKeyRepo { func NewRemotePublicKeyRepo(hc phttp.Client, ep string) *remotePublicKeyRepo {
return &remotePublicKeyRepo{hc: hc, ep: ep} return &remotePublicKeyRepo{hc: hc, ep: ep}
} }
...@@ -20,6 +25,11 @@ type remotePublicKeyRepo struct { ...@@ -20,6 +25,11 @@ type remotePublicKeyRepo struct {
ep string ep string
} }
// Get returns a PublicKeySet fetched from the JWK Set document endpoint. A TTL
// is set on the Key Set to avoid it having to be re-retrieved for every
// encryption event. This TTL is typically controlled by the endpoint returning
// a Cache-Control header, but defaults to 24 hours if no Cache-Control header
// is found.
func (r *remotePublicKeyRepo) Get() (key.KeySet, error) { func (r *remotePublicKeyRepo) Get() (key.KeySet, error) {
req, err := http.NewRequest("GET", r.ep, nil) req, err := http.NewRequest("GET", r.ep, nil)
if err != nil { if err != nil {
...@@ -48,7 +58,7 @@ func (r *remotePublicKeyRepo) Get() (key.KeySet, error) { ...@@ -48,7 +58,7 @@ func (r *remotePublicKeyRepo) Get() (key.KeySet, error) {
return nil, err return nil, err
} }
if !ok { if !ok {
return nil, errors.New("HTTP cache headers not set") ttl = DefaultPublicKeySetTTL
} }
exp := time.Now().UTC().Add(ttl) exp := time.Now().UTC().Add(ttl)
......
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
"strings"
"sync" "sync"
"time" "time"
...@@ -536,7 +537,7 @@ func (s *ProviderConfigSyncer) sync() (time.Duration, error) { ...@@ -536,7 +537,7 @@ func (s *ProviderConfigSyncer) sync() (time.Duration, error) {
s.initialSyncDone = true s.initialSyncDone = true
} }
log.Infof("Updating provider config: config=%#v", cfg) log.Debugf("Updating provider config: config=%#v", cfg)
return nextSyncAfter(cfg.ExpiresAt, s.clock), nil return nextSyncAfter(cfg.ExpiresAt, s.clock), nil
} }
...@@ -618,7 +619,11 @@ func NewHTTPProviderConfigGetter(hc phttp.Client, issuerURL string) *httpProvide ...@@ -618,7 +619,11 @@ func NewHTTPProviderConfigGetter(hc phttp.Client, issuerURL string) *httpProvide
} }
func (r *httpProviderConfigGetter) Get() (cfg ProviderConfig, err error) { func (r *httpProviderConfigGetter) Get() (cfg ProviderConfig, err error) {
req, err := http.NewRequest("GET", r.issuerURL+discoveryConfigPath, nil) // If the Issuer value contains a path component, any terminating / MUST be removed before
// appending /.well-known/openid-configuration.
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfigurationRequest
discoveryURL := strings.TrimSuffix(r.issuerURL, "/") + discoveryConfigPath
req, err := http.NewRequest("GET", discoveryURL, nil)
if err != nil { if err != nil {
return return
} }
......
...@@ -17,7 +17,6 @@ import ( ...@@ -17,7 +17,6 @@ import (
"github.com/kylelemons/godebug/diff" "github.com/kylelemons/godebug/diff"
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
phttp "github.com/coreos/go-oidc/http"
"github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/oauth2" "github.com/coreos/go-oidc/oauth2"
) )
...@@ -495,9 +494,26 @@ func TestProviderConfigRequiredFields(t *testing.T) { ...@@ -495,9 +494,26 @@ func TestProviderConfigRequiredFields(t *testing.T) {
} }
} }
type handlerClient struct {
Handler http.Handler
}
func (hc *handlerClient) Do(r *http.Request) (*http.Response, error) {
w := httptest.NewRecorder()
hc.Handler.ServeHTTP(w, r)
resp := http.Response{
StatusCode: w.Code,
Header: w.Header(),
Body: ioutil.NopCloser(w.Body),
}
return &resp, nil
}
func TestHTTPProviderConfigGetter(t *testing.T) { func TestHTTPProviderConfigGetter(t *testing.T) {
svr := &fakeProviderConfigHandler{} svr := &fakeProviderConfigHandler{}
hc := &phttp.HandlerClient{Handler: svr} hc := &handlerClient{Handler: svr}
fc := clockwork.NewFakeClock() fc := clockwork.NewFakeClock()
now := fc.Now().UTC() now := fc.Now().UTC()
...@@ -887,6 +903,14 @@ func TestProviderConfigSupportsGrantType(t *testing.T) { ...@@ -887,6 +903,14 @@ func TestProviderConfigSupportsGrantType(t *testing.T) {
} }
} }
type fakeClient struct {
resp *http.Response
}
func (f *fakeClient) Do(req *http.Request) (*http.Response, error) {
return f.resp, nil
}
func TestWaitForProviderConfigImmediateSuccess(t *testing.T) { func TestWaitForProviderConfigImmediateSuccess(t *testing.T) {
cfg := newValidProviderConfig() cfg := newValidProviderConfig()
b, err := json.Marshal(&cfg) b, err := json.Marshal(&cfg)
...@@ -895,7 +919,7 @@ func TestWaitForProviderConfigImmediateSuccess(t *testing.T) { ...@@ -895,7 +919,7 @@ func TestWaitForProviderConfigImmediateSuccess(t *testing.T) {
} }
resp := http.Response{Body: ioutil.NopCloser(bytes.NewBuffer(b))} resp := http.Response{Body: ioutil.NopCloser(bytes.NewBuffer(b))}
hc := &phttp.RequestRecorder{Response: &resp} hc := &fakeClient{&resp}
fc := clockwork.NewFakeClock() fc := clockwork.NewFakeClock()
reschan := make(chan ProviderConfig) reschan := make(chan ProviderConfig)
......
...@@ -67,6 +67,15 @@ func (t *AuthenticatedTransport) verifiedJWT() (jose.JWT, error) { ...@@ -67,6 +67,15 @@ func (t *AuthenticatedTransport) verifiedJWT() (jose.JWT, error) {
return t.jwt, nil return t.jwt, nil
} }
// SetJWT sets the JWT held by the Transport.
// This is useful for cases in which you want to set an initial JWT.
func (t *AuthenticatedTransport) SetJWT(jwt jose.JWT) {
t.mu.Lock()
defer t.mu.Unlock()
t.jwt = jwt
}
func (t *AuthenticatedTransport) RoundTrip(r *http.Request) (*http.Response, error) { func (t *AuthenticatedTransport) RoundTrip(r *http.Request) (*http.Response, error) {
jwt, err := t.verifiedJWT() jwt, err := t.verifiedJWT()
if err != nil { if err != nil {
......
...@@ -6,7 +6,6 @@ import ( ...@@ -6,7 +6,6 @@ import (
"reflect" "reflect"
"testing" "testing"
phttp "github.com/coreos/go-oidc/http"
"github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/jose"
) )
...@@ -75,8 +74,8 @@ func TestAuthenticatedTransportVerifiedJWT(t *testing.T) { ...@@ -75,8 +74,8 @@ func TestAuthenticatedTransportVerifiedJWT(t *testing.T) {
for i, tt := range tests { for i, tt := range tests {
at := &AuthenticatedTransport{ at := &AuthenticatedTransport{
TokenRefresher: tt.refresher, TokenRefresher: tt.refresher,
jwt: tt.startJWT,
} }
at.SetJWT(tt.startJWT)
gotJWT, err := at.verifiedJWT() gotJWT, err := at.verifiedJWT()
if !reflect.DeepEqual(tt.wantError, err) { if !reflect.DeepEqual(tt.wantError, err) {
...@@ -122,8 +121,18 @@ func TestAuthenticatedTransportJWTCaching(t *testing.T) { ...@@ -122,8 +121,18 @@ func TestAuthenticatedTransportJWTCaching(t *testing.T) {
} }
} }
type fakeRoundTripper struct {
Request *http.Request
resp *http.Response
}
func (r *fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
r.Request = req
return r.resp, nil
}
func TestAuthenticatedTransportRoundTrip(t *testing.T) { func TestAuthenticatedTransportRoundTrip(t *testing.T) {
rr := &phttp.RequestRecorder{Response: &http.Response{StatusCode: http.StatusOK}} rr := &fakeRoundTripper{nil, &http.Response{StatusCode: http.StatusOK}}
at := &AuthenticatedTransport{ at := &AuthenticatedTransport{
TokenRefresher: &staticTokenRefresher{ TokenRefresher: &staticTokenRefresher{
verify: func(jose.JWT) error { return nil }, verify: func(jose.JWT) error { return nil },
...@@ -150,7 +159,7 @@ func TestAuthenticatedTransportRoundTrip(t *testing.T) { ...@@ -150,7 +159,7 @@ func TestAuthenticatedTransportRoundTrip(t *testing.T) {
} }
func TestAuthenticatedTransportRoundTripRefreshFail(t *testing.T) { func TestAuthenticatedTransportRoundTripRefreshFail(t *testing.T) {
rr := &phttp.RequestRecorder{Response: &http.Response{StatusCode: http.StatusOK}} rr := &fakeRoundTripper{nil, &http.Response{StatusCode: http.StatusOK}}
at := &AuthenticatedTransport{ at := &AuthenticatedTransport{
TokenRefresher: &staticTokenRefresher{ TokenRefresher: &staticTokenRefresher{
verify: func(jose.JWT) error { return errors.New("fail!") }, verify: func(jose.JWT) error { return errors.New("fail!") },
......
...@@ -59,8 +59,8 @@ func NewClaims(iss, sub string, aud interface{}, iat, exp time.Time) jose.Claims ...@@ -59,8 +59,8 @@ func NewClaims(iss, sub string, aud interface{}, iat, exp time.Time) jose.Claims
"iss": iss, "iss": iss,
"sub": sub, "sub": sub,
"aud": aud, "aud": aud,
"iat": float64(iat.Unix()), "iat": iat.Unix(),
"exp": float64(exp.Unix()), "exp": exp.Unix(),
} }
} }
......
...@@ -83,8 +83,8 @@ func TestNewClaims(t *testing.T) { ...@@ -83,8 +83,8 @@ func TestNewClaims(t *testing.T) {
"iss": "https://example.com", "iss": "https://example.com",
"sub": "user-123", "sub": "user-123",
"aud": "client-abc", "aud": "client-abc",
"iat": float64(issAt.Unix()), "iat": issAt.Unix(),
"exp": float64(expAt.Unix()), "exp": expAt.Unix(),
} }
got := NewClaims("https://example.com", "user-123", "client-abc", issAt, expAt) got := NewClaims("https://example.com", "user-123", "client-abc", issAt, expAt)
...@@ -97,8 +97,8 @@ func TestNewClaims(t *testing.T) { ...@@ -97,8 +97,8 @@ func TestNewClaims(t *testing.T) {
"iss": "https://example.com", "iss": "https://example.com",
"sub": "user-123", "sub": "user-123",
"aud": []string{"client-abc", "client-def"}, "aud": []string{"client-abc", "client-def"},
"iat": float64(issAt.Unix()), "iat": issAt.Unix(),
"exp": float64(expAt.Unix()), "exp": expAt.Unix(),
} }
got2 := NewClaims("https://example.com", "user-123", []string{"client-abc", "client-def"}, issAt, expAt) got2 := NewClaims("https://example.com", "user-123", []string{"client-abc", "client-def"}, issAt, expAt)
......
...@@ -48,11 +48,15 @@ if [ -n "${fmtRes}" ]; then ...@@ -48,11 +48,15 @@ if [ -n "${fmtRes}" ]; then
exit 255 exit 255
fi fi
echo "Checking govet..." if [[ -z "$TRAVIS_GO_VERSION" || "$TRAVIS_GO_VERSION" != "1.4.3" ]]; then
vetRes=$(go vet $TEST) echo "Checking govet..."
if [ -n "${vetRes}" ]; then vetRes=$(go vet $TEST)
echo -e "govet checking failed:\n${vetRes}" if [ -n "${vetRes}" ]; then
exit 255 echo -e "govet checking failed:\n${vetRes}"
exit 255
fi
else
echo "Skipping govet (Go 1.4)"
fi fi
echo "Success" echo "Success"
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment