Commit 17cac69e authored by Eric Chiang's avatar Eric Chiang

Godeps: updated github.com/coreos/go-oidc/...

Update Godeps to include:

* coreso/go-oidc#24: reqiured for GitHub connector
* coreso/go-oidc#26: better error messages when connectors are misconfigured
parent 74f84734
...@@ -21,23 +21,23 @@ ...@@ -21,23 +21,23 @@
}, },
{ {
"ImportPath": "github.com/coreos/go-oidc/http", "ImportPath": "github.com/coreos/go-oidc/http",
"Rev": "ee7cb1fb480df22f7d8c4c90199e438e454ca3b6" "Rev": "faf70c34f9c411f234eb96d23c518c087cd96d79"
}, },
{ {
"ImportPath": "github.com/coreos/go-oidc/jose", "ImportPath": "github.com/coreos/go-oidc/jose",
"Rev": "ee7cb1fb480df22f7d8c4c90199e438e454ca3b6" "Rev": "faf70c34f9c411f234eb96d23c518c087cd96d79"
}, },
{ {
"ImportPath": "github.com/coreos/go-oidc/key", "ImportPath": "github.com/coreos/go-oidc/key",
"Rev": "ee7cb1fb480df22f7d8c4c90199e438e454ca3b6" "Rev": "faf70c34f9c411f234eb96d23c518c087cd96d79"
}, },
{ {
"ImportPath": "github.com/coreos/go-oidc/oauth2", "ImportPath": "github.com/coreos/go-oidc/oauth2",
"Rev": "ee7cb1fb480df22f7d8c4c90199e438e454ca3b6" "Rev": "faf70c34f9c411f234eb96d23c518c087cd96d79"
}, },
{ {
"ImportPath": "github.com/coreos/go-oidc/oidc", "ImportPath": "github.com/coreos/go-oidc/oidc",
"Rev": "ee7cb1fb480df22f7d8c4c90199e438e454ca3b6" "Rev": "faf70c34f9c411f234eb96d23c518c087cd96d79"
}, },
{ {
"ImportPath": "github.com/coreos/pkg/capnslog", "ImportPath": "github.com/coreos/pkg/capnslog",
......
This diff is collapsed.
CoreOS Project
Copyright 2014 CoreOS, Inc
This product includes software developed at CoreOS, Inc.
(http://www.coreos.com/).
package http
import (
"net/http"
"net/url"
"reflect"
"strings"
"testing"
"time"
)
func TestCacheControlMaxAgeSuccess(t *testing.T) {
tests := []struct {
hdr string
wantAge time.Duration
wantOK bool
}{
{"max-age=12", 12 * time.Second, true},
{"max-age=-12", 0, false},
{"max-age=0", 0, false},
{"public, max-age=12", 12 * time.Second, true},
{"public, max-age=40192, must-revalidate", 40192 * time.Second, true},
{"public, not-max-age=12, must-revalidate", time.Duration(0), false},
}
for i, tt := range tests {
maxAge, ok, err := cacheControlMaxAge(tt.hdr)
if err != nil {
t.Errorf("case %d: err=%v", i, err)
}
if tt.wantAge != maxAge {
t.Errorf("case %d: want=%d got=%d", i, tt.wantAge, maxAge)
}
if tt.wantOK != ok {
t.Errorf("case %d: incorrect ok value: want=%t got=%t", i, tt.wantOK, ok)
}
}
}
func TestCacheControlMaxAgeFail(t *testing.T) {
tests := []string{
"max-age=aasdf",
"max-age=",
"max-age",
}
for i, tt := range tests {
_, ok, err := cacheControlMaxAge(tt)
if ok {
t.Errorf("case %d: want ok=false, got true", i)
}
if err == nil {
t.Errorf("case %d: want non-nil err", i)
}
}
}
func TestMergeQuery(t *testing.T) {
tests := []struct {
u string
q url.Values
w string
}{
// No values
{
u: "http://example.com",
q: nil,
w: "http://example.com",
},
// No additional values
{
u: "http://example.com?foo=bar",
q: nil,
w: "http://example.com?foo=bar",
},
// Simple addition
{
u: "http://example.com",
q: url.Values{
"foo": []string{"bar"},
},
w: "http://example.com?foo=bar",
},
// Addition with existing values
{
u: "http://example.com?dog=boo",
q: url.Values{
"foo": []string{"bar"},
},
w: "http://example.com?dog=boo&foo=bar",
},
// Merge
{
u: "http://example.com?dog=boo",
q: url.Values{
"dog": []string{"elroy"},
},
w: "http://example.com?dog=boo&dog=elroy",
},
// Add and merge
{
u: "http://example.com?dog=boo",
q: url.Values{
"dog": []string{"elroy"},
"foo": []string{"bar"},
},
w: "http://example.com?dog=boo&dog=elroy&foo=bar",
},
// Multivalue merge
{
u: "http://example.com?dog=boo",
q: url.Values{
"dog": []string{"elroy", "penny"},
},
w: "http://example.com?dog=boo&dog=elroy&dog=penny",
},
}
for i, tt := range tests {
ur, err := url.Parse(tt.u)
if err != nil {
t.Errorf("case %d: failed parsing test url: %v, error: %v", i, tt.u, err)
}
got := MergeQuery(*ur, tt.q)
want, err := url.Parse(tt.w)
if err != nil {
t.Errorf("case %d: failed parsing want url: %v, error: %v", i, tt.w, err)
}
if !reflect.DeepEqual(*want, got) {
t.Errorf("case %d: want: %v, got: %v", i, *want, got)
}
}
}
func TestExpiresPass(t *testing.T) {
tests := []struct {
date string
exp string
wantTTL time.Duration
wantOK bool
}{
// Expires and Date properly set
{
date: "Thu, 01 Dec 1983 22:00:00 GMT",
exp: "Fri, 02 Dec 1983 01:00:00 GMT",
wantTTL: 10800 * time.Second,
wantOK: true,
},
// empty headers
{
date: "",
exp: "",
wantOK: false,
},
// lack of Expirs short-ciruits Date parsing
{
date: "foo",
exp: "",
wantOK: false,
},
// lack of Date short-ciruits Expires parsing
{
date: "",
exp: "foo",
wantOK: false,
},
// no Date
{
exp: "Thu, 01 Dec 1983 22:00:00 GMT",
wantTTL: 0,
wantOK: false,
},
// no Expires
{
date: "Thu, 01 Dec 1983 22:00:00 GMT",
wantTTL: 0,
wantOK: false,
},
// Expires < Date
{
date: "Fri, 02 Dec 1983 01:00:00 GMT",
exp: "Thu, 01 Dec 1983 22:00:00 GMT",
wantTTL: 0,
wantOK: false,
},
}
for i, tt := range tests {
ttl, ok, err := expires(tt.date, tt.exp)
if err != nil {
t.Errorf("case %d: err=%v", i, err)
}
if tt.wantTTL != ttl {
t.Errorf("case %d: want=%d got=%d", i, tt.wantTTL, ttl)
}
if tt.wantOK != ok {
t.Errorf("case %d: incorrect ok value: want=%t got=%t", i, tt.wantOK, ok)
}
}
}
func TestExpiresFail(t *testing.T) {
tests := []struct {
date string
exp string
}{
// malformed Date header
{
date: "foo",
exp: "Fri, 02 Dec 1983 01:00:00 GMT",
},
// malformed exp header
{
date: "Fri, 02 Dec 1983 01:00:00 GMT",
exp: "bar",
},
}
for i, tt := range tests {
_, _, err := expires(tt.date, tt.exp)
if err == nil {
t.Errorf("case %d: expected non-nil error", i)
}
}
}
func TestCacheablePass(t *testing.T) {
tests := []struct {
headers http.Header
wantTTL time.Duration
wantOK bool
}{
// valid Cache-Control
{
headers: http.Header{
"Cache-Control": []string{"max-age=100"},
},
wantTTL: 100 * time.Second,
wantOK: true,
},
// valid Date/Expires
{
headers: http.Header{
"Date": []string{"Thu, 01 Dec 1983 22:00:00 GMT"},
"Expires": []string{"Fri, 02 Dec 1983 01:00:00 GMT"},
},
wantTTL: 10800 * time.Second,
wantOK: true,
},
// Cache-Control supersedes Date/Expires
{
headers: http.Header{
"Cache-Control": []string{"max-age=100"},
"Date": []string{"Thu, 01 Dec 1983 22:00:00 GMT"},
"Expires": []string{"Fri, 02 Dec 1983 01:00:00 GMT"},
},
wantTTL: 100 * time.Second,
wantOK: true,
},
// no caching headers
{
headers: http.Header{},
wantOK: false,
},
}
for i, tt := range tests {
ttl, ok, err := Cacheable(tt.headers)
if err != nil {
t.Errorf("case %d: err=%v", i, err)
continue
}
if tt.wantTTL != ttl {
t.Errorf("case %d: want=%d got=%d", i, tt.wantTTL, ttl)
}
if tt.wantOK != ok {
t.Errorf("case %d: incorrect ok value: want=%t got=%t", i, tt.wantOK, ok)
}
}
}
func TestCacheableFail(t *testing.T) {
tests := []http.Header{
// invalid Cache-Control short-circuits
http.Header{
"Cache-Control": []string{"max-age"},
"Date": []string{"Thu, 01 Dec 1983 22:00:00 GMT"},
"Expires": []string{"Fri, 02 Dec 1983 01:00:00 GMT"},
},
// no Cache-Control, invalid Expires
http.Header{
"Date": []string{"Thu, 01 Dec 1983 22:00:00 GMT"},
"Expires": []string{"boo"},
},
}
for i, tt := range tests {
_, _, err := Cacheable(tt)
if err == nil {
t.Errorf("case %d: want non-nil err", i)
}
}
}
func TestNewResourceLocation(t *testing.T) {
tests := []struct {
ru *url.URL
id string
want string
}{
{
ru: &url.URL{
Scheme: "http",
Host: "example.com",
},
id: "foo",
want: "http://example.com/foo",
},
// https
{
ru: &url.URL{
Scheme: "https",
Host: "example.com",
},
id: "foo",
want: "https://example.com/foo",
},
// with path
{
ru: &url.URL{
Scheme: "http",
Host: "example.com",
Path: "one/two/three",
},
id: "foo",
want: "http://example.com/one/two/three/foo",
},
// with fragment
{
ru: &url.URL{
Scheme: "http",
Host: "example.com",
Fragment: "frag",
},
id: "foo",
want: "http://example.com/foo",
},
// with query
{
ru: &url.URL{
Scheme: "http",
Host: "example.com",
RawQuery: "dog=elroy",
},
id: "foo",
want: "http://example.com/foo",
},
}
for i, tt := range tests {
got := NewResourceLocation(tt.ru, tt.id)
if tt.want != got {
t.Errorf("case %d: want=%s, got=%s", i, tt.want, got)
}
}
}
func TestCopyRequest(t *testing.T) {
r1, err := http.NewRequest("GET", "http://example.com", strings.NewReader("foo"))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
r2 := CopyRequest(r1)
if !reflect.DeepEqual(r1, r2) {
t.Fatalf("Result of CopyRequest incorrect: %#v != %#v", r1, r2)
}
}
package http
import (
"net/url"
"testing"
)
func TestParseNonEmptyURL(t *testing.T) {
tests := []struct {
u string
ok bool
}{
{"", false},
{"http://", false},
{"example.com", false},
{"example", false},
{"http://example", true},
{"http://example:1234", true},
{"http://example.com", true},
{"http://example.com:1234", true},
}
for i, tt := range tests {
u, err := ParseNonEmptyURL(tt.u)
if err != nil {
t.Logf("err: %v", err)
if tt.ok {
t.Errorf("case %d: unexpected error: %v", i, err)
} else {
continue
}
}
if !tt.ok {
t.Errorf("case %d: expected error but got none", i)
continue
}
uu, err := url.Parse(tt.u)
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
continue
}
if uu.String() != u.String() {
t.Errorf("case %d: incorrect url value, want: %q, got: %q", i, uu.String(), u.String())
}
}
}
...@@ -26,6 +26,32 @@ func (c Claims) StringClaim(name string) (string, bool, error) { ...@@ -26,6 +26,32 @@ func (c Claims) StringClaim(name string) (string, bool, error) {
return v, true, nil return v, true, nil
} }
func (c Claims) StringsClaim(name string) ([]string, bool, error) {
cl, ok := c[name]
if !ok {
return nil, false, nil
}
if v, ok := cl.([]string); ok {
return v, true, nil
}
// When unmarshaled, []string will become []interface{}.
if v, ok := cl.([]interface{}); ok {
var ret []string
for _, vv := range v {
str, ok := vv.(string)
if !ok {
return nil, false, fmt.Errorf("unable to parse claim as string array: %v", name)
}
ret = append(ret, str)
}
return ret, true, nil
}
return nil, false, fmt.Errorf("unable to parse claim as string array: %v", name)
}
func (c Claims) Int64Claim(name string) (int64, bool, error) { func (c Claims) Int64Claim(name string) (int64, bool, error) {
cl, ok := c[name] cl, ok := c[name]
if !ok { if !ok {
......
package jose
import (
"testing"
"time"
)
func TestString(t *testing.T) {
tests := []struct {
cl Claims
key string
ok bool
err bool
val string
}{
// ok, no err, claim exists
{
cl: Claims{
"foo": "bar",
},
key: "foo",
val: "bar",
ok: true,
err: false,
},
// no claims
{
cl: Claims{},
key: "foo",
val: "",
ok: false,
err: false,
},
// missing claim
{
cl: Claims{
"foo": "bar",
},
key: "xxx",
val: "",
ok: false,
err: false,
},
// unparsable: type
{
cl: Claims{
"foo": struct{}{},
},
key: "foo",
val: "",
ok: false,
err: true,
},
// unparsable: nil value
{
cl: Claims{
"foo": nil,
},
key: "foo",
val: "",
ok: false,
err: true,
},
}
for i, tt := range tests {
val, ok, err := tt.cl.StringClaim(tt.key)
if tt.err && err == nil {
t.Errorf("case %d: want err=non-nil, got err=nil", i)
} else if !tt.err && err != nil {
t.Errorf("case %d: want err=nil, got err=%v", i, err)
}
if tt.ok != ok {
t.Errorf("case %d: want ok=%v, got ok=%v", i, tt.ok, ok)
}
if tt.val != val {
t.Errorf("case %d: want val=%v, got val=%v", i, tt.val, val)
}
}
}
func TestInt64(t *testing.T) {
tests := []struct {
cl Claims
key string
ok bool
err bool
val int64
}{
// ok, no err, claim exists
{
cl: Claims{
"foo": int64(100),
},
key: "foo",
val: int64(100),
ok: true,
err: false,
},
// no claims
{
cl: Claims{},
key: "foo",
val: 0,
ok: false,
err: false,
},
// missing claim
{
cl: Claims{
"foo": "bar",
},
key: "xxx",
val: 0,
ok: false,
err: false,
},
// unparsable: type
{
cl: Claims{
"foo": struct{}{},
},
key: "foo",
val: 0,
ok: false,
err: true,
},
// unparsable: nil value
{
cl: Claims{
"foo": nil,
},
key: "foo",
val: 0,
ok: false,
err: true,
},
}
for i, tt := range tests {
val, ok, err := tt.cl.Int64Claim(tt.key)
if tt.err && err == nil {
t.Errorf("case %d: want err=non-nil, got err=nil", i)
} else if !tt.err && err != nil {
t.Errorf("case %d: want err=nil, got err=%v", i, err)
}
if tt.ok != ok {
t.Errorf("case %d: want ok=%v, got ok=%v", i, tt.ok, ok)
}
if tt.val != val {
t.Errorf("case %d: want val=%v, got val=%v", i, tt.val, val)
}
}
}
func TestTime(t *testing.T) {
now := time.Now().UTC()
unixNow := now.Unix()
tests := []struct {
cl Claims
key string
ok bool
err bool
val time.Time
}{
// ok, no err, claim exists
{
cl: Claims{
"foo": unixNow,
},
key: "foo",
val: time.Unix(now.Unix(), 0).UTC(),
ok: true,
err: false,
},
// no claims
{
cl: Claims{},
key: "foo",
val: time.Time{},
ok: false,
err: false,
},
// missing claim
{
cl: Claims{
"foo": "bar",
},
key: "xxx",
val: time.Time{},
ok: false,
err: false,
},
// unparsable: type
{
cl: Claims{
"foo": struct{}{},
},
key: "foo",
val: time.Time{},
ok: false,
err: true,
},
// unparsable: nil value
{
cl: Claims{
"foo": nil,
},
key: "foo",
val: time.Time{},
ok: false,
err: true,
},
}
for i, tt := range tests {
val, ok, err := tt.cl.TimeClaim(tt.key)
if tt.err && err == nil {
t.Errorf("case %d: want err=non-nil, got err=nil", i)
} else if !tt.err && err != nil {
t.Errorf("case %d: want err=nil, got err=%v", i, err)
}
if tt.ok != ok {
t.Errorf("case %d: want ok=%v, got ok=%v", i, tt.ok, ok)
}
if tt.val != val {
t.Errorf("case %d: want val=%v, got val=%v", i, tt.val, val)
}
}
}
package jose
import (
"testing"
)
func TestDecodeBase64URLPaddingOptional(t *testing.T) {
tests := []struct {
encoded string
decoded string
err bool
}{
{
// With padding
encoded: "VGVjdG9uaWM=",
decoded: "Tectonic",
},
{
// Without padding
encoded: "VGVjdG9uaWM",
decoded: "Tectonic",
},
{
// Even More padding
encoded: "VGVjdG9uaQ==",
decoded: "Tectoni",
},
{
// And take it away!
encoded: "VGVjdG9uaQ",
decoded: "Tectoni",
},
{
// Too much padding.
encoded: "VGVjdG9uaWNh=",
decoded: "",
err: true,
},
{
// Too much padding.
encoded: "VGVjdG9uaWNh=",
decoded: "",
err: true,
},
}
for i, tt := range tests {
got, err := decodeBase64URLPaddingOptional(tt.encoded)
if tt.err {
if err == nil {
t.Errorf("case %d: expected non-nil err", i)
}
continue
}
if err != nil {
t.Errorf("case %d: want nil err, got: %v", i, err)
}
if string(got) != tt.decoded {
t.Errorf("case %d: want=%q, got=%q", i, tt.decoded, got)
}
}
}
package jose
import (
"strings"
"testing"
)
type testCase struct{ t string }
var validInput []testCase
var invalidInput []testCase
func init() {
validInput = []testCase{
{
"eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk",
},
}
invalidInput = []testCase{
// empty
{
"",
},
// undecodeable
{
"aaa.bbb.ccc",
},
// missing parts
{
"aaa",
},
// missing parts
{
"aaa.bbb",
},
// too many parts
{
"aaa.bbb.ccc.ddd",
},
// invalid header
// EncodeHeader(map[string]string{"foo": "bar"})
{
"eyJmb28iOiJiYXIifQ.bbb.ccc",
},
}
}
func TestParseJWS(t *testing.T) {
for i, tt := range validInput {
jws, err := ParseJWS(tt.t)
if err != nil {
t.Errorf("test: %d. expected: valid, actual: invalid", i)
}
expectedHeader := strings.Split(tt.t, ".")[0]
if jws.RawHeader != expectedHeader {
t.Errorf("test: %d. expected: %s, actual: %s", i, expectedHeader, jws.RawHeader)
}
expectedPayload := strings.Split(tt.t, ".")[1]
if jws.RawPayload != expectedPayload {
t.Errorf("test: %d. expected: %s, actual: %s", i, expectedPayload, jws.RawPayload)
}
}
for i, tt := range invalidInput {
_, err := ParseJWS(tt.t)
if err == nil {
t.Errorf("test: %d. expected: invalid, actual: valid", i)
}
}
}
package jose package jose
import ( import "strings"
"strings"
)
type JWT JWS type JWT JWS
...@@ -63,13 +61,13 @@ func (j *JWT) Encode() string { ...@@ -63,13 +61,13 @@ func (j *JWT) Encode() string {
return strings.Join([]string{d, s}, ".") return strings.Join([]string{d, s}, ".")
} }
func NewSignedJWT(claims map[string]interface{}, s Signer) (*JWT, error) { func NewSignedJWT(claims Claims, s Signer) (*JWT, error) {
header := JOSEHeader{ header := JOSEHeader{
HeaderKeyAlgorithm: s.Alg(), HeaderKeyAlgorithm: s.Alg(),
HeaderKeyID: s.ID(), HeaderKeyID: s.ID(),
} }
jwt, err := NewJWT(header, Claims(claims)) jwt, err := NewJWT(header, claims)
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
package jose
import (
"reflect"
"testing"
)
func TestParseJWT(t *testing.T) {
tests := []struct {
r string
h JOSEHeader
c Claims
}{
{
// Example from JWT spec:
// http://self-issued.info/docs/draft-ietf-oauth-json-web-token.html#ExampleJWT
"eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk",
JOSEHeader{
HeaderMediaType: "JWT",
HeaderKeyAlgorithm: "HS256",
},
Claims{
"iss": "joe",
// NOTE: test numbers must be floats for equality checks to work since values are converted form interface{} to float64 by default.
"exp": 1300819380.0,
"http://example.com/is_root": true,
},
},
}
for i, tt := range tests {
jwt, err := ParseJWT(tt.r)
if err != nil {
t.Errorf("raw token should parse. test: %d. expected: valid, actual: invalid. err=%v", i, err)
}
if !reflect.DeepEqual(tt.h, jwt.Header) {
t.Errorf("JOSE headers should match. test: %d. expected: %v, actual: %v", i, tt.h, jwt.Header)
}
claims, err := jwt.Claims()
if err != nil {
t.Errorf("test: %d. expected: valid claim parsing. err=%v", i, err)
}
if !reflect.DeepEqual(tt.c, claims) {
t.Errorf("claims should match. test: %d. expected: %v, actual: %v", i, tt.c, claims)
}
enc := jwt.Encode()
if enc != tt.r {
t.Errorf("encoded jwt should match raw jwt. test: %d. expected: %v, actual: %v", i, tt.r, enc)
}
}
}
func TestNewJWTHeaderTyp(t *testing.T) {
jwt, err := NewJWT(JOSEHeader{}, Claims{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
want := "JWT"
got := jwt.Header[HeaderMediaType]
if want != got {
t.Fatalf("Header %q incorrect: want=%s got=%s", HeaderMediaType, want, got)
}
}
func TestNewJWTHeaderKeyID(t *testing.T) {
jwt, err := NewJWT(JOSEHeader{HeaderKeyID: "foo"}, Claims{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
want := "foo"
got, ok := jwt.KeyID()
if !ok {
t.Fatalf("KeyID not set")
} else if want != got {
t.Fatalf("KeyID incorrect: want=%s got=%s", want, got)
}
}
func TestNewJWTHeaderKeyIDNotSet(t *testing.T) {
jwt, err := NewJWT(JOSEHeader{}, Claims{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if _, ok := jwt.KeyID(); ok {
t.Fatalf("KeyID set, but should not be")
}
}
package jose
import (
"bytes"
"encoding/base64"
"testing"
)
var hmacTestCases = []struct {
data string
sig string
jwk JWK
valid bool
desc string
}{
{
"test",
"Aymga2LNFrM-tnkr6MYLFY2Jou46h2_Omogeu0iMCRQ=",
JWK{
ID: "fake-key",
Alg: "HS256",
Secret: []byte("secret"),
},
true,
"valid case",
},
{
"test",
"Aymga2LNFrM-tnkr6MYLFY2Jou46h2_Omogeu0iMCRQ=",
JWK{
ID: "different-key",
Alg: "HS256",
Secret: []byte("secret"),
},
true,
"invalid: different key, should not match",
},
{
"test sig and non-matching data",
"Aymga2LNFrM-tnkr6MYLFY2Jou46h2_Omogeu0iMCRQ=",
JWK{
ID: "fake-key",
Alg: "HS256",
Secret: []byte("secret"),
},
false,
"invalid: sig and data should not match",
},
}
func TestVerify(t *testing.T) {
for _, tt := range hmacTestCases {
v, err := NewVerifierHMAC(tt.jwk)
if err != nil {
t.Errorf("should construct hmac verifier. test: %s. err=%v", tt.desc, err)
}
decSig, _ := base64.URLEncoding.DecodeString(tt.sig)
err = v.Verify(decSig, []byte(tt.data))
if err == nil && !tt.valid {
t.Errorf("verify failure. test: %s. expected: invalid, actual: valid.", tt.desc)
}
if err != nil && tt.valid {
t.Errorf("verify failure. test: %s. expected: valid, actual: invalid. err=%v", tt.desc, err)
}
}
}
func TestSign(t *testing.T) {
for _, tt := range hmacTestCases {
s := NewSignerHMAC("test", tt.jwk.Secret)
sig, err := s.Sign([]byte(tt.data))
if err != nil {
t.Errorf("sign failure. test: %s. err=%v", tt.desc, err)
}
expSig, _ := base64.URLEncoding.DecodeString(tt.sig)
if tt.valid && !bytes.Equal(sig, expSig) {
t.Errorf("sign failure. test: %s. expected: %s, actual: %s.", tt.desc, tt.sig, base64.URLEncoding.EncodeToString(sig))
}
if !tt.valid && bytes.Equal(sig, expSig) {
t.Errorf("sign failure. test: %s. expected: invalid signature.", tt.desc)
}
}
}
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"encoding/base64" "encoding/base64"
"encoding/json"
"math/big" "math/big"
"time" "time"
...@@ -18,6 +19,19 @@ type PublicKey struct { ...@@ -18,6 +19,19 @@ type PublicKey struct {
jwk jose.JWK jwk jose.JWK
} }
func (k *PublicKey) MarshalJSON() ([]byte, error) {
return json.Marshal(k.jwk)
}
func (k *PublicKey) UnmarshalJSON(data []byte) error {
var jwk jose.JWK
if err := json.Unmarshal(data, &jwk); err != nil {
return err
}
k.jwk = jwk
return nil
}
func (k *PublicKey) ID() string { func (k *PublicKey) ID() string {
return k.jwk.ID return k.jwk.ID
} }
......
package key
import (
"crypto/rsa"
"math/big"
"reflect"
"testing"
"time"
"github.com/coreos/go-oidc/jose"
)
func TestPrivateRSAKeyJWK(t *testing.T) {
n := big.NewInt(int64(17))
if n == nil {
panic("NewInt returned nil")
}
k := &PrivateKey{
KeyID: "foo",
PrivateKey: &rsa.PrivateKey{
PublicKey: rsa.PublicKey{N: n, E: 65537},
},
}
want := jose.JWK{
ID: "foo",
Type: "RSA",
Alg: "RS256",
Use: "sig",
Modulus: n,
Exponent: 65537,
}
got := k.JWK()
if !reflect.DeepEqual(want, got) {
t.Fatalf("JWK mismatch: want=%#v got=%#v", want, got)
}
}
func TestPublicKeySetKey(t *testing.T) {
n := big.NewInt(int64(17))
if n == nil {
panic("NewInt returned nil")
}
k := jose.JWK{
ID: "foo",
Type: "RSA",
Alg: "RS256",
Use: "sig",
Modulus: n,
Exponent: 65537,
}
now := time.Now().UTC()
ks := NewPublicKeySet([]jose.JWK{k}, now)
want := &PublicKey{jwk: k}
got := ks.Key("foo")
if !reflect.DeepEqual(want, got) {
t.Errorf("Unexpected response from PublicKeySet.Key: want=%#v got=%#v", want, got)
}
got = ks.Key("bar")
if got != nil {
t.Errorf("Expected nil response from PublicKeySet.Key, got %#v", got)
}
}
package key
import (
"crypto/rsa"
"math/big"
"reflect"
"strconv"
"testing"
"time"
"github.com/jonboulle/clockwork"
"github.com/coreos/go-oidc/jose"
)
var (
jwk1 jose.JWK
jwk2 jose.JWK
jwk3 jose.JWK
)
func init() {
jwk1 = jose.JWK{
ID: "1",
Type: "RSA",
Alg: "RS256",
Use: "sig",
Modulus: big.NewInt(1),
Exponent: 65537,
}
jwk2 = jose.JWK{
ID: "2",
Type: "RSA",
Alg: "RS256",
Use: "sig",
Modulus: big.NewInt(2),
Exponent: 65537,
}
jwk3 = jose.JWK{
ID: "3",
Type: "RSA",
Alg: "RS256",
Use: "sig",
Modulus: big.NewInt(3),
Exponent: 65537,
}
}
func generatePrivateKeyStatic(t *testing.T, idAndN int) *PrivateKey {
n := big.NewInt(int64(idAndN))
if n == nil {
t.Fatalf("Call to NewInt(%d) failed", idAndN)
}
pk := &rsa.PrivateKey{
PublicKey: rsa.PublicKey{N: n, E: 65537},
}
return &PrivateKey{
KeyID: strconv.Itoa(idAndN),
PrivateKey: pk,
}
}
func TestPrivateKeyManagerJWKsRotate(t *testing.T) {
k1 := generatePrivateKeyStatic(t, 1)
k2 := generatePrivateKeyStatic(t, 2)
k3 := generatePrivateKeyStatic(t, 3)
km := NewPrivateKeyManager()
err := km.Set(&PrivateKeySet{
keys: []*PrivateKey{k1, k2, k3},
ActiveKeyID: k1.KeyID,
expiresAt: time.Now().Add(time.Minute),
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
want := []jose.JWK{jwk1, jwk2, jwk3}
got, err := km.JWKs()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if !reflect.DeepEqual(want, got) {
t.Fatalf("JWK mismatch: want=%#v got=%#v", want, got)
}
}
func TestPrivateKeyManagerSigner(t *testing.T) {
k := generatePrivateKeyStatic(t, 13)
km := NewPrivateKeyManager()
err := km.Set(&PrivateKeySet{
keys: []*PrivateKey{k},
ActiveKeyID: k.KeyID,
expiresAt: time.Now().Add(time.Minute),
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
signer, err := km.Signer()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
wantID := "13"
gotID := signer.ID()
if wantID != gotID {
t.Fatalf("Signer has incorrect ID: want=%s got=%s", wantID, gotID)
}
}
func TestPrivateKeyManagerHealthyFail(t *testing.T) {
keyFixture := generatePrivateKeyStatic(t, 1)
tests := []*privateKeyManager{
// keySet nil
&privateKeyManager{
keySet: nil,
clock: clockwork.NewRealClock(),
},
// zero keys
&privateKeyManager{
keySet: &PrivateKeySet{
keys: []*PrivateKey{},
expiresAt: time.Now().Add(time.Minute),
},
clock: clockwork.NewRealClock(),
},
// key set expired
&privateKeyManager{
keySet: &PrivateKeySet{
keys: []*PrivateKey{keyFixture},
expiresAt: time.Now().Add(-1 * time.Minute),
},
clock: clockwork.NewRealClock(),
},
}
for i, tt := range tests {
if err := tt.Healthy(); err == nil {
t.Errorf("case %d: nil error", i)
}
}
}
func TestPrivateKeyManagerHealthyFailsOtherMethods(t *testing.T) {
km := NewPrivateKeyManager()
if _, err := km.JWKs(); err == nil {
t.Fatalf("Expected non-nil error")
}
if _, err := km.Signer(); err == nil {
t.Fatalf("Expected non-nil error")
}
}
func TestPrivateKeyManagerExpiresAt(t *testing.T) {
fc := clockwork.NewFakeClock()
now := fc.Now().UTC()
k := generatePrivateKeyStatic(t, 17)
km := &privateKeyManager{
clock: fc,
}
want := fc.Now().UTC()
got := km.ExpiresAt()
if want != got {
t.Fatalf("Incorrect expiration time: want=%v got=%v", want, got)
}
err := km.Set(&PrivateKeySet{
keys: []*PrivateKey{k},
ActiveKeyID: k.KeyID,
expiresAt: now.Add(2 * time.Minute),
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
want = fc.Now().UTC().Add(2 * time.Minute)
got = km.ExpiresAt()
if want != got {
t.Fatalf("Incorrect expiration time: want=%v got=%v", want, got)
}
}
func TestPublicKeys(t *testing.T) {
km := NewPrivateKeyManager()
k1 := generatePrivateKeyStatic(t, 1)
k2 := generatePrivateKeyStatic(t, 2)
k3 := generatePrivateKeyStatic(t, 3)
tests := [][]*PrivateKey{
[]*PrivateKey{k1},
[]*PrivateKey{k1, k2},
[]*PrivateKey{k1, k2, k3},
}
for i, tt := range tests {
ks := &PrivateKeySet{
keys: tt,
expiresAt: time.Now().Add(time.Hour),
}
km.Set(ks)
jwks, err := km.JWKs()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
pks := NewPublicKeySet(jwks, time.Now().Add(time.Hour))
want := pks.Keys()
got, err := km.PublicKeys()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if !reflect.DeepEqual(want, got) {
t.Errorf("case %d: Invalid public keys: want=%v got=%v", i, want, got)
}
}
}
package key
import (
"reflect"
"testing"
"time"
"github.com/jonboulle/clockwork"
)
func generatePrivateKeySerialFunc(t *testing.T) GeneratePrivateKeyFunc {
var n int
return func() (*PrivateKey, error) {
n++
return generatePrivateKeyStatic(t, n), nil
}
}
func TestRotate(t *testing.T) {
now := time.Now()
k1 := generatePrivateKeyStatic(t, 1)
k2 := generatePrivateKeyStatic(t, 2)
k3 := generatePrivateKeyStatic(t, 3)
tests := []struct {
start *PrivateKeySet
key *PrivateKey
keep int
exp time.Time
want *PrivateKeySet
}{
// start with nil keys
{
start: nil,
key: k1,
keep: 2,
exp: now.Add(time.Second),
want: &PrivateKeySet{
keys: []*PrivateKey{k1},
ActiveKeyID: k1.KeyID,
expiresAt: now.Add(time.Second),
},
},
// start with zero keys
{
start: &PrivateKeySet{},
key: k1,
keep: 2,
exp: now.Add(time.Second),
want: &PrivateKeySet{
keys: []*PrivateKey{k1},
ActiveKeyID: k1.KeyID,
expiresAt: now.Add(time.Second),
},
},
// add second key
{
start: &PrivateKeySet{
keys: []*PrivateKey{k1},
ActiveKeyID: k1.KeyID,
expiresAt: now,
},
key: k2,
keep: 2,
exp: now.Add(time.Second),
want: &PrivateKeySet{
keys: []*PrivateKey{k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now.Add(time.Second),
},
},
// rotate in third key
{
start: &PrivateKeySet{
keys: []*PrivateKey{k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now,
},
key: k3,
keep: 2,
exp: now.Add(time.Second),
want: &PrivateKeySet{
keys: []*PrivateKey{k3, k2},
ActiveKeyID: k3.KeyID,
expiresAt: now.Add(time.Second),
},
},
}
for i, tt := range tests {
repo := NewPrivateKeySetRepo()
if tt.start != nil {
err := repo.Set(tt.start)
if err != nil {
log.Fatalf("case %d: unexpected error: %v", i, err)
}
}
rotatePrivateKeys(repo, tt.key, tt.keep, tt.exp)
got, err := repo.Get()
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
continue
}
if !reflect.DeepEqual(tt.want, got) {
t.Errorf("case %d: unexpected result: want=%#v got=%#v", i, tt.want, got)
}
}
}
func TestPrivateKeyRotatorRun(t *testing.T) {
fc := clockwork.NewFakeClock()
now := fc.Now().UTC()
k1 := generatePrivateKeyStatic(t, 1)
k2 := generatePrivateKeyStatic(t, 2)
k3 := generatePrivateKeyStatic(t, 3)
k4 := generatePrivateKeyStatic(t, 4)
kRepo := NewPrivateKeySetRepo()
krot := NewPrivateKeyRotator(kRepo, 4*time.Second)
krot.clock = fc
krot.generateKey = generatePrivateKeySerialFunc(t)
steps := []*PrivateKeySet{
&PrivateKeySet{
keys: []*PrivateKey{k1},
ActiveKeyID: k1.KeyID,
expiresAt: now.Add(4 * time.Second),
},
&PrivateKeySet{
keys: []*PrivateKey{k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now.Add(6 * time.Second),
},
&PrivateKeySet{
keys: []*PrivateKey{k3, k2},
ActiveKeyID: k3.KeyID,
expiresAt: now.Add(8 * time.Second),
},
&PrivateKeySet{
keys: []*PrivateKey{k4, k3},
ActiveKeyID: k4.KeyID,
expiresAt: now.Add(10 * time.Second),
},
}
stop := krot.Run()
defer close(stop)
for i, st := range steps {
// wait for the rotater to get sleepy
fc.BlockUntil(1)
got, err := kRepo.Get()
if err != nil {
t.Fatalf("step %d: unexpected error: %v", i, err)
}
if !reflect.DeepEqual(st, got) {
t.Fatalf("step %d: unexpected state: want=%#v got=%#v", i, st, got)
}
fc.Advance(2 * time.Second)
}
}
func TestPrivateKeyRotatorExpiresAt(t *testing.T) {
fc := clockwork.NewFakeClock()
krot := &PrivateKeyRotator{
clock: fc,
ttl: time.Minute,
}
got := krot.expiresAt()
want := fc.Now().UTC().Add(time.Minute)
if !reflect.DeepEqual(want, got) {
t.Errorf("Incorrect expiration time: want=%v got=%v", want, got)
}
}
func TestNextRotation(t *testing.T) {
fc := clockwork.NewFakeClock()
now := fc.Now().UTC()
tests := []struct {
expiresAt time.Time
ttl time.Duration
numKeys int
expected time.Duration
}{
{
// closest to prod
expiresAt: now.Add(time.Hour * 24),
ttl: time.Hour * 24,
numKeys: 2,
expected: time.Hour * 12,
},
{
expiresAt: now.Add(time.Hour * 2),
ttl: time.Hour * 4,
numKeys: 2,
expected: 0,
},
{
// No keys.
expiresAt: now.Add(time.Hour * 2),
ttl: time.Hour * 4,
numKeys: 0,
expected: 0,
},
{
// Nil keyset.
expiresAt: now.Add(time.Hour * 2),
ttl: time.Hour * 4,
numKeys: -1,
expected: 0,
},
{
// KeySet expired.
expiresAt: now.Add(time.Hour * -2),
ttl: time.Hour * 4,
numKeys: 2,
expected: 0,
},
{
// Expiry past now + TTL
expiresAt: now.Add(time.Hour * 5),
ttl: time.Hour * 4,
numKeys: 2,
expected: 3 * time.Hour,
},
}
for i, tt := range tests {
kRepo := NewPrivateKeySetRepo()
krot := NewPrivateKeyRotator(kRepo, tt.ttl)
krot.clock = fc
pks := &PrivateKeySet{
expiresAt: tt.expiresAt,
}
if tt.numKeys != -1 {
for n := 0; n < tt.numKeys; n++ {
pks.keys = append(pks.keys, generatePrivateKeyStatic(t, n))
}
err := kRepo.Set(pks)
if err != nil {
log.Fatalf("case %d: unexpected error: %v", i, err)
}
}
actual, err := krot.nextRotation()
if err != nil {
t.Errorf("case %d: error calling shouldRotate(): %v", i, err)
}
if actual != tt.expected {
t.Errorf("case %d: actual == %v, want %v", i, actual, tt.expected)
}
}
}
func TestHealthy(t *testing.T) {
fc := clockwork.NewFakeClock()
now := fc.Now().UTC()
tests := []struct {
expiresAt time.Time
numKeys int
expected error
}{
{
expiresAt: now.Add(time.Hour),
numKeys: 2,
expected: nil,
},
{
expiresAt: now.Add(time.Hour),
numKeys: -1,
expected: ErrorNoKeys,
},
{
expiresAt: now.Add(time.Hour),
numKeys: 0,
expected: ErrorNoKeys,
},
{
expiresAt: now.Add(-time.Hour),
numKeys: 2,
expected: ErrorPrivateKeysExpired,
},
}
for i, tt := range tests {
kRepo := NewPrivateKeySetRepo()
krot := NewPrivateKeyRotator(kRepo, time.Hour)
krot.clock = fc
pks := &PrivateKeySet{
expiresAt: tt.expiresAt,
}
if tt.numKeys != -1 {
for n := 0; n < tt.numKeys; n++ {
pks.keys = append(pks.keys, generatePrivateKeyStatic(t, n))
}
err := kRepo.Set(pks)
if err != nil {
log.Fatalf("case %d: unexpected error: %v", i, err)
}
}
if err := krot.Healthy(); err != tt.expected {
t.Errorf("case %d: got==%q, want==%q", i, err, tt.expected)
}
}
}
package key
import (
"errors"
"reflect"
"testing"
"time"
"github.com/jonboulle/clockwork"
)
type staticReadableKeySetRepo struct {
ks KeySet
err error
}
func (r *staticReadableKeySetRepo) Get() (KeySet, error) {
return r.ks, r.err
}
func TestKeySyncerSync(t *testing.T) {
fc := clockwork.NewFakeClock()
now := fc.Now().UTC()
k1 := generatePrivateKeyStatic(t, 1)
k2 := generatePrivateKeyStatic(t, 2)
k3 := generatePrivateKeyStatic(t, 3)
steps := []struct {
fromKS KeySet
fromErr error
advance time.Duration
want *PrivateKeySet
}{
// on startup, first sync should trigger within a second
{
fromKS: &PrivateKeySet{
keys: []*PrivateKey{k1},
ActiveKeyID: k1.KeyID,
expiresAt: now.Add(10 * time.Second),
},
advance: time.Second,
want: &PrivateKeySet{
keys: []*PrivateKey{k1},
ActiveKeyID: k1.KeyID,
expiresAt: now.Add(10 * time.Second),
},
},
// advance halfway into TTL, triggering sync
{
fromKS: &PrivateKeySet{
keys: []*PrivateKey{k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now.Add(15 * time.Second),
},
advance: 5 * time.Second,
want: &PrivateKeySet{
keys: []*PrivateKey{k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now.Add(15 * time.Second),
},
},
// advance halfway into TTL, triggering sync that fails
{
fromErr: errors.New("fail!"),
advance: 10 * time.Second,
want: &PrivateKeySet{
keys: []*PrivateKey{k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now.Add(15 * time.Second),
},
},
// sync retries quickly, and succeeds with fixed data
{
fromKS: &PrivateKeySet{
keys: []*PrivateKey{k3, k2, k1},
ActiveKeyID: k3.KeyID,
expiresAt: now.Add(25 * time.Second),
},
advance: 3 * time.Second,
want: &PrivateKeySet{
keys: []*PrivateKey{k3, k2, k1},
ActiveKeyID: k3.KeyID,
expiresAt: now.Add(25 * time.Second),
},
},
}
from := &staticReadableKeySetRepo{}
to := NewPrivateKeySetRepo()
syncer := NewKeySetSyncer(from, to)
syncer.clock = fc
stop := syncer.Run()
defer close(stop)
for i, st := range steps {
from.ks = st.fromKS
from.err = st.fromErr
fc.Advance(st.advance)
fc.BlockUntil(1)
ks, err := to.Get()
if err != nil {
t.Fatalf("step %d: unable to get keys: %v", i, err)
}
if !reflect.DeepEqual(st.want, ks) {
t.Fatalf("step %d: incorrect state: want=%#v got=%#v", i, st.want, ks)
}
}
}
func TestSync(t *testing.T) {
fc := clockwork.NewFakeClock()
now := fc.Now().UTC()
k1 := generatePrivateKeyStatic(t, 1)
k2 := generatePrivateKeyStatic(t, 2)
k3 := generatePrivateKeyStatic(t, 3)
tests := []struct {
keySet *PrivateKeySet
want time.Duration
}{
{
keySet: &PrivateKeySet{
keys: []*PrivateKey{k1},
ActiveKeyID: k1.KeyID,
expiresAt: now.Add(time.Minute),
},
want: time.Minute,
},
{
keySet: &PrivateKeySet{
keys: []*PrivateKey{k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now.Add(time.Minute),
},
want: time.Minute,
},
{
keySet: &PrivateKeySet{
keys: []*PrivateKey{k3, k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now.Add(time.Minute),
},
want: time.Minute,
},
{
keySet: &PrivateKeySet{
keys: []*PrivateKey{k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now.Add(time.Hour),
},
want: time.Hour,
},
{
keySet: &PrivateKeySet{
keys: []*PrivateKey{k1},
ActiveKeyID: k1.KeyID,
expiresAt: now.Add(-time.Hour),
},
want: 0,
},
}
for i, tt := range tests {
from := NewPrivateKeySetRepo()
to := NewPrivateKeySetRepo()
err := from.Set(tt.keySet)
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
continue
}
exp, err := sync(from, to, fc)
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
continue
}
if tt.want != exp {
t.Errorf("case %d: want=%v got=%v", i, tt.want, exp)
}
}
}
func TestSyncFail(t *testing.T) {
tests := []error{
nil,
errors.New("fail!"),
}
for i, tt := range tests {
from := &staticReadableKeySetRepo{ks: nil, err: tt}
to := NewPrivateKeySetRepo()
if _, err := sync(from, to, clockwork.NewFakeClock()); err == nil {
t.Errorf("case %d: expected non-nil error", i)
}
}
}
package oauth2 package oauth2
import (
"encoding/json"
"fmt"
)
const ( const (
ErrorAccessDenied = "access_denied" ErrorAccessDenied = "access_denied"
ErrorInvalidClient = "invalid_client" ErrorInvalidClient = "invalid_client"
...@@ -17,23 +12,18 @@ const ( ...@@ -17,23 +12,18 @@ const (
) )
type Error struct { type Error struct {
Type string `json:"error"` Type string `json:"error"`
State string `json:"state,omitempty"` Description string `json:"error_description,omitempty"`
State string `json:"state,omitempty"`
} }
func (e *Error) Error() string { func (e *Error) Error() string {
if e.Description != "" {
return e.Type + ": " + e.Description
}
return e.Type return e.Type
} }
func NewError(typ string) *Error { func NewError(typ string) *Error {
return &Error{Type: typ} return &Error{Type: typ}
} }
func unmarshalError(b []byte) error {
var oerr Error
err := json.Unmarshal(b, &oerr)
if err != nil {
return fmt.Errorf("unrecognized error: %s", string(b))
}
return &oerr
}
package oauth2
import (
"fmt"
"reflect"
"testing"
)
func TestUnmarshalError(t *testing.T) {
tests := []struct {
b []byte
e *Error
o bool
}{
{
b: []byte("{ \"error\": \"invalid_client\", \"state\": \"foo\" }"),
e: &Error{Type: ErrorInvalidClient, State: "foo"},
o: true,
},
{
b: []byte("{ \"error\": \"invalid_grant\", \"state\": \"bar\" }"),
e: &Error{Type: ErrorInvalidGrant, State: "bar"},
o: true,
},
{
b: []byte("{ \"error\": \"invalid_request\", \"state\": \"\" }"),
e: &Error{Type: ErrorInvalidRequest, State: ""},
o: true,
},
{
b: []byte("{ \"error\": \"server_error\", \"state\": \"elroy\" }"),
e: &Error{Type: ErrorServerError, State: "elroy"},
o: true,
},
{
b: []byte("{ \"error\": \"unsupported_grant_type\", \"state\": \"\" }"),
e: &Error{Type: ErrorUnsupportedGrantType, State: ""},
o: true,
},
{
b: []byte("{ \"error\": \"unsupported_response_type\", \"state\": \"\" }"),
e: &Error{Type: ErrorUnsupportedResponseType, State: ""},
o: true,
},
// Should fail json unmarshal
{
b: nil,
e: nil,
o: false,
},
{
b: []byte("random string"),
e: nil,
o: false,
},
}
for i, tt := range tests {
err := unmarshalError(tt.b)
oerr, ok := err.(*Error)
if ok != tt.o {
t.Errorf("%v != %v, %v", ok, tt.o, oerr)
t.Errorf("case %d: want=%+v, got=%+v", i, tt.e, oerr)
}
if ok && !reflect.DeepEqual(tt.e, oerr) {
t.Errorf("case %d: want=%+v, got=%+v", i, tt.e, oerr)
}
if !ok && tt.e != nil {
want := fmt.Sprintf("unrecognized error: %s", string(tt.b))
got := tt.e.Error()
if want != got {
t.Errorf("case %d: want=%+v, got=%+v", i, want, got)
}
}
}
}
...@@ -220,11 +220,7 @@ func parseTokenResponse(resp *http.Response) (result TokenResponse, err error) { ...@@ -220,11 +220,7 @@ func parseTokenResponse(resp *http.Response) (result TokenResponse, err error) {
if err != nil { if err != nil {
return return
} }
badStatusCode := resp.StatusCode < 200 || resp.StatusCode > 299
if resp.StatusCode < 200 || resp.StatusCode > 299 {
err = unmarshalError(body)
return
}
contentType, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) contentType, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type"))
if err != nil { if err != nil {
...@@ -235,42 +231,69 @@ func parseTokenResponse(resp *http.Response) (result TokenResponse, err error) { ...@@ -235,42 +231,69 @@ func parseTokenResponse(resp *http.Response) (result TokenResponse, err error) {
RawBody: body, RawBody: body,
} }
newError := func(typ, desc, state string) error {
if typ == "" {
return fmt.Errorf("unrecognized error %s", body)
}
return &Error{typ, desc, state}
}
if contentType == "application/x-www-form-urlencoded" || contentType == "text/plain" { if contentType == "application/x-www-form-urlencoded" || contentType == "text/plain" {
var vals url.Values var vals url.Values
vals, err = url.ParseQuery(string(body)) vals, err = url.ParseQuery(string(body))
if err != nil { if err != nil {
return return
} }
if error := vals.Get("error"); error != "" || badStatusCode {
err = newError(error, vals.Get("error_description"), vals.Get("state"))
return
}
e := vals.Get("expires_in")
if e == "" {
e = vals.Get("expires")
}
if e != "" {
result.Expires, err = strconv.Atoi(e)
if err != nil {
return
}
}
result.AccessToken = vals.Get("access_token") result.AccessToken = vals.Get("access_token")
result.TokenType = vals.Get("token_type") result.TokenType = vals.Get("token_type")
result.IDToken = vals.Get("id_token") result.IDToken = vals.Get("id_token")
result.RefreshToken = vals.Get("refresh_token") result.RefreshToken = vals.Get("refresh_token")
result.Scope = vals.Get("scope") result.Scope = vals.Get("scope")
e := vals.Get("expires_in") } else {
if e == "" { var r struct {
e = vals.Get("expires") AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
Scope string `json:"scope"`
State string `json:"state"`
ExpiresIn int `json:"expires_in"`
Expires int `json:"expires"`
Error string `json:"error"`
Desc string `json:"error_description"`
} }
result.Expires, err = strconv.Atoi(e) if err = json.Unmarshal(body, &r); err != nil {
if err != nil {
return return
} }
} else { if r.Error != "" || badStatusCode {
b := make(map[string]interface{}) err = newError(r.Error, r.Desc, r.State)
if err = json.Unmarshal(body, &b); err != nil {
return return
} }
result.AccessToken, _ = b["access_token"].(string) result.AccessToken = r.AccessToken
result.TokenType, _ = b["token_type"].(string) result.TokenType = r.TokenType
result.IDToken, _ = b["id_token"].(string) result.IDToken = r.IDToken
result.RefreshToken, _ = b["refresh_token"].(string) result.RefreshToken = r.RefreshToken
result.Scope, _ = b["scope"].(string) result.Scope = r.Scope
e, ok := b["expires_in"].(int) if r.ExpiresIn == 0 {
if !ok { result.Expires = r.Expires
e, _ = b["expires"].(int) } else {
result.Expires = r.ExpiresIn
} }
result.Expires = e
} }
return return
} }
......
package oauth2
import (
"errors"
"net/url"
"reflect"
"strings"
"testing"
phttp "github.com/coreos/go-oidc/http"
)
func TestParseAuthCodeRequest(t *testing.T) {
tests := []struct {
query url.Values
wantACR AuthCodeRequest
wantErr error
}{
// no redirect_uri
{
query: url.Values{
"response_type": []string{"code"},
"scope": []string{"foo bar baz"},
"client_id": []string{"XXX"},
"state": []string{"pants"},
},
wantACR: AuthCodeRequest{
ResponseType: "code",
ClientID: "XXX",
Scope: []string{"foo", "bar", "baz"},
State: "pants",
RedirectURL: nil,
},
},
// with redirect_uri
{
query: url.Values{
"response_type": []string{"code"},
"redirect_uri": []string{"https://127.0.0.1:5555/callback?foo=bar"},
"scope": []string{"foo bar baz"},
"client_id": []string{"XXX"},
"state": []string{"pants"},
},
wantACR: AuthCodeRequest{
ResponseType: "code",
ClientID: "XXX",
Scope: []string{"foo", "bar", "baz"},
State: "pants",
RedirectURL: &url.URL{
Scheme: "https",
Host: "127.0.0.1:5555",
Path: "/callback",
RawQuery: "foo=bar",
},
},
},
// unsupported response_type doesn't trigger error
{
query: url.Values{
"response_type": []string{"token"},
"redirect_uri": []string{"https://127.0.0.1:5555/callback?foo=bar"},
"scope": []string{"foo bar baz"},
"client_id": []string{"XXX"},
"state": []string{"pants"},
},
wantACR: AuthCodeRequest{
ResponseType: "token",
ClientID: "XXX",
Scope: []string{"foo", "bar", "baz"},
State: "pants",
RedirectURL: &url.URL{
Scheme: "https",
Host: "127.0.0.1:5555",
Path: "/callback",
RawQuery: "foo=bar",
},
},
},
// unparseable redirect_uri
{
query: url.Values{
"response_type": []string{"code"},
"redirect_uri": []string{":"},
"scope": []string{"foo bar baz"},
"client_id": []string{"XXX"},
"state": []string{"pants"},
},
wantACR: AuthCodeRequest{
ResponseType: "code",
ClientID: "XXX",
Scope: []string{"foo", "bar", "baz"},
State: "pants",
},
wantErr: NewError(ErrorInvalidRequest),
},
// no client_id, redirect_uri not parsed
{
query: url.Values{
"response_type": []string{"code"},
"redirect_uri": []string{"https://127.0.0.1:5555/callback?foo=bar"},
"scope": []string{"foo bar baz"},
"client_id": []string{},
"state": []string{"pants"},
},
wantACR: AuthCodeRequest{
ResponseType: "code",
ClientID: "",
Scope: []string{"foo", "bar", "baz"},
State: "pants",
RedirectURL: nil,
},
wantErr: NewError(ErrorInvalidRequest),
},
}
for i, tt := range tests {
got, err := ParseAuthCodeRequest(tt.query)
if !reflect.DeepEqual(tt.wantErr, err) {
t.Errorf("case %d: incorrect error value: want=%q got=%q", i, tt.wantErr, err)
}
if !reflect.DeepEqual(tt.wantACR, got) {
t.Errorf("case %d: incorrect AuthCodeRequest value: want=%#v got=%#v", i, tt.wantACR, got)
}
}
}
func TestClientCredsToken(t *testing.T) {
hc := &phttp.RequestRecorder{Error: errors.New("error")}
cfg := Config{
Credentials: ClientCredentials{ID: "cid", Secret: "csecret"},
Scope: []string{"foo-scope", "bar-scope"},
TokenURL: "http://example.com/token",
AuthMethod: AuthMethodClientSecretBasic,
RedirectURL: "http://example.com/redirect",
AuthURL: "http://example.com/auth",
}
c, err := NewClient(hc, cfg)
if err != nil {
t.Errorf("unexpected error %v", err)
}
scope := []string{"openid"}
c.ClientCredsToken(scope)
if hc.Request == nil {
t.Error("request is empty")
}
tu := hc.Request.URL.String()
if cfg.TokenURL != tu {
t.Errorf("wrong token url, want=%v, got=%v", cfg.TokenURL, tu)
}
ct := hc.Request.Header.Get("Content-Type")
if ct != "application/x-www-form-urlencoded" {
t.Errorf("wrong content-type, want=application/x-www-form-urlencoded, got=%v", ct)
}
cid, secret, ok := phttp.BasicAuth(hc.Request)
if !ok {
t.Error("unexpected error parsing basic auth")
}
if cfg.Credentials.ID != cid {
t.Errorf("wrong client ID, want=%v, got=%v", cfg.Credentials.ID, cid)
}
if cfg.Credentials.Secret != secret {
t.Errorf("wrong client secret, want=%v, got=%v", cfg.Credentials.Secret, secret)
}
err = hc.Request.ParseForm()
if err != nil {
t.Error("unexpected error parsing form")
}
gt := hc.Request.PostForm.Get("grant_type")
if gt != GrantTypeClientCreds {
t.Errorf("wrong grant_type, want=client_credentials, got=%v", gt)
}
sc := strings.Split(hc.Request.PostForm.Get("scope"), " ")
if !reflect.DeepEqual(scope, sc) {
t.Errorf("wrong scope, want=%v, got=%v", scope, sc)
}
}
func TestNewAuthenticatedRequest(t *testing.T) {
tests := []struct {
authMethod string
url string
values url.Values
}{
{
authMethod: AuthMethodClientSecretBasic,
url: "http://example.com/token",
values: url.Values{},
},
{
authMethod: AuthMethodClientSecretPost,
url: "http://example.com/token",
values: url.Values{},
},
}
for i, tt := range tests {
hc := &phttp.HandlerClient{}
cfg := Config{
Credentials: ClientCredentials{ID: "cid", Secret: "csecret"},
Scope: []string{"foo-scope", "bar-scope"},
TokenURL: "http://example.com/token",
AuthURL: "http://example.com/auth",
RedirectURL: "http://example.com/redirect",
AuthMethod: tt.authMethod,
}
c, err := NewClient(hc, cfg)
req, err := c.newAuthenticatedRequest(tt.url, tt.values)
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
continue
}
err = req.ParseForm()
if err != nil {
t.Errorf("case %d: want nil err, got %v", i, err)
}
if tt.authMethod == AuthMethodClientSecretBasic {
cid, secret, ok := phttp.BasicAuth(req)
if !ok {
t.Errorf("case %d: !ok parsing Basic Auth headers", i)
continue
}
if cid != cfg.Credentials.ID {
t.Errorf("case %d: want CID == %q, got CID == %q", i, cfg.Credentials.ID, cid)
}
if secret != cfg.Credentials.Secret {
t.Errorf("case %d: want secret == %q, got secret == %q", i, cfg.Credentials.Secret, secret)
}
} else if tt.authMethod == AuthMethodClientSecretPost {
if req.PostFormValue("client_secret") != cfg.Credentials.Secret {
t.Errorf("case %d: want client_secret == %q, got client_secret == %q",
i, cfg.Credentials.Secret, req.PostFormValue("client_secret"))
}
}
for k, v := range tt.values {
if !reflect.DeepEqual(v, req.PostForm[k]) {
t.Errorf("case %d: key:%q want==%q, got==%q", i, k, v, req.PostForm[k])
}
}
if req.URL.String() != tt.url {
t.Errorf("case %d: want URL==%q, got URL==%q", i, tt.url, req.URL.String())
}
}
}
package oidc
import (
"net/url"
"reflect"
"testing"
"time"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key"
"github.com/coreos/go-oidc/oauth2"
)
func TestNewClientScopeDefault(t *testing.T) {
tests := []struct {
c ClientConfig
e []string
}{
{
// No scope
c: ClientConfig{RedirectURL: "http://example.com/redirect"},
e: DefaultScope,
},
{
// Nil scope
c: ClientConfig{RedirectURL: "http://example.com/redirect", Scope: nil},
e: DefaultScope,
},
{
// Empty scope
c: ClientConfig{RedirectURL: "http://example.com/redirect", Scope: []string{}},
e: []string{},
},
{
// Custom scope equal to default
c: ClientConfig{RedirectURL: "http://example.com/redirect", Scope: []string{"openid", "email", "profile"}},
e: DefaultScope,
},
{
// Custom scope not including defaults
c: ClientConfig{RedirectURL: "http://example.com/redirect", Scope: []string{"foo", "bar"}},
e: []string{"foo", "bar"},
},
{
// Custom scopes overlapping with defaults
c: ClientConfig{RedirectURL: "http://example.com/redirect", Scope: []string{"openid", "foo"}},
e: []string{"openid", "foo"},
},
}
for i, tt := range tests {
c, err := NewClient(tt.c)
if err != nil {
t.Errorf("case %d: unexpected error from NewClient: %v", i, err)
continue
}
if !reflect.DeepEqual(tt.e, c.scope) {
t.Errorf("case %d: want: %v, got: %v", i, tt.e, c.scope)
}
}
}
func TestHealthy(t *testing.T) {
now := time.Now().UTC()
tests := []struct {
c *Client
h bool
}{
// all ok
{
c: &Client{
providerConfig: ProviderConfig{
Issuer: "http://example.com",
ExpiresAt: now.Add(time.Hour),
},
},
h: true,
},
// zero-value ProviderConfig.ExpiresAt
{
c: &Client{
providerConfig: ProviderConfig{
Issuer: "http://example.com",
},
},
h: true,
},
// expired ProviderConfig
{
c: &Client{
providerConfig: ProviderConfig{
Issuer: "http://example.com",
ExpiresAt: now.Add(time.Hour * -1),
},
},
h: false,
},
// empty ProviderConfig
{
c: &Client{},
h: false,
},
}
for i, tt := range tests {
err := tt.c.Healthy()
want := tt.h
got := (err == nil)
if want != got {
t.Errorf("case %d: want: healthy=%v, got: healhty=%v, err: %v", i, want, got, err)
}
}
}
func TestClientKeysFuncAll(t *testing.T) {
priv1, err := key.GeneratePrivateKey()
if err != nil {
t.Fatalf("failed to generate private key, error=%v", err)
}
priv2, err := key.GeneratePrivateKey()
if err != nil {
t.Fatalf("failed to generate private key, error=%v", err)
}
now := time.Now()
future := now.Add(time.Hour)
past := now.Add(-1 * time.Hour)
tests := []struct {
keySet *key.PublicKeySet
want []key.PublicKey
}{
// two keys, non-expired set
{
keySet: key.NewPublicKeySet([]jose.JWK{priv2.JWK(), priv1.JWK()}, future),
want: []key.PublicKey{*key.NewPublicKey(priv2.JWK()), *key.NewPublicKey(priv1.JWK())},
},
// no keys, non-expired set
{
keySet: key.NewPublicKeySet([]jose.JWK{}, future),
want: []key.PublicKey{},
},
// two keys, expired set
{
keySet: key.NewPublicKeySet([]jose.JWK{priv2.JWK(), priv1.JWK()}, past),
want: []key.PublicKey{},
},
// no keys, expired set
{
keySet: key.NewPublicKeySet([]jose.JWK{}, past),
want: []key.PublicKey{},
},
}
for i, tt := range tests {
var c Client
c.keySet = *tt.keySet
keysFunc := c.keysFuncAll()
got := keysFunc()
if !reflect.DeepEqual(tt.want, got) {
t.Errorf("case %d: want=%#v got=%#v", i, tt.want, got)
}
}
}
func TestClientKeysFuncWithID(t *testing.T) {
priv1, err := key.GeneratePrivateKey()
if err != nil {
t.Fatalf("failed to generate private key, error=%v", err)
}
priv2, err := key.GeneratePrivateKey()
if err != nil {
t.Fatalf("failed to generate private key, error=%v", err)
}
now := time.Now()
future := now.Add(time.Hour)
past := now.Add(-1 * time.Hour)
tests := []struct {
keySet *key.PublicKeySet
argID string
want []key.PublicKey
}{
// two keys, match, non-expired set
{
keySet: key.NewPublicKeySet([]jose.JWK{priv2.JWK(), priv1.JWK()}, future),
argID: priv2.ID(),
want: []key.PublicKey{*key.NewPublicKey(priv2.JWK())},
},
// two keys, no match, non-expired set
{
keySet: key.NewPublicKeySet([]jose.JWK{priv2.JWK(), priv1.JWK()}, future),
argID: "XXX",
want: []key.PublicKey{},
},
// no keys, no match, non-expired set
{
keySet: key.NewPublicKeySet([]jose.JWK{}, future),
argID: priv2.ID(),
want: []key.PublicKey{},
},
// two keys, match, expired set
{
keySet: key.NewPublicKeySet([]jose.JWK{priv2.JWK(), priv1.JWK()}, past),
argID: priv2.ID(),
want: []key.PublicKey{},
},
// no keys, no match, expired set
{
keySet: key.NewPublicKeySet([]jose.JWK{}, past),
argID: priv2.ID(),
want: []key.PublicKey{},
},
}
for i, tt := range tests {
var c Client
c.keySet = *tt.keySet
keysFunc := c.keysFuncWithID(tt.argID)
got := keysFunc()
if !reflect.DeepEqual(tt.want, got) {
t.Errorf("case %d: want=%#v got=%#v", i, tt.want, got)
}
}
}
func TestClientMetadataValid(t *testing.T) {
tests := []ClientMetadata{
// one RedirectURL
ClientMetadata{
RedirectURLs: []url.URL{url.URL{Scheme: "http", Host: "example.com"}},
},
// one RedirectURL w/ nonempty path
ClientMetadata{
RedirectURLs: []url.URL{url.URL{Scheme: "http", Host: "example.com", Path: "/foo"}},
},
// two RedirectURLs
ClientMetadata{
RedirectURLs: []url.URL{
url.URL{Scheme: "http", Host: "foo.example.com"},
url.URL{Scheme: "http", Host: "bar.example.com"},
},
},
}
for i, tt := range tests {
if err := tt.Valid(); err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
}
}
}
func TestClientMetadataInvalid(t *testing.T) {
tests := []ClientMetadata{
// nil RedirectURls slice
ClientMetadata{
RedirectURLs: nil,
},
// empty RedirectURLs slice
ClientMetadata{
RedirectURLs: []url.URL{},
},
// empty url.URL
ClientMetadata{
RedirectURLs: []url.URL{url.URL{}},
},
// empty url.URL following OK item
ClientMetadata{
RedirectURLs: []url.URL{url.URL{Scheme: "http", Host: "example.com"}, url.URL{}},
},
// url.URL with empty Host
ClientMetadata{
RedirectURLs: []url.URL{url.URL{Scheme: "http", Host: ""}},
},
// url.URL with empty Scheme
ClientMetadata{
RedirectURLs: []url.URL{url.URL{Scheme: "", Host: "example.com"}},
},
// url.URL with non-HTTP(S) Scheme
ClientMetadata{
RedirectURLs: []url.URL{url.URL{Scheme: "tcp", Host: "127.0.0.1"}},
},
}
for i, tt := range tests {
if err := tt.Valid(); err == nil {
t.Errorf("case %d: expected non-nil error", i)
}
}
}
func TestChooseAuthMethod(t *testing.T) {
tests := []struct {
supported []string
chosen string
err bool
}{
{
supported: []string{},
chosen: oauth2.AuthMethodClientSecretBasic,
},
{
supported: []string{oauth2.AuthMethodClientSecretBasic},
chosen: oauth2.AuthMethodClientSecretBasic,
},
{
supported: []string{oauth2.AuthMethodClientSecretPost},
chosen: oauth2.AuthMethodClientSecretPost,
},
{
supported: []string{oauth2.AuthMethodClientSecretPost, oauth2.AuthMethodClientSecretBasic},
chosen: oauth2.AuthMethodClientSecretPost,
},
{
supported: []string{oauth2.AuthMethodClientSecretBasic, oauth2.AuthMethodClientSecretPost},
chosen: oauth2.AuthMethodClientSecretBasic,
},
{
supported: []string{oauth2.AuthMethodClientSecretJWT, oauth2.AuthMethodClientSecretPost},
chosen: oauth2.AuthMethodClientSecretPost,
},
{
supported: []string{oauth2.AuthMethodClientSecretJWT},
chosen: "",
err: true,
},
}
for i, tt := range tests {
client := Client{
providerConfig: ProviderConfig{
TokenEndpointAuthMethodsSupported: tt.supported,
},
}
got, err := client.chooseAuthMethod()
if tt.err {
if err == nil {
t.Errorf("case %d: expected non-nil err", i)
}
continue
}
if got != tt.chosen {
t.Errorf("case %d: want=%q, got=%q", i, tt.chosen, got)
}
}
}
package oidc
import (
"reflect"
"testing"
"time"
"github.com/coreos/go-oidc/jose"
)
func TestIdentityFromClaims(t *testing.T) {
tests := []struct {
claims jose.Claims
want Identity
}{
{
claims: jose.Claims{
"sub": "123850281",
"name": "Elroy",
"email": "elroy@example.com",
"exp": float64(1.416935146e+09),
},
want: Identity{
ID: "123850281",
Name: "",
Email: "elroy@example.com",
ExpiresAt: time.Date(2014, time.November, 25, 17, 05, 46, 0, time.UTC),
},
},
{
claims: jose.Claims{
"sub": "123850281",
"name": "Elroy",
"exp": float64(1.416935146e+09),
},
want: Identity{
ID: "123850281",
Name: "",
Email: "",
ExpiresAt: time.Date(2014, time.November, 25, 17, 05, 46, 0, time.UTC),
},
},
{
claims: jose.Claims{
"sub": "123850281",
"name": "Elroy",
"email": "elroy@example.com",
"exp": int64(1416935146),
},
want: Identity{
ID: "123850281",
Name: "",
Email: "elroy@example.com",
ExpiresAt: time.Date(2014, time.November, 25, 17, 05, 46, 0, time.UTC),
},
},
{
claims: jose.Claims{
"sub": "123850281",
"name": "Elroy",
"email": "elroy@example.com",
},
want: Identity{
ID: "123850281",
Name: "",
Email: "elroy@example.com",
ExpiresAt: time.Time{},
},
},
}
for i, tt := range tests {
got, err := IdentityFromClaims(tt.claims)
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
continue
}
if !reflect.DeepEqual(tt.want, *got) {
t.Errorf("case %d: want=%#v got=%#v", i, tt.want, *got)
}
}
}
func TestIdentityFromClaimsFail(t *testing.T) {
tests := []jose.Claims{
// sub incorrect type
jose.Claims{
"sub": 123,
"name": "foo",
"email": "elroy@example.com",
},
// email incorrect type
jose.Claims{
"sub": "123850281",
"name": "Elroy",
"email": false,
},
// exp incorrect type
jose.Claims{
"sub": "123850281",
"name": "Elroy",
"email": "elroy@example.com",
"exp": "2014-11-25 18:05:46 +0000 UTC",
},
}
for i, tt := range tests {
_, err := IdentityFromClaims(tt)
if err == nil {
t.Errorf("case %d: expected non-nil error", i)
}
}
}
package oidc
import (
"errors"
"net/http"
"reflect"
"testing"
phttp "github.com/coreos/go-oidc/http"
"github.com/coreos/go-oidc/jose"
)
type staticTokenRefresher struct {
verify func(jose.JWT) error
refresh func() (jose.JWT, error)
}
func (s *staticTokenRefresher) Verify(jwt jose.JWT) error {
return s.verify(jwt)
}
func (s *staticTokenRefresher) Refresh() (jose.JWT, error) {
return s.refresh()
}
func TestAuthenticatedTransportVerifiedJWT(t *testing.T) {
tests := []struct {
refresher TokenRefresher
startJWT jose.JWT
wantJWT jose.JWT
wantError error
}{
// verification succeeds, so refresh is not called
{
refresher: &staticTokenRefresher{
verify: func(jose.JWT) error { return nil },
refresh: func() (jose.JWT, error) { return jose.JWT{RawPayload: "2"}, nil },
},
startJWT: jose.JWT{RawPayload: "1"},
wantJWT: jose.JWT{RawPayload: "1"},
},
// verification fails, refresh succeeds so cached JWT changes
{
refresher: &staticTokenRefresher{
verify: func(jose.JWT) error { return errors.New("fail!") },
refresh: func() (jose.JWT, error) { return jose.JWT{RawPayload: "2"}, nil },
},
startJWT: jose.JWT{RawPayload: "1"},
wantJWT: jose.JWT{RawPayload: "2"},
},
// verification succeeds, so failing refresh isn't attempted
{
refresher: &staticTokenRefresher{
verify: func(jose.JWT) error { return nil },
refresh: func() (jose.JWT, error) { return jose.JWT{}, errors.New("fail!") },
},
startJWT: jose.JWT{RawPayload: "1"},
wantJWT: jose.JWT{RawPayload: "1"},
},
// verification fails, but refresh fails, too
{
refresher: &staticTokenRefresher{
verify: func(jose.JWT) error { return errors.New("fail!") },
refresh: func() (jose.JWT, error) { return jose.JWT{}, errors.New("fail!") },
},
startJWT: jose.JWT{RawPayload: "1"},
wantJWT: jose.JWT{},
wantError: errors.New("unable to acquire valid JWT: fail!"),
},
}
for i, tt := range tests {
at := &AuthenticatedTransport{
TokenRefresher: tt.refresher,
jwt: tt.startJWT,
}
gotJWT, err := at.verifiedJWT()
if !reflect.DeepEqual(tt.wantError, err) {
t.Errorf("#%d: unexpected error: want=%#v got=%#v", i, tt.wantError, err)
}
if !reflect.DeepEqual(tt.wantJWT, gotJWT) {
t.Errorf("#%d: incorrect JWT returned from verifiedJWT: want=%#v got=%#v", i, tt.wantJWT, gotJWT)
}
}
}
func TestAuthenticatedTransportJWTCaching(t *testing.T) {
at := &AuthenticatedTransport{
TokenRefresher: &staticTokenRefresher{
verify: func(jose.JWT) error { return errors.New("fail!") },
refresh: func() (jose.JWT, error) { return jose.JWT{RawPayload: "2"}, nil },
},
jwt: jose.JWT{RawPayload: "1"},
}
wantJWT := jose.JWT{RawPayload: "2"}
gotJWT, err := at.verifiedJWT()
if err != nil {
t.Fatalf("got non-nil error: %#v", err)
}
if !reflect.DeepEqual(wantJWT, gotJWT) {
t.Fatalf("incorrect JWT returned from verifiedJWT: want=%#v got=%#v", wantJWT, gotJWT)
}
at.TokenRefresher = &staticTokenRefresher{
verify: func(jose.JWT) error { return nil },
refresh: func() (jose.JWT, error) { return jose.JWT{RawPayload: "3"}, nil },
}
// the previous JWT should still be cached on the AuthenticatedTransport since
// it is still valid, even though there's a new token ready to refresh
gotJWT, err = at.verifiedJWT()
if err != nil {
t.Fatalf("got non-nil error: %#v", err)
}
if !reflect.DeepEqual(wantJWT, gotJWT) {
t.Fatalf("incorrect JWT returned from verifiedJWT: want=%#v got=%#v", wantJWT, gotJWT)
}
}
func TestAuthenticatedTransportRoundTrip(t *testing.T) {
rr := &phttp.RequestRecorder{Response: &http.Response{StatusCode: http.StatusOK}}
at := &AuthenticatedTransport{
TokenRefresher: &staticTokenRefresher{
verify: func(jose.JWT) error { return nil },
},
RoundTripper: rr,
jwt: jose.JWT{RawPayload: "1"},
}
req := http.Request{}
_, err := at.RoundTrip(&req)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if !reflect.DeepEqual(req, http.Request{}) {
t.Errorf("http.Request object was modified")
}
want := []string{"Bearer .1."}
got := rr.Request.Header["Authorization"]
if !reflect.DeepEqual(want, got) {
t.Errorf("incorrect Authorization header: want=%#v got=%#v", want, got)
}
}
func TestAuthenticatedTransportRoundTripRefreshFail(t *testing.T) {
rr := &phttp.RequestRecorder{Response: &http.Response{StatusCode: http.StatusOK}}
at := &AuthenticatedTransport{
TokenRefresher: &staticTokenRefresher{
verify: func(jose.JWT) error { return errors.New("fail!") },
refresh: func() (jose.JWT, error) { return jose.JWT{}, errors.New("fail!") },
},
RoundTripper: rr,
jwt: jose.JWT{RawPayload: "1"},
}
_, err := at.RoundTrip(&http.Request{})
if err == nil {
t.Errorf("expected non-nil error")
}
}
...@@ -53,7 +53,7 @@ func CookieTokenExtractor(cookieName string) RequestTokenExtractor { ...@@ -53,7 +53,7 @@ func CookieTokenExtractor(cookieName string) RequestTokenExtractor {
} }
} }
func NewClaims(iss, sub, aud string, iat, exp time.Time) jose.Claims { func NewClaims(iss, sub string, aud interface{}, iat, exp time.Time) jose.Claims {
return jose.Claims{ return jose.Claims{
// required // required
"iss": iss, "iss": iss,
......
package oidc
import (
"fmt"
"net/http"
"reflect"
"testing"
"time"
"github.com/coreos/go-oidc/jose"
)
func TestCookieTokenExtractorInvalid(t *testing.T) {
ckName := "tokenCookie"
tests := []*http.Cookie{
&http.Cookie{},
&http.Cookie{Name: ckName},
&http.Cookie{Name: ckName, Value: ""},
}
for i, tt := range tests {
r, _ := http.NewRequest("", "", nil)
r.AddCookie(tt)
_, err := CookieTokenExtractor(ckName)(r)
if err == nil {
t.Errorf("case %d: want: error for invalid cookie token, got: no error.", i)
}
}
}
func TestCookieTokenExtractorValid(t *testing.T) {
validToken := "eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
ckName := "tokenCookie"
tests := []*http.Cookie{
&http.Cookie{Name: ckName, Value: "some non-empty value"},
&http.Cookie{Name: ckName, Value: validToken},
}
for i, tt := range tests {
r, _ := http.NewRequest("", "", nil)
r.AddCookie(tt)
_, err := CookieTokenExtractor(ckName)(r)
if err != nil {
t.Errorf("case %d: want: valid cookie with no error, got: %v", i, err)
}
}
}
func TestExtractBearerTokenInvalid(t *testing.T) {
tests := []string{"", "x", "Bearer", "xxxxxxx", "Bearer "}
for i, tt := range tests {
r, _ := http.NewRequest("", "", nil)
r.Header.Add("Authorization", tt)
_, err := ExtractBearerToken(r)
if err == nil {
t.Errorf("case %d: want: invalid Authorization header, got: valid Authorization header.", i)
}
}
}
func TestExtractBearerTokenValid(t *testing.T) {
validToken := "eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
tests := []string{
fmt.Sprintf("Bearer %s", validToken),
}
for i, tt := range tests {
r, _ := http.NewRequest("", "", nil)
r.Header.Add("Authorization", tt)
_, err := ExtractBearerToken(r)
if err != nil {
t.Errorf("case %d: want: valid Authorization header, got: invalid Authorization header: %v.", i, err)
}
}
}
func TestNewClaims(t *testing.T) {
issAt := time.Date(2, time.January, 1, 0, 0, 0, 0, time.UTC)
expAt := time.Date(2, time.January, 1, 1, 0, 0, 0, time.UTC)
want := jose.Claims{
"iss": "https://example.com",
"sub": "user-123",
"aud": "client-abc",
"iat": float64(issAt.Unix()),
"exp": float64(expAt.Unix()),
}
got := NewClaims("https://example.com", "user-123", "client-abc", issAt, expAt)
if !reflect.DeepEqual(want, got) {
t.Fatalf("want=%#v got=%#v", want, got)
}
}
...@@ -25,6 +25,17 @@ func VerifySignature(jwt jose.JWT, keys []key.PublicKey) (bool, error) { ...@@ -25,6 +25,17 @@ func VerifySignature(jwt jose.JWT, keys []key.PublicKey) (bool, error) {
return false, nil return false, nil
} }
// containsString returns true if the given string(needle) is found
// in the string array(haystack).
func containsString(needle string, haystack []string) bool {
for _, v := range haystack {
if v == needle {
return true
}
}
return false
}
// Verify claims in accordance with OIDC spec // Verify claims in accordance with OIDC spec
// http://openid.net/specs/openid-connect-basic-1_0.html#IDTokenValidation // http://openid.net/specs/openid-connect-basic-1_0.html#IDTokenValidation
func VerifyClaims(jwt jose.JWT, issuer, clientID string) error { func VerifyClaims(jwt jose.JWT, issuer, clientID string) error {
...@@ -45,7 +56,8 @@ func VerifyClaims(jwt jose.JWT, issuer, clientID string) error { ...@@ -45,7 +56,8 @@ func VerifyClaims(jwt jose.JWT, issuer, clientID string) error {
} }
// iss REQUIRED. Issuer Identifier for the Issuer of the response. // iss REQUIRED. Issuer Identifier for the Issuer of the response.
// The iss value is a case sensitive URL using the https scheme that contains scheme, host, and optionally, port number and path components and no query or fragment components. // The iss value is a case sensitive URL using the https scheme that contains scheme,
// host, and optionally, port number and path components and no query or fragment components.
if iss, exists := claims["iss"].(string); exists { if iss, exists := claims["iss"].(string); exists {
if !urlEqual(iss, issuer) { if !urlEqual(iss, issuer) {
return fmt.Errorf("invalid claim value: 'iss'. expected=%s, found=%s.", issuer, iss) return fmt.Errorf("invalid claim value: 'iss'. expected=%s, found=%s.", issuer, iss)
...@@ -55,19 +67,27 @@ func VerifyClaims(jwt jose.JWT, issuer, clientID string) error { ...@@ -55,19 +67,27 @@ func VerifyClaims(jwt jose.JWT, issuer, clientID string) error {
} }
// iat REQUIRED. Time at which the JWT was issued. // iat REQUIRED. Time at which the JWT was issued.
// Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time. // Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z
// as measured in UTC until the date/time.
if _, exists := claims["iat"].(float64); !exists { if _, exists := claims["iat"].(float64); !exists {
return errors.New("missing claim: 'iat'") return errors.New("missing claim: 'iat'")
} }
// aud REQUIRED. Audience(s) that this ID Token is intended for. // aud REQUIRED. Audience(s) that this ID Token is intended for.
// It MUST contain the OAuth 2.0 client_id of the Relying Party as an audience value. It MAY also contain identifiers for other audiences. In the general case, the aud value is an array of case sensitive strings. In the common special case when there is one audience, the aud value MAY be a single case sensitive string. // It MUST contain the OAuth 2.0 client_id of the Relying Party as an audience value.
if aud, exists := claims["aud"].(string); exists { // It MAY also contain identifiers for other audiences. In the general case, the aud
// value is an array of case sensitive strings. In the common special case when there
// is one audience, the aud value MAY be a single case sensitive string.
if aud, ok, err := claims.StringClaim("aud"); err == nil && ok {
if aud != clientID { if aud != clientID {
return errors.New("invalid claim value: 'aud'") return fmt.Errorf("invalid claims, 'aud' claim and 'client_id' do not match, aud=%s, client_id=%s", aud, clientID)
}
} else if aud, ok, err := claims.StringsClaim("aud"); err == nil && ok {
if !containsString(clientID, aud) {
return fmt.Errorf("invalid claims, cannot find 'client_id' in 'aud' claim, aud=%v, client_id=%s", aud, clientID)
} }
} else { } else {
return errors.New("missing claim: 'aud'") return errors.New("invalid claim value: 'aud' is required, and should be either string or string array")
} }
return nil return nil
...@@ -97,15 +117,16 @@ func VerifyClientClaims(jwt jose.JWT, issuer string) (string, error) { ...@@ -97,15 +117,16 @@ func VerifyClientClaims(jwt jose.JWT, issuer string) (string, error) {
return "", errors.New("missing required 'sub' claim") return "", errors.New("missing required 'sub' claim")
} }
aud, ok, err := claims.StringClaim("aud") if aud, ok, err := claims.StringClaim("aud"); err == nil && ok {
if err != nil { if aud != sub {
return "", fmt.Errorf("failed to parse 'aud' claim: %v", err) return "", fmt.Errorf("invalid claims, 'aud' claim and 'sub' claim do not match, aud=%s, sub=%s", aud, sub)
} else if !ok { }
return "", errors.New("missing required 'aud' claim") } else if aud, ok, err := claims.StringsClaim("aud"); err == nil && ok {
} if !containsString(sub, aud) {
return "", fmt.Errorf("invalid claims, cannot find 'sud' in 'aud' claim, aud=%v, sub=%s", aud, sub)
if sub != aud { }
return "", fmt.Errorf("invalid claims, 'aud' claim and 'sub' claim do not match, aud=%s, sub=%s", aud, sub) } else {
return "", errors.New("invalid claim value: 'aud' is required, and should be either string or string array")
} }
now := time.Now().UTC() now := time.Now().UTC()
......
package oidc
import (
"testing"
"time"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key"
)
func TestVerifyClientClaims(t *testing.T) {
validIss := "https://example.com"
validClientID := "valid-client"
now := time.Now()
tomorrow := now.Add(24 * time.Hour)
header := jose.JOSEHeader{
jose.HeaderKeyAlgorithm: "test-alg",
jose.HeaderKeyID: "1",
}
tests := []struct {
claims jose.Claims
ok bool
}{
// valid token
{
claims: jose.Claims{
"iss": validIss,
"sub": validClientID,
"aud": validClientID,
"iat": float64(now.Unix()),
"exp": float64(tomorrow.Unix()),
},
ok: true,
},
// missing 'iss' claim
{
claims: jose.Claims{
"sub": validClientID,
"aud": validClientID,
"iat": float64(now.Unix()),
"exp": float64(tomorrow.Unix()),
},
ok: false,
},
// invalid 'iss' claim
{
claims: jose.Claims{
"iss": "INVALID",
"sub": validClientID,
"aud": validClientID,
"iat": float64(now.Unix()),
"exp": float64(tomorrow.Unix()),
},
ok: false,
},
// missing 'sub' claim
{
claims: jose.Claims{
"iss": validIss,
"aud": validClientID,
"iat": float64(now.Unix()),
"exp": float64(tomorrow.Unix()),
},
ok: false,
},
// invalid 'sub' claim
{
claims: jose.Claims{
"iss": validIss,
"sub": "INVALID",
"aud": validClientID,
"iat": float64(now.Unix()),
"exp": float64(tomorrow.Unix()),
},
ok: false,
},
// missing 'aud' claim
{
claims: jose.Claims{
"iss": validIss,
"sub": validClientID,
"iat": float64(now.Unix()),
"exp": float64(tomorrow.Unix()),
},
ok: false,
},
// invalid 'aud' claim
{
claims: jose.Claims{
"iss": validIss,
"sub": validClientID,
"aud": "INVALID",
"iat": float64(now.Unix()),
"exp": float64(tomorrow.Unix()),
},
ok: false,
},
// expired
{
claims: jose.Claims{
"iss": validIss,
"sub": validClientID,
"aud": validClientID,
"iat": float64(now.Unix()),
"exp": float64(now.Unix()),
},
ok: false,
},
}
for i, tt := range tests {
jwt, err := jose.NewJWT(header, tt.claims)
if err != nil {
t.Fatalf("case %d: Failed to generate JWT, error=%v", i, err)
}
got, err := VerifyClientClaims(jwt, validIss)
if tt.ok {
if err != nil {
t.Errorf("case %d: unexpected error, err=%v", i, err)
}
if got != validClientID {
t.Errorf("case %d: incorrect client ID, want=%s, got=%s", i, validClientID, got)
}
} else if err == nil {
t.Errorf("case %d: expected error but err is nil", i)
}
}
}
func TestJWTVerifier(t *testing.T) {
iss := "http://example.com"
now := time.Now()
future12 := now.Add(12 * time.Hour)
past36 := now.Add(-36 * time.Hour)
past12 := now.Add(-12 * time.Hour)
priv1, err := key.GeneratePrivateKey()
if err != nil {
t.Fatalf("failed to generate private key, error=%v", err)
}
pk1 := *key.NewPublicKey(priv1.JWK())
priv2, err := key.GeneratePrivateKey()
if err != nil {
t.Fatalf("failed to generate private key, error=%v", err)
}
pk2 := *key.NewPublicKey(priv2.JWK())
jwtPK1, err := jose.NewSignedJWT(NewClaims(iss, "XXX", "XXX", past12, future12), priv1.Signer())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
jwtPK1BadClaims, err := jose.NewSignedJWT(NewClaims(iss, "XXX", "YYY", past12, future12), priv1.Signer())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
jwtExpired, err := jose.NewSignedJWT(NewClaims(iss, "XXX", "XXX", past36, past12), priv1.Signer())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
jwtPK2, err := jose.NewSignedJWT(NewClaims(iss, "XXX", "XXX", past12, future12), priv2.Signer())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
tests := []struct {
verifier JWTVerifier
jwt jose.JWT
wantErr bool
}{
// JWT signed with available key
{
verifier: JWTVerifier{
issuer: "example.com",
clientID: "XXX",
syncFunc: func() error { return nil },
keysFunc: func() []key.PublicKey {
return []key.PublicKey{pk1}
},
},
jwt: *jwtPK1,
wantErr: false,
},
// JWT signed with available key, with bad claims
{
verifier: JWTVerifier{
issuer: "example.com",
clientID: "XXX",
syncFunc: func() error { return nil },
keysFunc: func() []key.PublicKey {
return []key.PublicKey{pk1}
},
},
jwt: *jwtPK1BadClaims,
wantErr: true,
},
// expired JWT signed with available key
{
verifier: JWTVerifier{
issuer: "example.com",
clientID: "XXX",
syncFunc: func() error { return nil },
keysFunc: func() []key.PublicKey {
return []key.PublicKey{pk1}
},
},
jwt: *jwtExpired,
wantErr: true,
},
// JWT signed with unrecognized key, verifiable after sync
{
verifier: JWTVerifier{
issuer: "example.com",
clientID: "XXX",
syncFunc: func() error { return nil },
keysFunc: func() func() []key.PublicKey {
var i int
return func() []key.PublicKey {
defer func() { i++ }()
return [][]key.PublicKey{
[]key.PublicKey{pk1},
[]key.PublicKey{pk2},
}[i]
}
}(),
},
jwt: *jwtPK2,
wantErr: false,
},
// JWT signed with unrecognized key, not verifiable after sync
{
verifier: JWTVerifier{
issuer: "example.com",
clientID: "XXX",
syncFunc: func() error { return nil },
keysFunc: func() []key.PublicKey {
return []key.PublicKey{pk1}
},
},
jwt: *jwtPK2,
wantErr: true,
},
// verifier gets no keys from keysFunc, still not verifiable after sync
{
verifier: JWTVerifier{
issuer: "example.com",
clientID: "XXX",
syncFunc: func() error { return nil },
keysFunc: func() []key.PublicKey {
return []key.PublicKey{}
},
},
jwt: *jwtPK1,
wantErr: true,
},
// verifier gets no keys from keysFunc, verifiable after sync
{
verifier: JWTVerifier{
issuer: "example.com",
clientID: "XXX",
syncFunc: func() error { return nil },
keysFunc: func() func() []key.PublicKey {
var i int
return func() []key.PublicKey {
defer func() { i++ }()
return [][]key.PublicKey{
[]key.PublicKey{},
[]key.PublicKey{pk2},
}[i]
}
}(),
},
jwt: *jwtPK2,
wantErr: false,
},
}
for i, tt := range tests {
err := tt.verifier.Verify(tt.jwt)
if tt.wantErr && (err == nil) {
t.Errorf("case %d: wanted non-nil error", i)
} else if !tt.wantErr && (err != nil) {
t.Errorf("case %d: wanted nil error, got %v", i, err)
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment