Commit b24351ff authored by Sean Sun's avatar Sean Sun

Verify distributed claim endpoints

parent 1bddd0c5
......@@ -7,6 +7,8 @@ import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"strings"
"time"
......@@ -118,6 +120,53 @@ func contains(sli []string, ele string) bool {
return false
}
// Returns the Claims from the distributed JWT token
func resolveDistributedClaim(ctx context.Context, verifier *IDTokenVerifier, src claimSource) ([]byte, error) {
req, err := http.NewRequest("GET", src.Endpoint, nil)
if err != nil {
return nil, fmt.Errorf("malformed request: %v", err)
}
if src.AccessToken != "" {
req.Header.Set("Authorization", "Bearer "+src.AccessToken)
}
resp, err := doRequest(ctx, req)
if err != nil {
return nil, fmt.Errorf("oidc: Request to endpoint failed: %v", err)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("unable to read response body: %v", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("oidc: request failed: %v", resp.StatusCode)
}
token, err := verifier.Verify(ctx, string(body))
if err != nil {
return nil, fmt.Errorf("malformed response body: %v", err)
}
return token.claims, nil
}
func parseClaim(raw []byte, name string, v interface{}) error {
var parsed map[string]json.RawMessage
if err := json.Unmarshal(raw, &parsed); err != nil {
return err
}
val, ok := parsed[name]
if !ok {
return fmt.Errorf("claim doesn't exist: %s", name)
}
return json.Unmarshal([]byte(val), v)
}
// Verify parses a raw ID Token, verifies it's been signed by the provider, preforms
// any additional checks depending on the Config, and returns the payload.
//
......
......@@ -3,6 +3,9 @@ package oidc
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"reflect"
"strconv"
"testing"
......@@ -342,6 +345,150 @@ func TestDistributedClaims(t *testing.T) {
}
}
func TestDistClaimResolver(t *testing.T) {
tests := []resolverTest{
{
name: "noAccessToken",
payload: `{"iss":"https://foo","aud":"client1",
"email":"janedoe@email.com",
"shipping_address": {
"street_address": "1234 Hollywood Blvd.",
"locality": "Los Angeles",
"region": "CA",
"postal_code": "90210",
"country": "US"}
}`,
config: Config{
ClientID: "client1",
SkipExpiryCheck: true,
},
signKey: newRSAKey(t),
issuer: "https://foo",
want: map[string]claimSource{},
},
{
name: "rightAccessToken",
payload: `{"iss":"https://foo","aud":"client1",
"email":"janedoe@email.com",
"shipping_address": {
"street_address": "1234 Hollywood Blvd.",
"locality": "Los Angeles",
"region": "CA",
"postal_code": "90210",
"country": "US"}
}`,
config: Config{
ClientID: "client1",
SkipExpiryCheck: true,
},
signKey: newRSAKey(t),
accessToken: "1234",
issuer: "https://foo",
want: map[string]claimSource{},
},
{
name: "wrongAccessToken",
payload: `{"iss":"https://foo","aud":"client1",
"email":"janedoe@email.com",
"shipping_address": {
"street_address": "1234 Hollywood Blvd.",
"locality": "Los Angeles",
"region": "CA",
"postal_code": "90210",
"country": "US"}
}`,
config: Config{
ClientID: "client1",
SkipExpiryCheck: true,
},
signKey: newRSAKey(t),
accessToken: "12345",
issuer: "https://foo",
wantErr: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
claims, err := test.testEndpoint(t)
if err != nil {
if !test.wantErr {
t.Errorf("%v", err)
}
return
}
if test.wantErr {
t.Errorf("expected error receiving response")
return
}
if !reflect.DeepEqual(string(claims), test.payload) {
t.Errorf("expected dist claim: %#v, got: %v", test.payload, string(claims))
}
})
}
}
type resolverTest struct {
// Name of the subtest.
name string
// issuer will be the endpoint server url
issuer string
// just the payload
payload string
// Key to sign the ID Token with.
signKey *signingKey
// If not provided defaults to signKey. Only useful when
// testing invalid signatures.
verificationKey *signingKey
config Config
wantErr bool
want map[string]claimSource
//this is the access token that the testEndpoint will accept
accessToken string
}
func (v resolverTest) testEndpoint(t *testing.T) ([]byte, error) {
token := v.signKey.sign(t, []byte(v.payload))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got := r.Header.Get("Authorization")
if got != "" && got != "Bearer 1234" {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
io.WriteString(w, token)
}))
defer s.Close()
issuer := v.issuer
var ks KeySet
if v.verificationKey == nil {
ks = &testVerifier{v.signKey.jwk()}
} else {
ks = &testVerifier{v.verificationKey.jwk()}
}
verifier := NewVerifier(issuer, ks, &v.config)
ctx = ClientContext(ctx, s.Client())
src := claimSource{
Endpoint: s.URL + "/",
AccessToken: v.accessToken,
}
return resolveDistributedClaim(ctx, verifier, src)
}
type verificationTest struct {
// Name of the subtest.
name string
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment