Unverified Commit 1bddd0c5 authored by Eric Chiang's avatar Eric Chiang Committed by GitHub

Merge pull request #172 from Sean-Q-Sun/distributed-claims

Resolve distributed claims in idToken
parents 1180514e e7de8122
......@@ -261,6 +261,9 @@ type IDToken struct {
// Raw payload of the id_token.
claims []byte
// Map of distributed claim names to claim sources
distributedClaims map[string]claimSource
}
// Claims unmarshals the raw JSON payload of the ID Token into a provided struct.
......@@ -313,13 +316,20 @@ func (i *IDToken) VerifyAccessToken(accessToken string) error {
}
type idToken struct {
Issuer string `json:"iss"`
Subject string `json:"sub"`
Audience audience `json:"aud"`
Expiry jsonTime `json:"exp"`
IssuedAt jsonTime `json:"iat"`
Nonce string `json:"nonce"`
AtHash string `json:"at_hash"`
Issuer string `json:"iss"`
Subject string `json:"sub"`
Audience audience `json:"aud"`
Expiry jsonTime `json:"exp"`
IssuedAt jsonTime `json:"iat"`
Nonce string `json:"nonce"`
AtHash string `json:"at_hash"`
ClaimNames map[string]string `json:"_claim_names"`
ClaimSources map[string]claimSource `json:"_claim_sources"`
}
type claimSource struct {
Endpoint string `json:"endpoint"`
AccessToken string `json:"access_token"`
}
type audience []string
......
......@@ -155,15 +155,30 @@ func (v *IDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (*IDTok
return nil, fmt.Errorf("oidc: failed to unmarshal claims: %v", err)
}
distributedClaims := make(map[string]claimSource)
//step through the token to map claim names to claim sources"
for cn, src := range token.ClaimNames {
if src == "" {
return nil, fmt.Errorf("oidc: failed to obtain source from claim name")
}
s, ok := token.ClaimSources[src]
if !ok {
return nil, fmt.Errorf("oidc: source does not exist")
}
distributedClaims[cn] = s
}
t := &IDToken{
Issuer: token.Issuer,
Subject: token.Subject,
Audience: []string(token.Audience),
Expiry: time.Time(token.Expiry),
IssuedAt: time.Time(token.IssuedAt),
Nonce: token.Nonce,
AccessTokenHash: token.AtHash,
claims: payload,
Issuer: token.Issuer,
Subject: token.Subject,
Audience: []string(token.Audience),
Expiry: time.Time(token.Expiry),
IssuedAt: time.Time(token.IssuedAt),
Nonce: token.Nonce,
AccessTokenHash: token.AtHash,
claims: payload,
distributedClaims: distributedClaims,
}
// Check issuer.
......
......@@ -3,6 +3,7 @@ package oidc
import (
"context"
"fmt"
"reflect"
"strconv"
"testing"
"time"
......@@ -204,18 +205,143 @@ func TestAccessTokenHash(t *testing.T) {
signKey: newRSAKey(t),
}
t.Run("at_hash", func(t *testing.T) {
tok := vt.runGetToken(t)
if tok != nil {
if tok.AccessTokenHash != atHash {
t.Errorf("access token hash not preserved correctly, want %q got %q", atHash, tok.AccessTokenHash)
}
if tok.sigAlgorithm != RS256 {
t.Errorf("invalid signature algo, want %q got %q", RS256, tok.sigAlgorithm)
}
tok, err := vt.runGetToken(t)
if err != nil {
t.Errorf("parsing token: %v", err)
return
}
if tok.AccessTokenHash != atHash {
t.Errorf("access token hash not preserved correctly, want %q got %q", atHash, tok.AccessTokenHash)
}
if tok.sigAlgorithm != RS256 {
t.Errorf("invalid signature algo, want %q got %q", RS256, tok.sigAlgorithm)
}
})
}
func TestDistributedClaims(t *testing.T) {
tests := []struct {
test verificationTest
want map[string]claimSource
wantErr bool
}{
{
test: verificationTest{
name: "NoDistClaims",
idToken: `{"iss":"https://foo","aud":"client1"}`,
config: Config{
ClientID: "client1",
SkipExpiryCheck: true,
},
signKey: newRSAKey(t),
},
want: map[string]claimSource{},
},
{
test: verificationTest{
name: "1DistClaim",
idToken: `{
"iss":"https://foo","aud":"client1",
"_claim_names": {
"address": "src1"
},
"_claim_sources": {
"src1": {"endpoint": "123", "access_token":"1234"}
}
}`,
config: Config{
ClientID: "client1",
SkipExpiryCheck: true,
},
signKey: newRSAKey(t),
},
want: map[string]claimSource{
"address": claimSource{Endpoint: "123", AccessToken: "1234"},
},
},
{
test: verificationTest{
name: "2DistClaims1Src",
idToken: `{
"iss":"https://foo","aud":"client1",
"_claim_names": {
"address": "src1",
"phone_number": "src1"
},
"_claim_sources": {
"src1": {"endpoint": "123", "access_token":"1234"}
}
}`,
config: Config{
ClientID: "client1",
SkipExpiryCheck: true,
},
signKey: newRSAKey(t),
},
want: map[string]claimSource{
"address": claimSource{Endpoint: "123", AccessToken: "1234"},
"phone_number": claimSource{Endpoint: "123", AccessToken: "1234"},
},
},
{
test: verificationTest{
name: "1Name0Src",
idToken: `{
"iss":"https://foo","aud":"client1",
"_claim_names": {
"address": "src1"
},
"_claim_sources": {
}
}`,
config: Config{
ClientID: "client1",
SkipExpiryCheck: true,
},
signKey: newRSAKey(t),
},
wantErr: true,
},
{
test: verificationTest{
name: "NoNames1Src",
idToken: `{
"iss":"https://foo","aud":"client1",
"_claim_names": {
},
"_claim_sources": {
"src1": {"endpoint": "https://foo", "access_token":"1234"}
}
}`,
config: Config{
ClientID: "client1",
SkipExpiryCheck: true,
},
signKey: newRSAKey(t),
},
want: map[string]claimSource{},
},
}
for _, test := range tests {
t.Run(test.test.name, func(t *testing.T) {
idToken, err := test.test.runGetToken(t)
if err != nil {
if !test.wantErr {
t.Errorf("parsing token: %v", err)
}
return
}
if test.wantErr {
t.Errorf("expected error parsing token")
return
}
if !reflect.DeepEqual(idToken.distributedClaims, test.want) {
t.Errorf("expected distributed claim: %#v, got: %#v", test.want, idToken.distributedClaims)
}
})
}
}
type verificationTest struct {
// Name of the subtest.
name string
......@@ -236,7 +362,7 @@ type verificationTest struct {
wantErr bool
}
func (v verificationTest) runGetToken(t *testing.T) *IDToken {
func (v verificationTest) runGetToken(t *testing.T) (*IDToken, error) {
token := v.signKey.sign(t, []byte(v.idToken))
ctx, cancel := context.WithCancel(context.Background())
......@@ -254,19 +380,15 @@ func (v verificationTest) runGetToken(t *testing.T) *IDToken {
}
verifier := NewVerifier(issuer, ks, &v.config)
idToken, err := verifier.Verify(ctx, token)
if err != nil {
if !v.wantErr {
t.Errorf("%s: verify %v", v.name, err)
}
} else {
if v.wantErr {
t.Errorf("%s: expected error", v.name)
}
}
return idToken
return verifier.Verify(ctx, token)
}
func (v verificationTest) run(t *testing.T) {
v.runGetToken(t)
_, err := v.runGetToken(t)
if err != nil && !v.wantErr {
t.Errorf("%v", err)
}
if err == nil && v.wantErr {
t.Errorf("expected error")
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment