Commit 731ffe63 authored by Eric Chiang's avatar Eric Chiang

*: remove v1 branch code

parent a93f71fd
package http
import "net/http"
type Client interface {
Do(*http.Request) (*http.Response, error)
}
// Package http is DEPRECATED. Use net/http instead.
package http
package http
import (
"encoding/base64"
"encoding/json"
"errors"
"log"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"time"
)
func WriteError(w http.ResponseWriter, code int, msg string) {
e := struct {
Error string `json:"error"`
}{
Error: msg,
}
b, err := json.Marshal(e)
if err != nil {
log.Printf("go-oidc: failed to marshal %#v: %v", e, err)
code = http.StatusInternalServerError
b = []byte(`{"error":"server_error"}`)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
w.Write(b)
}
// BasicAuth parses a username and password from the request's
// Authorization header. This was pulled from golang master:
// https://codereview.appspot.com/76540043
func BasicAuth(r *http.Request) (username, password string, ok bool) {
auth := r.Header.Get("Authorization")
if auth == "" {
return
}
if !strings.HasPrefix(auth, "Basic ") {
return
}
c, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(auth, "Basic "))
if err != nil {
return
}
cs := string(c)
s := strings.IndexByte(cs, ':')
if s < 0 {
return
}
return cs[:s], cs[s+1:], true
}
func cacheControlMaxAge(hdr string) (time.Duration, bool, error) {
for _, field := range strings.Split(hdr, ",") {
parts := strings.SplitN(strings.TrimSpace(field), "=", 2)
k := strings.ToLower(strings.TrimSpace(parts[0]))
if k != "max-age" {
continue
}
if len(parts) == 1 {
return 0, false, errors.New("max-age has no value")
}
v := strings.TrimSpace(parts[1])
if v == "" {
return 0, false, errors.New("max-age has empty value")
}
age, err := strconv.Atoi(v)
if err != nil {
return 0, false, err
}
if age <= 0 {
return 0, false, nil
}
return time.Duration(age) * time.Second, true, nil
}
return 0, false, nil
}
func expires(date, expires string) (time.Duration, bool, error) {
if date == "" || expires == "" {
return 0, false, nil
}
var te time.Time
var err error
if expires == "0" {
return 0, false, nil
}
te, err = time.Parse(time.RFC1123, expires)
if err != nil {
return 0, false, err
}
td, err := time.Parse(time.RFC1123, date)
if err != nil {
return 0, false, err
}
ttl := te.Sub(td)
// headers indicate data already expired, caller should not
// have to care about this case
if ttl <= 0 {
return 0, false, nil
}
return ttl, true, nil
}
func Cacheable(hdr http.Header) (time.Duration, bool, error) {
ttl, ok, err := cacheControlMaxAge(hdr.Get("Cache-Control"))
if err != nil || ok {
return ttl, ok, err
}
return expires(hdr.Get("Date"), hdr.Get("Expires"))
}
// MergeQuery appends additional query values to an existing URL.
func MergeQuery(u url.URL, q url.Values) url.URL {
uv := u.Query()
for k, vs := range q {
for _, v := range vs {
uv.Add(k, v)
}
}
u.RawQuery = uv.Encode()
return u
}
// NewResourceLocation appends a resource id to the end of the requested URL path.
func NewResourceLocation(reqURL *url.URL, id string) string {
var u url.URL
u = *reqURL
u.Path = path.Join(u.Path, id)
u.RawQuery = ""
u.Fragment = ""
return u.String()
}
// CopyRequest returns a clone of the provided *http.Request.
// The returned object is a shallow copy of the struct and a
// deep copy of its Header field.
func CopyRequest(r *http.Request) *http.Request {
r2 := *r
r2.Header = make(http.Header)
for k, s := range r.Header {
r2.Header[k] = s
}
return &r2
}
package http
import (
"net/http"
"net/url"
"reflect"
"testing"
"time"
)
func TestCacheControlMaxAgeSuccess(t *testing.T) {
tests := []struct {
hdr string
wantAge time.Duration
wantOK bool
}{
{"max-age=12", 12 * time.Second, true},
{"max-age=-12", 0, false},
{"max-age=0", 0, false},
{"public, max-age=12", 12 * time.Second, true},
{"public, max-age=40192, must-revalidate", 40192 * time.Second, true},
{"public, not-max-age=12, must-revalidate", time.Duration(0), false},
}
for i, tt := range tests {
maxAge, ok, err := cacheControlMaxAge(tt.hdr)
if err != nil {
t.Errorf("case %d: err=%v", i, err)
}
if tt.wantAge != maxAge {
t.Errorf("case %d: want=%d got=%d", i, tt.wantAge, maxAge)
}
if tt.wantOK != ok {
t.Errorf("case %d: incorrect ok value: want=%t got=%t", i, tt.wantOK, ok)
}
}
}
func TestCacheControlMaxAgeFail(t *testing.T) {
tests := []string{
"max-age=aasdf",
"max-age=",
"max-age",
}
for i, tt := range tests {
_, ok, err := cacheControlMaxAge(tt)
if ok {
t.Errorf("case %d: want ok=false, got true", i)
}
if err == nil {
t.Errorf("case %d: want non-nil err", i)
}
}
}
func TestMergeQuery(t *testing.T) {
tests := []struct {
u string
q url.Values
w string
}{
// No values
{
u: "http://example.com",
q: nil,
w: "http://example.com",
},
// No additional values
{
u: "http://example.com?foo=bar",
q: nil,
w: "http://example.com?foo=bar",
},
// Simple addition
{
u: "http://example.com",
q: url.Values{
"foo": []string{"bar"},
},
w: "http://example.com?foo=bar",
},
// Addition with existing values
{
u: "http://example.com?dog=boo",
q: url.Values{
"foo": []string{"bar"},
},
w: "http://example.com?dog=boo&foo=bar",
},
// Merge
{
u: "http://example.com?dog=boo",
q: url.Values{
"dog": []string{"elroy"},
},
w: "http://example.com?dog=boo&dog=elroy",
},
// Add and merge
{
u: "http://example.com?dog=boo",
q: url.Values{
"dog": []string{"elroy"},
"foo": []string{"bar"},
},
w: "http://example.com?dog=boo&dog=elroy&foo=bar",
},
// Multivalue merge
{
u: "http://example.com?dog=boo",
q: url.Values{
"dog": []string{"elroy", "penny"},
},
w: "http://example.com?dog=boo&dog=elroy&dog=penny",
},
}
for i, tt := range tests {
ur, err := url.Parse(tt.u)
if err != nil {
t.Errorf("case %d: failed parsing test url: %v, error: %v", i, tt.u, err)
}
got := MergeQuery(*ur, tt.q)
want, err := url.Parse(tt.w)
if err != nil {
t.Errorf("case %d: failed parsing want url: %v, error: %v", i, tt.w, err)
}
if !reflect.DeepEqual(*want, got) {
t.Errorf("case %d: want: %v, got: %v", i, *want, got)
}
}
}
func TestExpiresPass(t *testing.T) {
tests := []struct {
date string
exp string
wantTTL time.Duration
wantOK bool
}{
// Expires and Date properly set
{
date: "Thu, 01 Dec 1983 22:00:00 GMT",
exp: "Fri, 02 Dec 1983 01:00:00 GMT",
wantTTL: 10800 * time.Second,
wantOK: true,
},
// empty headers
{
date: "",
exp: "",
wantOK: false,
},
// lack of Expirs short-ciruits Date parsing
{
date: "foo",
exp: "",
wantOK: false,
},
// lack of Date short-ciruits Expires parsing
{
date: "",
exp: "foo",
wantOK: false,
},
// no Date
{
exp: "Thu, 01 Dec 1983 22:00:00 GMT",
wantTTL: 0,
wantOK: false,
},
// no Expires
{
date: "Thu, 01 Dec 1983 22:00:00 GMT",
wantTTL: 0,
wantOK: false,
},
// Expires set to false
{
date: "Thu, 01 Dec 1983 22:00:00 GMT",
exp: "0",
wantTTL: 0,
wantOK: false,
},
// Expires < Date
{
date: "Fri, 02 Dec 1983 01:00:00 GMT",
exp: "Thu, 01 Dec 1983 22:00:00 GMT",
wantTTL: 0,
wantOK: false,
},
}
for i, tt := range tests {
ttl, ok, err := expires(tt.date, tt.exp)
if err != nil {
t.Errorf("case %d: err=%v", i, err)
}
if tt.wantTTL != ttl {
t.Errorf("case %d: want=%d got=%d", i, tt.wantTTL, ttl)
}
if tt.wantOK != ok {
t.Errorf("case %d: incorrect ok value: want=%t got=%t", i, tt.wantOK, ok)
}
}
}
func TestExpiresFail(t *testing.T) {
tests := []struct {
date string
exp string
}{
// malformed Date header
{
date: "foo",
exp: "Fri, 02 Dec 1983 01:00:00 GMT",
},
// malformed exp header
{
date: "Fri, 02 Dec 1983 01:00:00 GMT",
exp: "bar",
},
}
for i, tt := range tests {
_, _, err := expires(tt.date, tt.exp)
if err == nil {
t.Errorf("case %d: expected non-nil error", i)
}
}
}
func TestCacheablePass(t *testing.T) {
tests := []struct {
headers http.Header
wantTTL time.Duration
wantOK bool
}{
// valid Cache-Control
{
headers: http.Header{
"Cache-Control": []string{"max-age=100"},
},
wantTTL: 100 * time.Second,
wantOK: true,
},
// valid Date/Expires
{
headers: http.Header{
"Date": []string{"Thu, 01 Dec 1983 22:00:00 GMT"},
"Expires": []string{"Fri, 02 Dec 1983 01:00:00 GMT"},
},
wantTTL: 10800 * time.Second,
wantOK: true,
},
// Cache-Control supersedes Date/Expires
{
headers: http.Header{
"Cache-Control": []string{"max-age=100"},
"Date": []string{"Thu, 01 Dec 1983 22:00:00 GMT"},
"Expires": []string{"Fri, 02 Dec 1983 01:00:00 GMT"},
},
wantTTL: 100 * time.Second,
wantOK: true,
},
// no caching headers
{
headers: http.Header{},
wantOK: false,
},
}
for i, tt := range tests {
ttl, ok, err := Cacheable(tt.headers)
if err != nil {
t.Errorf("case %d: err=%v", i, err)
continue
}
if tt.wantTTL != ttl {
t.Errorf("case %d: want=%d got=%d", i, tt.wantTTL, ttl)
}
if tt.wantOK != ok {
t.Errorf("case %d: incorrect ok value: want=%t got=%t", i, tt.wantOK, ok)
}
}
}
func TestCacheableFail(t *testing.T) {
tests := []http.Header{
// invalid Cache-Control short-circuits
http.Header{
"Cache-Control": []string{"max-age"},
"Date": []string{"Thu, 01 Dec 1983 22:00:00 GMT"},
"Expires": []string{"Fri, 02 Dec 1983 01:00:00 GMT"},
},
// no Cache-Control, invalid Expires
http.Header{
"Date": []string{"Thu, 01 Dec 1983 22:00:00 GMT"},
"Expires": []string{"boo"},
},
}
for i, tt := range tests {
_, _, err := Cacheable(tt)
if err == nil {
t.Errorf("case %d: want non-nil err", i)
}
}
}
func TestNewResourceLocation(t *testing.T) {
tests := []struct {
ru *url.URL
id string
want string
}{
{
ru: &url.URL{
Scheme: "http",
Host: "example.com",
},
id: "foo",
want: "http://example.com/foo",
},
// https
{
ru: &url.URL{
Scheme: "https",
Host: "example.com",
},
id: "foo",
want: "https://example.com/foo",
},
// with path
{
ru: &url.URL{
Scheme: "http",
Host: "example.com",
Path: "one/two/three",
},
id: "foo",
want: "http://example.com/one/two/three/foo",
},
// with fragment
{
ru: &url.URL{
Scheme: "http",
Host: "example.com",
Fragment: "frag",
},
id: "foo",
want: "http://example.com/foo",
},
// with query
{
ru: &url.URL{
Scheme: "http",
Host: "example.com",
RawQuery: "dog=elroy",
},
id: "foo",
want: "http://example.com/foo",
},
}
for i, tt := range tests {
got := NewResourceLocation(tt.ru, tt.id)
if tt.want != got {
t.Errorf("case %d: want=%s, got=%s", i, tt.want, got)
}
}
}
package http
import (
"errors"
"net/url"
)
// ParseNonEmptyURL checks that a string is a parsable URL which is also not empty
// since `url.Parse("")` does not return an error. Must contian a scheme and a host.
func ParseNonEmptyURL(u string) (*url.URL, error) {
if u == "" {
return nil, errors.New("url is empty")
}
ur, err := url.Parse(u)
if err != nil {
return nil, err
}
if ur.Scheme == "" {
return nil, errors.New("url scheme is empty")
}
if ur.Host == "" {
return nil, errors.New("url host is empty")
}
return ur, nil
}
package http
import (
"net/url"
"testing"
)
func TestParseNonEmptyURL(t *testing.T) {
tests := []struct {
u string
ok bool
}{
{"", false},
{"http://", false},
{"example.com", false},
{"example", false},
{"http://example", true},
{"http://example:1234", true},
{"http://example.com", true},
{"http://example.com:1234", true},
}
for i, tt := range tests {
u, err := ParseNonEmptyURL(tt.u)
if err != nil {
t.Logf("err: %v", err)
if tt.ok {
t.Errorf("case %d: unexpected error: %v", i, err)
} else {
continue
}
}
if !tt.ok {
t.Errorf("case %d: expected error but got none", i)
continue
}
uu, err := url.Parse(tt.u)
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
continue
}
if uu.String() != u.String() {
t.Errorf("case %d: incorrect url value, want: %q, got: %q", i, uu.String(), u.String())
}
}
}
package jose
import (
"encoding/json"
"fmt"
"math"
"time"
)
type Claims map[string]interface{}
func (c Claims) Add(name string, value interface{}) {
c[name] = value
}
func (c Claims) StringClaim(name string) (string, bool, error) {
cl, ok := c[name]
if !ok {
return "", false, nil
}
v, ok := cl.(string)
if !ok {
return "", false, fmt.Errorf("unable to parse claim as string: %v", name)
}
return v, true, nil
}
func (c Claims) StringsClaim(name string) ([]string, bool, error) {
cl, ok := c[name]
if !ok {
return nil, false, nil
}
if v, ok := cl.([]string); ok {
return v, true, nil
}
// When unmarshaled, []string will become []interface{}.
if v, ok := cl.([]interface{}); ok {
var ret []string
for _, vv := range v {
str, ok := vv.(string)
if !ok {
return nil, false, fmt.Errorf("unable to parse claim as string array: %v", name)
}
ret = append(ret, str)
}
return ret, true, nil
}
return nil, false, fmt.Errorf("unable to parse claim as string array: %v", name)
}
func (c Claims) Int64Claim(name string) (int64, bool, error) {
cl, ok := c[name]
if !ok {
return 0, false, nil
}
v, ok := cl.(int64)
if !ok {
vf, ok := cl.(float64)
if !ok {
return 0, false, fmt.Errorf("unable to parse claim as int64: %v", name)
}
v = int64(vf)
}
return v, true, nil
}
func (c Claims) Float64Claim(name string) (float64, bool, error) {
cl, ok := c[name]
if !ok {
return 0, false, nil
}
v, ok := cl.(float64)
if !ok {
vi, ok := cl.(int64)
if !ok {
return 0, false, fmt.Errorf("unable to parse claim as float64: %v", name)
}
v = float64(vi)
}
return v, true, nil
}
func (c Claims) TimeClaim(name string) (time.Time, bool, error) {
v, ok, err := c.Float64Claim(name)
if !ok || err != nil {
return time.Time{}, ok, err
}
s := math.Trunc(v)
ns := (v - s) * math.Pow(10, 9)
return time.Unix(int64(s), int64(ns)).UTC(), true, nil
}
func decodeClaims(payload []byte) (Claims, error) {
var c Claims
if err := json.Unmarshal(payload, &c); err != nil {
return nil, fmt.Errorf("malformed JWT claims, unable to decode: %v", err)
}
return c, nil
}
func marshalClaims(c Claims) ([]byte, error) {
b, err := json.Marshal(c)
if err != nil {
return nil, err
}
return b, nil
}
func encodeClaims(c Claims) (string, error) {
b, err := marshalClaims(c)
if err != nil {
return "", err
}
return encodeSegment(b), nil
}
package jose
import (
"reflect"
"testing"
"time"
)
func TestString(t *testing.T) {
tests := []struct {
cl Claims
key string
ok bool
err bool
val string
}{
// ok, no err, claim exists
{
cl: Claims{
"foo": "bar",
},
key: "foo",
val: "bar",
ok: true,
err: false,
},
// no claims
{
cl: Claims{},
key: "foo",
val: "",
ok: false,
err: false,
},
// missing claim
{
cl: Claims{
"foo": "bar",
},
key: "xxx",
val: "",
ok: false,
err: false,
},
// unparsable: type
{
cl: Claims{
"foo": struct{}{},
},
key: "foo",
val: "",
ok: false,
err: true,
},
// unparsable: nil value
{
cl: Claims{
"foo": nil,
},
key: "foo",
val: "",
ok: false,
err: true,
},
}
for i, tt := range tests {
val, ok, err := tt.cl.StringClaim(tt.key)
if tt.err && err == nil {
t.Errorf("case %d: want err=non-nil, got err=nil", i)
} else if !tt.err && err != nil {
t.Errorf("case %d: want err=nil, got err=%v", i, err)
}
if tt.ok != ok {
t.Errorf("case %d: want ok=%v, got ok=%v", i, tt.ok, ok)
}
if tt.val != val {
t.Errorf("case %d: want val=%v, got val=%v", i, tt.val, val)
}
}
}
func TestInt64(t *testing.T) {
tests := []struct {
cl Claims
key string
ok bool
err bool
val int64
}{
// ok, no err, claim exists
{
cl: Claims{
"foo": int64(100),
},
key: "foo",
val: int64(100),
ok: true,
err: false,
},
// no claims
{
cl: Claims{},
key: "foo",
val: 0,
ok: false,
err: false,
},
// missing claim
{
cl: Claims{
"foo": "bar",
},
key: "xxx",
val: 0,
ok: false,
err: false,
},
// unparsable: type
{
cl: Claims{
"foo": struct{}{},
},
key: "foo",
val: 0,
ok: false,
err: true,
},
// unparsable: nil value
{
cl: Claims{
"foo": nil,
},
key: "foo",
val: 0,
ok: false,
err: true,
},
}
for i, tt := range tests {
val, ok, err := tt.cl.Int64Claim(tt.key)
if tt.err && err == nil {
t.Errorf("case %d: want err=non-nil, got err=nil", i)
} else if !tt.err && err != nil {
t.Errorf("case %d: want err=nil, got err=%v", i, err)
}
if tt.ok != ok {
t.Errorf("case %d: want ok=%v, got ok=%v", i, tt.ok, ok)
}
if tt.val != val {
t.Errorf("case %d: want val=%v, got val=%v", i, tt.val, val)
}
}
}
func TestTime(t *testing.T) {
now := time.Now().UTC()
unixNow := now.Unix()
tests := []struct {
cl Claims
key string
ok bool
err bool
val time.Time
}{
// ok, no err, claim exists
{
cl: Claims{
"foo": unixNow,
},
key: "foo",
val: time.Unix(now.Unix(), 0).UTC(),
ok: true,
err: false,
},
// no claims
{
cl: Claims{},
key: "foo",
val: time.Time{},
ok: false,
err: false,
},
// missing claim
{
cl: Claims{
"foo": "bar",
},
key: "xxx",
val: time.Time{},
ok: false,
err: false,
},
// unparsable: type
{
cl: Claims{
"foo": struct{}{},
},
key: "foo",
val: time.Time{},
ok: false,
err: true,
},
// unparsable: nil value
{
cl: Claims{
"foo": nil,
},
key: "foo",
val: time.Time{},
ok: false,
err: true,
},
}
for i, tt := range tests {
val, ok, err := tt.cl.TimeClaim(tt.key)
if tt.err && err == nil {
t.Errorf("case %d: want err=non-nil, got err=nil", i)
} else if !tt.err && err != nil {
t.Errorf("case %d: want err=nil, got err=%v", i, err)
}
if tt.ok != ok {
t.Errorf("case %d: want ok=%v, got ok=%v", i, tt.ok, ok)
}
if tt.val != val {
t.Errorf("case %d: want val=%v, got val=%v", i, tt.val, val)
}
}
}
func TestStringArray(t *testing.T) {
tests := []struct {
cl Claims
key string
ok bool
err bool
val []string
}{
// ok, no err, claim exists
{
cl: Claims{
"foo": []string{"bar", "faf"},
},
key: "foo",
val: []string{"bar", "faf"},
ok: true,
err: false,
},
// ok, no err, []interface{}
{
cl: Claims{
"foo": []interface{}{"bar", "faf"},
},
key: "foo",
val: []string{"bar", "faf"},
ok: true,
err: false,
},
// no claims
{
cl: Claims{},
key: "foo",
val: nil,
ok: false,
err: false,
},
// missing claim
{
cl: Claims{
"foo": "bar",
},
key: "xxx",
val: nil,
ok: false,
err: false,
},
// unparsable: type
{
cl: Claims{
"foo": struct{}{},
},
key: "foo",
val: nil,
ok: false,
err: true,
},
// unparsable: nil value
{
cl: Claims{
"foo": nil,
},
key: "foo",
val: nil,
ok: false,
err: true,
},
}
for i, tt := range tests {
val, ok, err := tt.cl.StringsClaim(tt.key)
if tt.err && err == nil {
t.Errorf("case %d: want err=non-nil, got err=nil", i)
} else if !tt.err && err != nil {
t.Errorf("case %d: want err=nil, got err=%v", i, err)
}
if tt.ok != ok {
t.Errorf("case %d: want ok=%v, got ok=%v", i, tt.ok, ok)
}
if !reflect.DeepEqual(tt.val, val) {
t.Errorf("case %d: want val=%v, got val=%v", i, tt.val, val)
}
}
}
// Package jose is DEPRECATED. Use gopkg.in/square/go-jose.v2 instead.
package jose
package jose
import (
"encoding/base64"
"encoding/json"
"fmt"
"strings"
)
const (
HeaderMediaType = "typ"
HeaderKeyAlgorithm = "alg"
HeaderKeyID = "kid"
)
const (
// Encryption Algorithm Header Parameter Values for JWS
// See: https://tools.ietf.org/html/draft-ietf-jose-json-web-algorithms-40#page-6
AlgHS256 = "HS256"
AlgHS384 = "HS384"
AlgHS512 = "HS512"
AlgRS256 = "RS256"
AlgRS384 = "RS384"
AlgRS512 = "RS512"
AlgES256 = "ES256"
AlgES384 = "ES384"
AlgES512 = "ES512"
AlgPS256 = "PS256"
AlgPS384 = "PS384"
AlgPS512 = "PS512"
AlgNone = "none"
)
const (
// Algorithm Header Parameter Values for JWE
// See: https://tools.ietf.org/html/draft-ietf-jose-json-web-algorithms-40#section-4.1
AlgRSA15 = "RSA1_5"
AlgRSAOAEP = "RSA-OAEP"
AlgRSAOAEP256 = "RSA-OAEP-256"
AlgA128KW = "A128KW"
AlgA192KW = "A192KW"
AlgA256KW = "A256KW"
AlgDir = "dir"
AlgECDHES = "ECDH-ES"
AlgECDHESA128KW = "ECDH-ES+A128KW"
AlgECDHESA192KW = "ECDH-ES+A192KW"
AlgECDHESA256KW = "ECDH-ES+A256KW"
AlgA128GCMKW = "A128GCMKW"
AlgA192GCMKW = "A192GCMKW"
AlgA256GCMKW = "A256GCMKW"
AlgPBES2HS256A128KW = "PBES2-HS256+A128KW"
AlgPBES2HS384A192KW = "PBES2-HS384+A192KW"
AlgPBES2HS512A256KW = "PBES2-HS512+A256KW"
)
const (
// Encryption Algorithm Header Parameter Values for JWE
// See: https://tools.ietf.org/html/draft-ietf-jose-json-web-algorithms-40#page-22
EncA128CBCHS256 = "A128CBC-HS256"
EncA128CBCHS384 = "A128CBC-HS384"
EncA256CBCHS512 = "A256CBC-HS512"
EncA128GCM = "A128GCM"
EncA192GCM = "A192GCM"
EncA256GCM = "A256GCM"
)
type JOSEHeader map[string]string
func (j JOSEHeader) Validate() error {
if _, exists := j[HeaderKeyAlgorithm]; !exists {
return fmt.Errorf("header missing %q parameter", HeaderKeyAlgorithm)
}
return nil
}
func decodeHeader(seg string) (JOSEHeader, error) {
b, err := decodeSegment(seg)
if err != nil {
return nil, err
}
var h JOSEHeader
err = json.Unmarshal(b, &h)
if err != nil {
return nil, err
}
return h, nil
}
func encodeHeader(h JOSEHeader) (string, error) {
b, err := json.Marshal(h)
if err != nil {
return "", err
}
return encodeSegment(b), nil
}
// Decode JWT specific base64url encoding with padding stripped
func decodeSegment(seg string) ([]byte, error) {
if l := len(seg) % 4; l != 0 {
seg += strings.Repeat("=", 4-l)
}
return base64.URLEncoding.DecodeString(seg)
}
// Encode JWT specific base64url encoding with padding stripped
func encodeSegment(seg []byte) string {
return strings.TrimRight(base64.URLEncoding.EncodeToString(seg), "=")
}
package jose
import (
"bytes"
"encoding/base64"
"encoding/binary"
"encoding/json"
"math/big"
"strings"
)
// JSON Web Key
// https://tools.ietf.org/html/draft-ietf-jose-json-web-key-36#page-5
type JWK struct {
ID string
Type string
Alg string
Use string
Exponent int
Modulus *big.Int
Secret []byte
}
type jwkJSON struct {
ID string `json:"kid"`
Type string `json:"kty"`
Alg string `json:"alg"`
Use string `json:"use"`
Exponent string `json:"e"`
Modulus string `json:"n"`
}
func (j *JWK) MarshalJSON() ([]byte, error) {
t := jwkJSON{
ID: j.ID,
Type: j.Type,
Alg: j.Alg,
Use: j.Use,
Exponent: encodeExponent(j.Exponent),
Modulus: encodeModulus(j.Modulus),
}
return json.Marshal(&t)
}
func (j *JWK) UnmarshalJSON(data []byte) error {
var t jwkJSON
err := json.Unmarshal(data, &t)
if err != nil {
return err
}
e, err := decodeExponent(t.Exponent)
if err != nil {
return err
}
n, err := decodeModulus(t.Modulus)
if err != nil {
return err
}
j.ID = t.ID
j.Type = t.Type
j.Alg = t.Alg
j.Use = t.Use
j.Exponent = e
j.Modulus = n
return nil
}
type JWKSet struct {
Keys []JWK `json:"keys"`
}
func decodeExponent(e string) (int, error) {
decE, err := decodeBase64URLPaddingOptional(e)
if err != nil {
return 0, err
}
var eBytes []byte
if len(decE) < 8 {
eBytes = make([]byte, 8-len(decE), 8)
eBytes = append(eBytes, decE...)
} else {
eBytes = decE
}
eReader := bytes.NewReader(eBytes)
var E uint64
err = binary.Read(eReader, binary.BigEndian, &E)
if err != nil {
return 0, err
}
return int(E), nil
}
func encodeExponent(e int) string {
b := make([]byte, 8)
binary.BigEndian.PutUint64(b, uint64(e))
var idx int
for ; idx < 8; idx++ {
if b[idx] != 0x0 {
break
}
}
return base64.RawURLEncoding.EncodeToString(b[idx:])
}
// Turns a URL encoded modulus of a key into a big int.
func decodeModulus(n string) (*big.Int, error) {
decN, err := decodeBase64URLPaddingOptional(n)
if err != nil {
return nil, err
}
N := big.NewInt(0)
N.SetBytes(decN)
return N, nil
}
func encodeModulus(n *big.Int) string {
return base64.RawURLEncoding.EncodeToString(n.Bytes())
}
// decodeBase64URLPaddingOptional decodes Base64 whether there is padding or not.
// The stdlib version currently doesn't handle this.
// We can get rid of this is if this bug:
// https://github.com/golang/go/issues/4237
// ever closes.
func decodeBase64URLPaddingOptional(e string) ([]byte, error) {
if m := len(e) % 4; m != 0 {
e += strings.Repeat("=", 4-m)
}
return base64.URLEncoding.DecodeString(e)
}
package jose
import (
"testing"
)
func TestDecodeBase64URLPaddingOptional(t *testing.T) {
tests := []struct {
encoded string
decoded string
err bool
}{
{
// With padding
encoded: "VGVjdG9uaWM=",
decoded: "Tectonic",
},
{
// Without padding
encoded: "VGVjdG9uaWM",
decoded: "Tectonic",
},
{
// Even More padding
encoded: "VGVjdG9uaQ==",
decoded: "Tectoni",
},
{
// And take it away!
encoded: "VGVjdG9uaQ",
decoded: "Tectoni",
},
{
// Too much padding.
encoded: "VGVjdG9uaWNh=",
decoded: "",
err: true,
},
{
// Too much padding.
encoded: "VGVjdG9uaWNh=",
decoded: "",
err: true,
},
}
for i, tt := range tests {
got, err := decodeBase64URLPaddingOptional(tt.encoded)
if tt.err {
if err == nil {
t.Errorf("case %d: expected non-nil err", i)
}
continue
}
if err != nil {
t.Errorf("case %d: want nil err, got: %v", i, err)
}
if string(got) != tt.decoded {
t.Errorf("case %d: want=%q, got=%q", i, tt.decoded, got)
}
}
}
package jose
import (
"fmt"
"strings"
)
type JWS struct {
RawHeader string
Header JOSEHeader
RawPayload string
Payload []byte
Signature []byte
}
// Given a raw encoded JWS token parses it and verifies the structure.
func ParseJWS(raw string) (JWS, error) {
parts := strings.Split(raw, ".")
if len(parts) != 3 {
return JWS{}, fmt.Errorf("malformed JWS, only %d segments", len(parts))
}
rawSig := parts[2]
jws := JWS{
RawHeader: parts[0],
RawPayload: parts[1],
}
header, err := decodeHeader(jws.RawHeader)
if err != nil {
return JWS{}, fmt.Errorf("malformed JWS, unable to decode header, %s", err)
}
if err = header.Validate(); err != nil {
return JWS{}, fmt.Errorf("malformed JWS, %s", err)
}
jws.Header = header
payload, err := decodeSegment(jws.RawPayload)
if err != nil {
return JWS{}, fmt.Errorf("malformed JWS, unable to decode payload: %s", err)
}
jws.Payload = payload
sig, err := decodeSegment(rawSig)
if err != nil {
return JWS{}, fmt.Errorf("malformed JWS, unable to decode signature: %s", err)
}
jws.Signature = sig
return jws, nil
}
package jose
import (
"strings"
"testing"
)
type testCase struct{ t string }
var validInput []testCase
var invalidInput []testCase
func init() {
validInput = []testCase{
{
"eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk",
},
}
invalidInput = []testCase{
// empty
{
"",
},
// undecodeable
{
"aaa.bbb.ccc",
},
// missing parts
{
"aaa",
},
// missing parts
{
"aaa.bbb",
},
// too many parts
{
"aaa.bbb.ccc.ddd",
},
// invalid header
// EncodeHeader(map[string]string{"foo": "bar"})
{
"eyJmb28iOiJiYXIifQ.bbb.ccc",
},
}
}
func TestParseJWS(t *testing.T) {
for i, tt := range validInput {
jws, err := ParseJWS(tt.t)
if err != nil {
t.Errorf("test: %d. expected: valid, actual: invalid", i)
}
expectedHeader := strings.Split(tt.t, ".")[0]
if jws.RawHeader != expectedHeader {
t.Errorf("test: %d. expected: %s, actual: %s", i, expectedHeader, jws.RawHeader)
}
expectedPayload := strings.Split(tt.t, ".")[1]
if jws.RawPayload != expectedPayload {
t.Errorf("test: %d. expected: %s, actual: %s", i, expectedPayload, jws.RawPayload)
}
}
for i, tt := range invalidInput {
_, err := ParseJWS(tt.t)
if err == nil {
t.Errorf("test: %d. expected: invalid, actual: valid", i)
}
}
}
package jose
import "strings"
type JWT JWS
func ParseJWT(token string) (jwt JWT, err error) {
jws, err := ParseJWS(token)
if err != nil {
return
}
return JWT(jws), nil
}
func NewJWT(header JOSEHeader, claims Claims) (jwt JWT, err error) {
jwt = JWT{}
jwt.Header = header
jwt.Header[HeaderMediaType] = "JWT"
claimBytes, err := marshalClaims(claims)
if err != nil {
return
}
jwt.Payload = claimBytes
eh, err := encodeHeader(header)
if err != nil {
return
}
jwt.RawHeader = eh
ec, err := encodeClaims(claims)
if err != nil {
return
}
jwt.RawPayload = ec
return
}
func (j *JWT) KeyID() (string, bool) {
kID, ok := j.Header[HeaderKeyID]
return kID, ok
}
func (j *JWT) Claims() (Claims, error) {
return decodeClaims(j.Payload)
}
// Encoded data part of the token which may be signed.
func (j *JWT) Data() string {
return strings.Join([]string{j.RawHeader, j.RawPayload}, ".")
}
// Full encoded JWT token string in format: header.claims.signature
func (j *JWT) Encode() string {
d := j.Data()
s := encodeSegment(j.Signature)
return strings.Join([]string{d, s}, ".")
}
func NewSignedJWT(claims Claims, s Signer) (*JWT, error) {
header := JOSEHeader{
HeaderKeyAlgorithm: s.Alg(),
HeaderKeyID: s.ID(),
}
jwt, err := NewJWT(header, claims)
if err != nil {
return nil, err
}
sig, err := s.Sign([]byte(jwt.Data()))
if err != nil {
return nil, err
}
jwt.Signature = sig
return &jwt, nil
}
package jose
import (
"reflect"
"testing"
)
func TestParseJWT(t *testing.T) {
tests := []struct {
r string
h JOSEHeader
c Claims
}{
{
// Example from JWT spec:
// http://self-issued.info/docs/draft-ietf-oauth-json-web-token.html#ExampleJWT
"eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk",
JOSEHeader{
HeaderMediaType: "JWT",
HeaderKeyAlgorithm: "HS256",
},
Claims{
"iss": "joe",
// NOTE: test numbers must be floats for equality checks to work since values are converted form interface{} to float64 by default.
"exp": 1300819380.0,
"http://example.com/is_root": true,
},
},
}
for i, tt := range tests {
jwt, err := ParseJWT(tt.r)
if err != nil {
t.Errorf("raw token should parse. test: %d. expected: valid, actual: invalid. err=%v", i, err)
}
if !reflect.DeepEqual(tt.h, jwt.Header) {
t.Errorf("JOSE headers should match. test: %d. expected: %v, actual: %v", i, tt.h, jwt.Header)
}
claims, err := jwt.Claims()
if err != nil {
t.Errorf("test: %d. expected: valid claim parsing. err=%v", i, err)
}
if !reflect.DeepEqual(tt.c, claims) {
t.Errorf("claims should match. test: %d. expected: %v, actual: %v", i, tt.c, claims)
}
enc := jwt.Encode()
if enc != tt.r {
t.Errorf("encoded jwt should match raw jwt. test: %d. expected: %v, actual: %v", i, tt.r, enc)
}
}
}
func TestNewJWTHeaderType(t *testing.T) {
jwt, err := NewJWT(JOSEHeader{}, Claims{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
want := "JWT"
got := jwt.Header[HeaderMediaType]
if want != got {
t.Fatalf("Header %q incorrect: want=%s got=%s", HeaderMediaType, want, got)
}
}
func TestNewJWTHeaderKeyID(t *testing.T) {
jwt, err := NewJWT(JOSEHeader{HeaderKeyID: "foo"}, Claims{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
want := "foo"
got, ok := jwt.KeyID()
if !ok {
t.Fatalf("KeyID not set")
} else if want != got {
t.Fatalf("KeyID incorrect: want=%s got=%s", want, got)
}
}
func TestNewJWTHeaderKeyIDNotSet(t *testing.T) {
jwt, err := NewJWT(JOSEHeader{}, Claims{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if _, ok := jwt.KeyID(); ok {
t.Fatalf("KeyID set, but should not be")
}
}
package jose
import (
"fmt"
)
type Verifier interface {
ID() string
Alg() string
Verify(sig []byte, data []byte) error
}
type Signer interface {
Verifier
Sign(data []byte) (sig []byte, err error)
}
func NewVerifier(jwk JWK) (Verifier, error) {
if jwk.Type != "RSA" {
return nil, fmt.Errorf("unsupported key type %q", jwk.Type)
}
return NewVerifierRSA(jwk)
}
package jose
import (
"crypto"
"crypto/rand"
"crypto/rsa"
"fmt"
)
type VerifierRSA struct {
KeyID string
Hash crypto.Hash
PublicKey rsa.PublicKey
}
type SignerRSA struct {
PrivateKey rsa.PrivateKey
VerifierRSA
}
func NewVerifierRSA(jwk JWK) (*VerifierRSA, error) {
if jwk.Alg != "" && jwk.Alg != "RS256" {
return nil, fmt.Errorf("unsupported key algorithm %q", jwk.Alg)
}
v := VerifierRSA{
KeyID: jwk.ID,
PublicKey: rsa.PublicKey{
N: jwk.Modulus,
E: jwk.Exponent,
},
Hash: crypto.SHA256,
}
return &v, nil
}
func NewSignerRSA(kid string, key rsa.PrivateKey) *SignerRSA {
return &SignerRSA{
PrivateKey: key,
VerifierRSA: VerifierRSA{
KeyID: kid,
PublicKey: key.PublicKey,
Hash: crypto.SHA256,
},
}
}
func (v *VerifierRSA) ID() string {
return v.KeyID
}
func (v *VerifierRSA) Alg() string {
return "RS256"
}
func (v *VerifierRSA) Verify(sig []byte, data []byte) error {
h := v.Hash.New()
h.Write(data)
return rsa.VerifyPKCS1v15(&v.PublicKey, v.Hash, h.Sum(nil), sig)
}
func (s *SignerRSA) Sign(data []byte) ([]byte, error) {
h := s.Hash.New()
h.Write(data)
return rsa.SignPKCS1v15(rand.Reader, &s.PrivateKey, s.Hash, h.Sum(nil))
}
// Package key is DEPRECATED. Use github.com/coreos/go-oidc instead.
package key
package key
import (
"crypto/rand"
"crypto/rsa"
"encoding/hex"
"encoding/json"
"io"
"time"
"github.com/coreos/go-oidc/jose"
)
func NewPublicKey(jwk jose.JWK) *PublicKey {
return &PublicKey{jwk: jwk}
}
type PublicKey struct {
jwk jose.JWK
}
func (k *PublicKey) MarshalJSON() ([]byte, error) {
return json.Marshal(&k.jwk)
}
func (k *PublicKey) UnmarshalJSON(data []byte) error {
var jwk jose.JWK
if err := json.Unmarshal(data, &jwk); err != nil {
return err
}
k.jwk = jwk
return nil
}
func (k *PublicKey) ID() string {
return k.jwk.ID
}
func (k *PublicKey) Verifier() (jose.Verifier, error) {
return jose.NewVerifierRSA(k.jwk)
}
type PrivateKey struct {
KeyID string
PrivateKey *rsa.PrivateKey
}
func (k *PrivateKey) ID() string {
return k.KeyID
}
func (k *PrivateKey) Signer() jose.Signer {
return jose.NewSignerRSA(k.ID(), *k.PrivateKey)
}
func (k *PrivateKey) JWK() jose.JWK {
return jose.JWK{
ID: k.KeyID,
Type: "RSA",
Alg: "RS256",
Use: "sig",
Exponent: k.PrivateKey.PublicKey.E,
Modulus: k.PrivateKey.PublicKey.N,
}
}
type KeySet interface {
ExpiresAt() time.Time
}
type PublicKeySet struct {
keys []PublicKey
index map[string]*PublicKey
expiresAt time.Time
}
func NewPublicKeySet(jwks []jose.JWK, exp time.Time) *PublicKeySet {
keys := make([]PublicKey, len(jwks))
index := make(map[string]*PublicKey)
for i, jwk := range jwks {
keys[i] = *NewPublicKey(jwk)
index[keys[i].ID()] = &keys[i]
}
return &PublicKeySet{
keys: keys,
index: index,
expiresAt: exp,
}
}
func (s *PublicKeySet) ExpiresAt() time.Time {
return s.expiresAt
}
func (s *PublicKeySet) Keys() []PublicKey {
return s.keys
}
func (s *PublicKeySet) Key(id string) *PublicKey {
return s.index[id]
}
type PrivateKeySet struct {
keys []*PrivateKey
ActiveKeyID string
expiresAt time.Time
}
func NewPrivateKeySet(keys []*PrivateKey, exp time.Time) *PrivateKeySet {
return &PrivateKeySet{
keys: keys,
ActiveKeyID: keys[0].ID(),
expiresAt: exp.UTC(),
}
}
func (s *PrivateKeySet) Keys() []*PrivateKey {
return s.keys
}
func (s *PrivateKeySet) ExpiresAt() time.Time {
return s.expiresAt
}
func (s *PrivateKeySet) Active() *PrivateKey {
for i, k := range s.keys {
if k.ID() == s.ActiveKeyID {
return s.keys[i]
}
}
return nil
}
type GeneratePrivateKeyFunc func() (*PrivateKey, error)
func GeneratePrivateKey() (*PrivateKey, error) {
pk, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, err
}
keyID := make([]byte, 20)
if _, err := io.ReadFull(rand.Reader, keyID); err != nil {
return nil, err
}
k := PrivateKey{
KeyID: hex.EncodeToString(keyID),
PrivateKey: pk,
}
return &k, nil
}
package key
import (
"crypto/rsa"
"math/big"
"reflect"
"testing"
"time"
"github.com/coreos/go-oidc/jose"
)
func TestPrivateRSAKeyJWK(t *testing.T) {
n := big.NewInt(int64(17))
if n == nil {
panic("NewInt returned nil")
}
k := &PrivateKey{
KeyID: "foo",
PrivateKey: &rsa.PrivateKey{
PublicKey: rsa.PublicKey{N: n, E: 65537},
},
}
want := jose.JWK{
ID: "foo",
Type: "RSA",
Alg: "RS256",
Use: "sig",
Modulus: n,
Exponent: 65537,
}
got := k.JWK()
if !reflect.DeepEqual(want, got) {
t.Fatalf("JWK mismatch: want=%#v got=%#v", want, got)
}
}
func TestPublicKeySetKey(t *testing.T) {
n := big.NewInt(int64(17))
if n == nil {
panic("NewInt returned nil")
}
k := jose.JWK{
ID: "foo",
Type: "RSA",
Alg: "RS256",
Use: "sig",
Modulus: n,
Exponent: 65537,
}
now := time.Now().UTC()
ks := NewPublicKeySet([]jose.JWK{k}, now)
want := &PublicKey{jwk: k}
got := ks.Key("foo")
if !reflect.DeepEqual(want, got) {
t.Errorf("Unexpected response from PublicKeySet.Key: want=%#v got=%#v", want, got)
}
got = ks.Key("bar")
if got != nil {
t.Errorf("Expected nil response from PublicKeySet.Key, got %#v", got)
}
}
func TestPublicKeyMarshalJSON(t *testing.T) {
k := jose.JWK{
ID: "foo",
Type: "RSA",
Alg: "RS256",
Use: "sig",
Modulus: big.NewInt(int64(17)),
Exponent: 65537,
}
want := `{"kid":"foo","kty":"RSA","alg":"RS256","use":"sig","e":"AQAB","n":"EQ"}`
pubKey := NewPublicKey(k)
gotBytes, err := pubKey.MarshalJSON()
if err != nil {
t.Fatalf("failed to marshal public key: %v", err)
}
got := string(gotBytes)
if got != want {
t.Errorf("got != want:\n%s\n%s", got, want)
}
}
func TestGeneratePrivateKeyIDs(t *testing.T) {
key1, err := GeneratePrivateKey()
if err != nil {
t.Fatalf("GeneratePrivateKey(): %v", err)
}
key2, err := GeneratePrivateKey()
if err != nil {
t.Fatalf("GeneratePrivateKey(): %v", err)
}
if key1.KeyID == key2.KeyID {
t.Fatalf("expected different keys to have different key IDs")
}
}
package key
import (
"errors"
"time"
"github.com/jonboulle/clockwork"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/pkg/health"
)
type PrivateKeyManager interface {
ExpiresAt() time.Time
Signer() (jose.Signer, error)
JWKs() ([]jose.JWK, error)
PublicKeys() ([]PublicKey, error)
WritableKeySetRepo
health.Checkable
}
func NewPrivateKeyManager() PrivateKeyManager {
return &privateKeyManager{
clock: clockwork.NewRealClock(),
}
}
type privateKeyManager struct {
keySet *PrivateKeySet
clock clockwork.Clock
}
func (m *privateKeyManager) ExpiresAt() time.Time {
if m.keySet == nil {
return m.clock.Now().UTC()
}
return m.keySet.ExpiresAt()
}
func (m *privateKeyManager) Signer() (jose.Signer, error) {
if err := m.Healthy(); err != nil {
return nil, err
}
return m.keySet.Active().Signer(), nil
}
func (m *privateKeyManager) JWKs() ([]jose.JWK, error) {
if err := m.Healthy(); err != nil {
return nil, err
}
keys := m.keySet.Keys()
jwks := make([]jose.JWK, len(keys))
for i, k := range keys {
jwks[i] = k.JWK()
}
return jwks, nil
}
func (m *privateKeyManager) PublicKeys() ([]PublicKey, error) {
jwks, err := m.JWKs()
if err != nil {
return nil, err
}
keys := make([]PublicKey, len(jwks))
for i, jwk := range jwks {
keys[i] = *NewPublicKey(jwk)
}
return keys, nil
}
func (m *privateKeyManager) Healthy() error {
if m.keySet == nil {
return errors.New("private key manager uninitialized")
}
if len(m.keySet.Keys()) == 0 {
return errors.New("private key manager zero keys")
}
if m.keySet.ExpiresAt().Before(m.clock.Now().UTC()) {
return errors.New("private key manager keys expired")
}
return nil
}
func (m *privateKeyManager) Set(keySet KeySet) error {
privKeySet, ok := keySet.(*PrivateKeySet)
if !ok {
return errors.New("unable to cast to PrivateKeySet")
}
m.keySet = privKeySet
return nil
}
package key
import (
"crypto/rsa"
"math/big"
"reflect"
"strconv"
"testing"
"time"
"github.com/jonboulle/clockwork"
"github.com/coreos/go-oidc/jose"
)
var (
jwk1 jose.JWK
jwk2 jose.JWK
jwk3 jose.JWK
)
func init() {
jwk1 = jose.JWK{
ID: "1",
Type: "RSA",
Alg: "RS256",
Use: "sig",
Modulus: big.NewInt(1),
Exponent: 65537,
}
jwk2 = jose.JWK{
ID: "2",
Type: "RSA",
Alg: "RS256",
Use: "sig",
Modulus: big.NewInt(2),
Exponent: 65537,
}
jwk3 = jose.JWK{
ID: "3",
Type: "RSA",
Alg: "RS256",
Use: "sig",
Modulus: big.NewInt(3),
Exponent: 65537,
}
}
func generatePrivateKeyStatic(t *testing.T, idAndN int) *PrivateKey {
n := big.NewInt(int64(idAndN))
if n == nil {
t.Fatalf("Call to NewInt(%d) failed", idAndN)
}
pk := &rsa.PrivateKey{
PublicKey: rsa.PublicKey{N: n, E: 65537},
}
return &PrivateKey{
KeyID: strconv.Itoa(idAndN),
PrivateKey: pk,
}
}
func TestPrivateKeyManagerJWKsRotate(t *testing.T) {
k1 := generatePrivateKeyStatic(t, 1)
k2 := generatePrivateKeyStatic(t, 2)
k3 := generatePrivateKeyStatic(t, 3)
km := NewPrivateKeyManager()
err := km.Set(&PrivateKeySet{
keys: []*PrivateKey{k1, k2, k3},
ActiveKeyID: k1.KeyID,
expiresAt: time.Now().Add(time.Minute),
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
want := []jose.JWK{jwk1, jwk2, jwk3}
got, err := km.JWKs()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if !reflect.DeepEqual(want, got) {
t.Fatalf("JWK mismatch: want=%#v got=%#v", want, got)
}
}
func TestPrivateKeyManagerSigner(t *testing.T) {
k := generatePrivateKeyStatic(t, 13)
km := NewPrivateKeyManager()
err := km.Set(&PrivateKeySet{
keys: []*PrivateKey{k},
ActiveKeyID: k.KeyID,
expiresAt: time.Now().Add(time.Minute),
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
signer, err := km.Signer()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
wantID := "13"
gotID := signer.ID()
if wantID != gotID {
t.Fatalf("Signer has incorrect ID: want=%s got=%s", wantID, gotID)
}
}
func TestPrivateKeyManagerHealthyFail(t *testing.T) {
keyFixture := generatePrivateKeyStatic(t, 1)
tests := []*privateKeyManager{
// keySet nil
&privateKeyManager{
keySet: nil,
clock: clockwork.NewRealClock(),
},
// zero keys
&privateKeyManager{
keySet: &PrivateKeySet{
keys: []*PrivateKey{},
expiresAt: time.Now().Add(time.Minute),
},
clock: clockwork.NewRealClock(),
},
// key set expired
&privateKeyManager{
keySet: &PrivateKeySet{
keys: []*PrivateKey{keyFixture},
expiresAt: time.Now().Add(-1 * time.Minute),
},
clock: clockwork.NewRealClock(),
},
}
for i, tt := range tests {
if err := tt.Healthy(); err == nil {
t.Errorf("case %d: nil error", i)
}
}
}
func TestPrivateKeyManagerHealthyFailsOtherMethods(t *testing.T) {
km := NewPrivateKeyManager()
if _, err := km.JWKs(); err == nil {
t.Fatalf("Expected non-nil error")
}
if _, err := km.Signer(); err == nil {
t.Fatalf("Expected non-nil error")
}
}
func TestPrivateKeyManagerExpiresAt(t *testing.T) {
fc := clockwork.NewFakeClock()
now := fc.Now().UTC()
k := generatePrivateKeyStatic(t, 17)
km := &privateKeyManager{
clock: fc,
}
want := fc.Now().UTC()
got := km.ExpiresAt()
if want != got {
t.Fatalf("Incorrect expiration time: want=%v got=%v", want, got)
}
err := km.Set(&PrivateKeySet{
keys: []*PrivateKey{k},
ActiveKeyID: k.KeyID,
expiresAt: now.Add(2 * time.Minute),
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
want = fc.Now().UTC().Add(2 * time.Minute)
got = km.ExpiresAt()
if want != got {
t.Fatalf("Incorrect expiration time: want=%v got=%v", want, got)
}
}
func TestPublicKeys(t *testing.T) {
km := NewPrivateKeyManager()
k1 := generatePrivateKeyStatic(t, 1)
k2 := generatePrivateKeyStatic(t, 2)
k3 := generatePrivateKeyStatic(t, 3)
tests := [][]*PrivateKey{
[]*PrivateKey{k1},
[]*PrivateKey{k1, k2},
[]*PrivateKey{k1, k2, k3},
}
for i, tt := range tests {
ks := &PrivateKeySet{
keys: tt,
expiresAt: time.Now().Add(time.Hour),
}
km.Set(ks)
jwks, err := km.JWKs()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
pks := NewPublicKeySet(jwks, time.Now().Add(time.Hour))
want := pks.Keys()
got, err := km.PublicKeys()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if !reflect.DeepEqual(want, got) {
t.Errorf("case %d: Invalid public keys: want=%v got=%v", i, want, got)
}
}
}
package key
import (
"errors"
"sync"
)
var ErrorNoKeys = errors.New("no keys found")
type WritableKeySetRepo interface {
Set(KeySet) error
}
type ReadableKeySetRepo interface {
Get() (KeySet, error)
}
type PrivateKeySetRepo interface {
WritableKeySetRepo
ReadableKeySetRepo
}
func NewPrivateKeySetRepo() PrivateKeySetRepo {
return &memPrivateKeySetRepo{}
}
type memPrivateKeySetRepo struct {
mu sync.RWMutex
pks PrivateKeySet
}
func (r *memPrivateKeySetRepo) Set(ks KeySet) error {
pks, ok := ks.(*PrivateKeySet)
if !ok {
return errors.New("unable to cast to PrivateKeySet")
} else if pks == nil {
return errors.New("nil KeySet")
}
r.mu.Lock()
defer r.mu.Unlock()
r.pks = *pks
return nil
}
func (r *memPrivateKeySetRepo) Get() (KeySet, error) {
r.mu.RLock()
defer r.mu.RUnlock()
if r.pks.keys == nil {
return nil, ErrorNoKeys
}
return KeySet(&r.pks), nil
}
package key
import (
"errors"
"log"
"time"
ptime "github.com/coreos/pkg/timeutil"
"github.com/jonboulle/clockwork"
)
var (
ErrorPrivateKeysExpired = errors.New("private keys have expired")
)
func NewPrivateKeyRotator(repo PrivateKeySetRepo, ttl time.Duration) *PrivateKeyRotator {
return &PrivateKeyRotator{
repo: repo,
ttl: ttl,
keep: 2,
generateKey: GeneratePrivateKey,
clock: clockwork.NewRealClock(),
}
}
type PrivateKeyRotator struct {
repo PrivateKeySetRepo
generateKey GeneratePrivateKeyFunc
clock clockwork.Clock
keep int
ttl time.Duration
}
func (r *PrivateKeyRotator) expiresAt() time.Time {
return r.clock.Now().UTC().Add(r.ttl)
}
func (r *PrivateKeyRotator) Healthy() error {
pks, err := r.privateKeySet()
if err != nil {
return err
}
if r.clock.Now().After(pks.ExpiresAt()) {
return ErrorPrivateKeysExpired
}
return nil
}
func (r *PrivateKeyRotator) privateKeySet() (*PrivateKeySet, error) {
ks, err := r.repo.Get()
if err != nil {
return nil, err
}
pks, ok := ks.(*PrivateKeySet)
if !ok {
return nil, errors.New("unable to cast to PrivateKeySet")
}
return pks, nil
}
func (r *PrivateKeyRotator) nextRotation() (time.Duration, error) {
pks, err := r.privateKeySet()
if err == ErrorNoKeys {
return 0, nil
}
if err != nil {
return 0, err
}
now := r.clock.Now()
// Ideally, we want to rotate after half the TTL has elapsed.
idealRotationTime := pks.ExpiresAt().Add(-r.ttl / 2)
// If we are past the ideal rotation time, rotate immediatly.
return max(0, idealRotationTime.Sub(now)), nil
}
func max(a, b time.Duration) time.Duration {
if a > b {
return a
}
return b
}
func (r *PrivateKeyRotator) Run() chan struct{} {
attempt := func() {
k, err := r.generateKey()
if err != nil {
log.Printf("go-oidc: failed generating signing key: %v", err)
return
}
exp := r.expiresAt()
if err := rotatePrivateKeys(r.repo, k, r.keep, exp); err != nil {
log.Printf("go-oidc: key rotation failed: %v", err)
return
}
}
stop := make(chan struct{})
go func() {
for {
var nextRotation time.Duration
var sleep time.Duration
var err error
for {
if nextRotation, err = r.nextRotation(); err == nil {
break
}
sleep = ptime.ExpBackoff(sleep, time.Minute)
log.Printf("go-oidc: error getting nextRotation, retrying in %v: %v", sleep, err)
time.Sleep(sleep)
}
select {
case <-r.clock.After(nextRotation):
attempt()
case <-stop:
return
}
}
}()
return stop
}
func rotatePrivateKeys(repo PrivateKeySetRepo, k *PrivateKey, keep int, exp time.Time) error {
ks, err := repo.Get()
if err != nil && err != ErrorNoKeys {
return err
}
var keys []*PrivateKey
if ks != nil {
pks, ok := ks.(*PrivateKeySet)
if !ok {
return errors.New("unable to cast to PrivateKeySet")
}
keys = pks.Keys()
}
keys = append([]*PrivateKey{k}, keys...)
if l := len(keys); l > keep {
keys = keys[0:keep]
}
nks := PrivateKeySet{
keys: keys,
ActiveKeyID: k.ID(),
expiresAt: exp,
}
return repo.Set(KeySet(&nks))
}
package key
import (
"reflect"
"testing"
"time"
"github.com/jonboulle/clockwork"
)
func generatePrivateKeySerialFunc(t *testing.T) GeneratePrivateKeyFunc {
var n int
return func() (*PrivateKey, error) {
n++
return generatePrivateKeyStatic(t, n), nil
}
}
func TestRotate(t *testing.T) {
now := time.Now()
k1 := generatePrivateKeyStatic(t, 1)
k2 := generatePrivateKeyStatic(t, 2)
k3 := generatePrivateKeyStatic(t, 3)
tests := []struct {
start *PrivateKeySet
key *PrivateKey
keep int
exp time.Time
want *PrivateKeySet
}{
// start with nil keys
{
start: nil,
key: k1,
keep: 2,
exp: now.Add(time.Second),
want: &PrivateKeySet{
keys: []*PrivateKey{k1},
ActiveKeyID: k1.KeyID,
expiresAt: now.Add(time.Second),
},
},
// start with zero keys
{
start: &PrivateKeySet{},
key: k1,
keep: 2,
exp: now.Add(time.Second),
want: &PrivateKeySet{
keys: []*PrivateKey{k1},
ActiveKeyID: k1.KeyID,
expiresAt: now.Add(time.Second),
},
},
// add second key
{
start: &PrivateKeySet{
keys: []*PrivateKey{k1},
ActiveKeyID: k1.KeyID,
expiresAt: now,
},
key: k2,
keep: 2,
exp: now.Add(time.Second),
want: &PrivateKeySet{
keys: []*PrivateKey{k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now.Add(time.Second),
},
},
// rotate in third key
{
start: &PrivateKeySet{
keys: []*PrivateKey{k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now,
},
key: k3,
keep: 2,
exp: now.Add(time.Second),
want: &PrivateKeySet{
keys: []*PrivateKey{k3, k2},
ActiveKeyID: k3.KeyID,
expiresAt: now.Add(time.Second),
},
},
}
for i, tt := range tests {
repo := NewPrivateKeySetRepo()
if tt.start != nil {
err := repo.Set(tt.start)
if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err)
}
}
rotatePrivateKeys(repo, tt.key, tt.keep, tt.exp)
got, err := repo.Get()
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
continue
}
if !reflect.DeepEqual(tt.want, got) {
t.Errorf("case %d: unexpected result: want=%#v got=%#v", i, tt.want, got)
}
}
}
func TestPrivateKeyRotatorRun(t *testing.T) {
fc := clockwork.NewFakeClock()
now := fc.Now().UTC()
k1 := generatePrivateKeyStatic(t, 1)
k2 := generatePrivateKeyStatic(t, 2)
k3 := generatePrivateKeyStatic(t, 3)
k4 := generatePrivateKeyStatic(t, 4)
kRepo := NewPrivateKeySetRepo()
krot := NewPrivateKeyRotator(kRepo, 4*time.Second)
krot.clock = fc
krot.generateKey = generatePrivateKeySerialFunc(t)
steps := []*PrivateKeySet{
&PrivateKeySet{
keys: []*PrivateKey{k1},
ActiveKeyID: k1.KeyID,
expiresAt: now.Add(4 * time.Second),
},
&PrivateKeySet{
keys: []*PrivateKey{k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now.Add(6 * time.Second),
},
&PrivateKeySet{
keys: []*PrivateKey{k3, k2},
ActiveKeyID: k3.KeyID,
expiresAt: now.Add(8 * time.Second),
},
&PrivateKeySet{
keys: []*PrivateKey{k4, k3},
ActiveKeyID: k4.KeyID,
expiresAt: now.Add(10 * time.Second),
},
}
stop := krot.Run()
defer close(stop)
for i, st := range steps {
// wait for the rotater to get sleepy
fc.BlockUntil(1)
got, err := kRepo.Get()
if err != nil {
t.Fatalf("step %d: unexpected error: %v", i, err)
}
if !reflect.DeepEqual(st, got) {
t.Fatalf("step %d: unexpected state: want=%#v got=%#v", i, st, got)
}
fc.Advance(2 * time.Second)
}
}
func TestPrivateKeyRotatorExpiresAt(t *testing.T) {
fc := clockwork.NewFakeClock()
krot := &PrivateKeyRotator{
clock: fc,
ttl: time.Minute,
}
got := krot.expiresAt()
want := fc.Now().UTC().Add(time.Minute)
if !reflect.DeepEqual(want, got) {
t.Errorf("Incorrect expiration time: want=%v got=%v", want, got)
}
}
func TestNextRotation(t *testing.T) {
fc := clockwork.NewFakeClock()
now := fc.Now().UTC()
tests := []struct {
expiresAt time.Time
ttl time.Duration
numKeys int
expected time.Duration
}{
{
// closest to prod
expiresAt: now.Add(time.Hour * 24),
ttl: time.Hour * 24,
numKeys: 2,
expected: time.Hour * 12,
},
{
expiresAt: now.Add(time.Hour * 2),
ttl: time.Hour * 4,
numKeys: 2,
expected: 0,
},
{
// No keys.
expiresAt: now.Add(time.Hour * 2),
ttl: time.Hour * 4,
numKeys: 0,
expected: 0,
},
{
// Nil keyset.
expiresAt: now.Add(time.Hour * 2),
ttl: time.Hour * 4,
numKeys: -1,
expected: 0,
},
{
// KeySet expired.
expiresAt: now.Add(time.Hour * -2),
ttl: time.Hour * 4,
numKeys: 2,
expected: 0,
},
{
// Expiry past now + TTL
expiresAt: now.Add(time.Hour * 5),
ttl: time.Hour * 4,
numKeys: 2,
expected: 3 * time.Hour,
},
}
for i, tt := range tests {
kRepo := NewPrivateKeySetRepo()
krot := NewPrivateKeyRotator(kRepo, tt.ttl)
krot.clock = fc
pks := &PrivateKeySet{
expiresAt: tt.expiresAt,
}
if tt.numKeys != -1 {
for n := 0; n < tt.numKeys; n++ {
pks.keys = append(pks.keys, generatePrivateKeyStatic(t, n))
}
err := kRepo.Set(pks)
if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err)
}
}
actual, err := krot.nextRotation()
if err != nil {
t.Errorf("case %d: error calling shouldRotate(): %v", i, err)
}
if actual != tt.expected {
t.Errorf("case %d: actual == %v, want %v", i, actual, tt.expected)
}
}
}
func TestHealthy(t *testing.T) {
fc := clockwork.NewFakeClock()
now := fc.Now().UTC()
tests := []struct {
expiresAt time.Time
numKeys int
expected error
}{
{
expiresAt: now.Add(time.Hour),
numKeys: 2,
expected: nil,
},
{
expiresAt: now.Add(time.Hour),
numKeys: -1,
expected: ErrorNoKeys,
},
{
expiresAt: now.Add(time.Hour),
numKeys: 0,
expected: ErrorNoKeys,
},
{
expiresAt: now.Add(-time.Hour),
numKeys: 2,
expected: ErrorPrivateKeysExpired,
},
}
for i, tt := range tests {
kRepo := NewPrivateKeySetRepo()
krot := NewPrivateKeyRotator(kRepo, time.Hour)
krot.clock = fc
pks := &PrivateKeySet{
expiresAt: tt.expiresAt,
}
if tt.numKeys != -1 {
for n := 0; n < tt.numKeys; n++ {
pks.keys = append(pks.keys, generatePrivateKeyStatic(t, n))
}
err := kRepo.Set(pks)
if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err)
}
}
if err := krot.Healthy(); err != tt.expected {
t.Errorf("case %d: got==%q, want==%q", i, err, tt.expected)
}
}
}
package key
import (
"errors"
"log"
"time"
"github.com/jonboulle/clockwork"
"github.com/coreos/pkg/timeutil"
)
func NewKeySetSyncer(r ReadableKeySetRepo, w WritableKeySetRepo) *KeySetSyncer {
return &KeySetSyncer{
readable: r,
writable: w,
clock: clockwork.NewRealClock(),
}
}
type KeySetSyncer struct {
readable ReadableKeySetRepo
writable WritableKeySetRepo
clock clockwork.Clock
}
func (s *KeySetSyncer) Run() chan struct{} {
stop := make(chan struct{})
go func() {
var failing bool
var next time.Duration
for {
exp, err := syncKeySet(s.readable, s.writable, s.clock)
if err != nil || exp == 0 {
if !failing {
failing = true
next = time.Second
} else {
next = timeutil.ExpBackoff(next, time.Minute)
}
if exp == 0 {
log.Printf("Synced to already expired key set, retrying in %v: %v", next, err)
} else {
log.Printf("Failed syncing key set, retrying in %v: %v", next, err)
}
} else {
failing = false
next = exp / 2
}
select {
case <-s.clock.After(next):
continue
case <-stop:
return
}
}
}()
return stop
}
func Sync(r ReadableKeySetRepo, w WritableKeySetRepo) (time.Duration, error) {
return syncKeySet(r, w, clockwork.NewRealClock())
}
// syncKeySet copies the keyset from r to the KeySet at w and returns the duration in which the KeySet will expire.
// If keyset has already expired, returns a zero duration.
func syncKeySet(r ReadableKeySetRepo, w WritableKeySetRepo, clock clockwork.Clock) (exp time.Duration, err error) {
var ks KeySet
ks, err = r.Get()
if err != nil {
return
}
if ks == nil {
err = errors.New("no source KeySet")
return
}
if err = w.Set(ks); err != nil {
return
}
now := clock.Now()
if ks.ExpiresAt().After(now) {
exp = ks.ExpiresAt().Sub(now)
}
return
}
package key
import (
"errors"
"reflect"
"sync"
"testing"
"time"
"github.com/jonboulle/clockwork"
)
type staticReadableKeySetRepo struct {
mu sync.RWMutex
ks KeySet
err error
}
func (r *staticReadableKeySetRepo) Get() (KeySet, error) {
r.mu.RLock()
defer r.mu.RUnlock()
return r.ks, r.err
}
func (r *staticReadableKeySetRepo) set(ks KeySet, err error) {
r.mu.Lock()
defer r.mu.Unlock()
r.ks, r.err = ks, err
}
func TestKeySyncerSync(t *testing.T) {
fc := clockwork.NewFakeClock()
now := fc.Now().UTC()
k1 := generatePrivateKeyStatic(t, 1)
k2 := generatePrivateKeyStatic(t, 2)
k3 := generatePrivateKeyStatic(t, 3)
steps := []struct {
fromKS KeySet
fromErr error
advance time.Duration
want *PrivateKeySet
}{
// on startup, first sync should trigger within a second
{
fromKS: &PrivateKeySet{
keys: []*PrivateKey{k1},
ActiveKeyID: k1.KeyID,
expiresAt: now.Add(10 * time.Second),
},
advance: time.Second,
want: &PrivateKeySet{
keys: []*PrivateKey{k1},
ActiveKeyID: k1.KeyID,
expiresAt: now.Add(10 * time.Second),
},
},
// advance halfway into TTL, triggering sync
{
fromKS: &PrivateKeySet{
keys: []*PrivateKey{k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now.Add(15 * time.Second),
},
advance: 5 * time.Second,
want: &PrivateKeySet{
keys: []*PrivateKey{k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now.Add(15 * time.Second),
},
},
// advance halfway into TTL, triggering sync that fails
{
fromErr: errors.New("fail!"),
advance: 10 * time.Second,
want: &PrivateKeySet{
keys: []*PrivateKey{k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now.Add(15 * time.Second),
},
},
// sync retries quickly, and succeeds with fixed data
{
fromKS: &PrivateKeySet{
keys: []*PrivateKey{k3, k2, k1},
ActiveKeyID: k3.KeyID,
expiresAt: now.Add(25 * time.Second),
},
advance: 3 * time.Second,
want: &PrivateKeySet{
keys: []*PrivateKey{k3, k2, k1},
ActiveKeyID: k3.KeyID,
expiresAt: now.Add(25 * time.Second),
},
},
}
from := &staticReadableKeySetRepo{}
to := NewPrivateKeySetRepo()
syncer := NewKeySetSyncer(from, to)
syncer.clock = fc
stop := syncer.Run()
defer close(stop)
for i, st := range steps {
from.set(st.fromKS, st.fromErr)
fc.Advance(st.advance)
fc.BlockUntil(1)
ks, err := to.Get()
if err != nil {
t.Fatalf("step %d: unable to get keys: %v", i, err)
}
if !reflect.DeepEqual(st.want, ks) {
t.Fatalf("step %d: incorrect state: want=%#v got=%#v", i, st.want, ks)
}
}
}
func TestSync(t *testing.T) {
fc := clockwork.NewFakeClock()
now := fc.Now().UTC()
k1 := generatePrivateKeyStatic(t, 1)
k2 := generatePrivateKeyStatic(t, 2)
k3 := generatePrivateKeyStatic(t, 3)
tests := []struct {
keySet *PrivateKeySet
want time.Duration
}{
{
keySet: &PrivateKeySet{
keys: []*PrivateKey{k1},
ActiveKeyID: k1.KeyID,
expiresAt: now.Add(time.Minute),
},
want: time.Minute,
},
{
keySet: &PrivateKeySet{
keys: []*PrivateKey{k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now.Add(time.Minute),
},
want: time.Minute,
},
{
keySet: &PrivateKeySet{
keys: []*PrivateKey{k3, k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now.Add(time.Minute),
},
want: time.Minute,
},
{
keySet: &PrivateKeySet{
keys: []*PrivateKey{k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now.Add(time.Hour),
},
want: time.Hour,
},
{
keySet: &PrivateKeySet{
keys: []*PrivateKey{k1},
ActiveKeyID: k1.KeyID,
expiresAt: now.Add(-time.Hour),
},
want: 0,
},
}
for i, tt := range tests {
from := NewPrivateKeySetRepo()
to := NewPrivateKeySetRepo()
err := from.Set(tt.keySet)
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
continue
}
exp, err := syncKeySet(from, to, fc)
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
continue
}
if tt.want != exp {
t.Errorf("case %d: want=%v got=%v", i, tt.want, exp)
}
}
}
func TestSyncFail(t *testing.T) {
tests := []error{
nil,
errors.New("fail!"),
}
for i, tt := range tests {
from := &staticReadableKeySetRepo{ks: nil, err: tt}
to := NewPrivateKeySetRepo()
if _, err := syncKeySet(from, to, clockwork.NewFakeClock()); err == nil {
t.Errorf("case %d: expected non-nil error", i)
}
}
}
// Package oauth2 is DEPRECATED. Use golang.org/x/oauth instead.
package oauth2
package oauth2
const (
ErrorAccessDenied = "access_denied"
ErrorInvalidClient = "invalid_client"
ErrorInvalidGrant = "invalid_grant"
ErrorInvalidRequest = "invalid_request"
ErrorServerError = "server_error"
ErrorUnauthorizedClient = "unauthorized_client"
ErrorUnsupportedGrantType = "unsupported_grant_type"
ErrorUnsupportedResponseType = "unsupported_response_type"
)
type Error struct {
Type string `json:"error"`
Description string `json:"error_description,omitempty"`
State string `json:"state,omitempty"`
}
func (e *Error) Error() string {
if e.Description != "" {
return e.Type + ": " + e.Description
}
return e.Type
}
func NewError(typ string) *Error {
return &Error{Type: typ}
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
// This file contains tests which depend on the race detector being enabled.
// +build race
package oidc
import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
)
type testProvider struct {
baseURL *url.URL
}
func (p *testProvider) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != discoveryConfigPath {
http.NotFound(w, r)
return
}
cfg := ProviderConfig{
Issuer: p.baseURL,
ExpiresAt: time.Now().Add(time.Second),
}
cfg = fillRequiredProviderFields(cfg)
json.NewEncoder(w).Encode(&cfg)
}
// This test fails by triggering the race detector, not by calling t.Error or t.Fatal.
func TestProviderSyncRace(t *testing.T) {
prov := &testProvider{}
s := httptest.NewServer(prov)
defer s.Close()
u, err := url.Parse(s.URL)
if err != nil {
t.Fatal(err)
}
prov.baseURL = u
prevValue := minimumProviderConfigSyncInterval
defer func() { minimumProviderConfigSyncInterval = prevValue }()
// Reduce the sync interval to increase the write frequencey.
minimumProviderConfigSyncInterval = 5 * time.Millisecond
cliCfg := ClientConfig{
HTTPClient: http.DefaultClient,
}
cli, err := NewClient(cliCfg)
if err != nil {
t.Error(err)
return
}
if !cli.providerConfig.Get().Empty() {
t.Errorf("want c.ProviderConfig == nil, got c.ProviderConfig=%#v")
}
// SyncProviderConfig beings a goroutine which writes to the client's provider config.
c := cli.SyncProviderConfig(s.URL)
if cli.providerConfig.Get().Empty() {
t.Errorf("want c.ProviderConfig != nil")
}
defer func() {
// stop the background process
c <- struct{}{}
}()
for i := 0; i < 10; i++ {
time.Sleep(5 * time.Millisecond)
// Creating an OAuth client reads from the provider config.
cli.OAuthClient()
}
}
This diff is collapsed.
// Package oidc is DEPRECATED. Use github.com/coreos/go-oidc instead.
package oidc
package oidc
import (
"errors"
"time"
"github.com/coreos/go-oidc/jose"
)
type Identity struct {
ID string
Name string
Email string
ExpiresAt time.Time
}
func IdentityFromClaims(claims jose.Claims) (*Identity, error) {
if claims == nil {
return nil, errors.New("nil claim set")
}
var ident Identity
var err error
var ok bool
if ident.ID, ok, err = claims.StringClaim("sub"); err != nil {
return nil, err
} else if !ok {
return nil, errors.New("missing required claim: sub")
}
if ident.Email, _, err = claims.StringClaim("email"); err != nil {
return nil, err
}
exp, ok, err := claims.TimeClaim("exp")
if err != nil {
return nil, err
} else if ok {
ident.ExpiresAt = exp
}
return &ident, nil
}
This diff is collapsed.
package oidc
type LoginFunc func(ident Identity, sessionKey string) (redirectURL string, err error)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment