Commit d836fe72 authored by gotwarlost's avatar gotwarlost

add id token support to verify access token hashes, fixes #126

parent 77e7f201
......@@ -3,9 +3,13 @@ package oidc
import (
"context"
"crypto/sha256"
"crypto/sha512"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"hash"
"io/ioutil"
"mime"
"net/http"
......@@ -31,6 +35,11 @@ const (
ScopeOfflineAccess = "offline_access"
)
var (
errNoAtHash = errors.New("id token did not have an access token hash")
errInvalidAtHash = errors.New("access token hash does not match value in ID token")
)
// ClientContext returns a new Context that carries the provided HTTP client.
//
// This method sets the same context key used by the golang.org/x/oauth2 package,
......@@ -242,6 +251,14 @@ type IDToken struct {
// and it's the user's responsibility to ensure it contains a valid value.
Nonce string
// at_hash claim, if set in the ID token. Callers can verify an access token
// that corresponds to the ID token using the VerifyAccessToken method.
AccessTokenHash string
// signature algorithm used for ID token, needed to compute a verification hash of an
// access token
sigAlgorithm string
// Raw payload of the id_token.
claims []byte
}
......@@ -267,6 +284,34 @@ func (i *IDToken) Claims(v interface{}) error {
return json.Unmarshal(i.claims, v)
}
// VerifyAccessToken verifies that the hash of the access token that corresponds to the iD token
// matches the hash in the id token. It returns an error if the hashes don't match.
// It is the caller's responsibility to ensure that the optional access token hash is present for the ID token
// before calling this method. See https://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken
func (i *IDToken) VerifyAccessToken(accessToken string) error {
if i.AccessTokenHash == "" {
return errNoAtHash
}
var h hash.Hash
switch i.sigAlgorithm {
case RS256, ES256, PS256:
h = sha256.New()
case RS384, ES384, PS384:
h = sha512.New384()
case RS512, ES512, PS512:
h = sha512.New()
default:
return fmt.Errorf("oidc: unsupported signing algorithm %q", i.sigAlgorithm)
}
h.Write([]byte(accessToken)) // hash documents that Write will never return an error
sum := h.Sum(nil)[:h.Size()/2]
actual := base64.RawURLEncoding.EncodeToString(sum)
if actual != i.AccessTokenHash {
return errInvalidAtHash
}
return nil
}
type idToken struct {
Issuer string `json:"iss"`
Subject string `json:"sub"`
......@@ -274,6 +319,7 @@ type idToken struct {
Expiry jsonTime `json:"exp"`
IssuedAt jsonTime `json:"iat"`
Nonce string `json:"nonce"`
AtHash string `json:"at_hash"`
}
type audience []string
......
package oidc
import (
"fmt"
"testing"
)
const (
// at_hash value and access_token returned by Google.
googleAccessTokenHash = "piwt8oCH-K2D9pXlaS1Y-w"
googleAccessToken = "ya29.CjHSA1l5WUn8xZ6HanHFzzdHdbXm-14rxnC7JHch9eFIsZkQEGoWzaYG4o7k5f6BnPLj"
googleSigningAlg = RS256
// following values computed by own algo for regression testing
computed384TokenHash = "_ILKVQjbEzFKNJjUKC2kz9eReYi0A9Of"
computed512TokenHash = "Spa_APgwBrarSeQbxI-rbragXho6dqFpH5x9PqaPfUI"
)
type accessTokenTest struct {
name string
tok *IDToken
accessToken string
verifier func(err error) error
}
func (a accessTokenTest) run(t *testing.T) {
err := a.tok.VerifyAccessToken(a.accessToken)
result := a.verifier(err)
if result != nil {
t.Error(result)
}
}
func TestAccessTokenVerification(t *testing.T) {
newToken := func(alg, atHash string) *IDToken {
return &IDToken{sigAlgorithm: alg, AccessTokenHash: atHash}
}
assertNil := func(err error) error {
if err != nil {
return fmt.Errorf("want nil error, got %v", err)
}
return nil
}
assertMsg := func(msg string) func(err error) error {
return func(err error) error {
if err == nil {
return fmt.Errorf("expected error, got success")
}
if err.Error() != msg {
return fmt.Errorf("bad error message, %q, (want %q)", err.Error(), msg)
}
return nil
}
}
tests := []accessTokenTest{
{
"goodRS256",
newToken(googleSigningAlg, googleAccessTokenHash),
googleAccessToken,
assertNil,
},
{
"goodES384",
newToken("ES384", computed384TokenHash),
googleAccessToken,
assertNil,
},
{
"goodPS512",
newToken("PS512", computed512TokenHash),
googleAccessToken,
assertNil,
},
{
"badRS256",
newToken("RS256", computed512TokenHash),
googleAccessToken,
assertMsg("access token hash does not match value in ID token"),
},
{
"nohash",
newToken("RS256", ""),
googleAccessToken,
assertMsg("id token did not have an access token hash"),
},
{
"badSignAlgo",
newToken("none", "xxx"),
googleAccessToken,
assertMsg(`oidc: unsupported signing algorithm "none"`),
},
}
for _, test := range tests {
t.Run(test.name, test.run)
}
}
......@@ -170,13 +170,14 @@ func (v *IDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (*IDTok
}
t := &IDToken{
Issuer: token.Issuer,
Subject: token.Subject,
Audience: []string(token.Audience),
Expiry: time.Time(token.Expiry),
IssuedAt: time.Time(token.IssuedAt),
Nonce: token.Nonce,
claims: payload,
Issuer: token.Issuer,
Subject: token.Subject,
Audience: []string(token.Audience),
Expiry: time.Time(token.Expiry),
IssuedAt: time.Time(token.IssuedAt),
Nonce: token.Nonce,
AccessTokenHash: token.AtHash,
claims: payload,
}
// Check issuer.
......@@ -228,6 +229,7 @@ func (v *IDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (*IDTok
if len(v.config.SupportedSigningAlgs) != 0 && !contains(v.config.SupportedSigningAlgs, sig.Header.Algorithm) {
return nil, fmt.Errorf("oidc: id token signed with unsupported algorithm, expected %q got %q", v.config.SupportedSigningAlgs, sig.Header.Algorithm)
}
t.sigAlgorithm = sig.Header.Algorithm
gotPayload, err := v.keySet.VerifySignature(ctx, rawIDToken)
if err != nil {
......
......@@ -192,6 +192,30 @@ func TestVerifySigningAlg(t *testing.T) {
}
}
func TestAccessTokenHash(t *testing.T) {
atHash := "piwt8oCH-K2D9pXlaS1Y-w"
vt := verificationTest{
name: "preserves token hash and sig algo",
idToken: `{"iss":"https://foo","aud":"client1", "at_hash": "` + atHash + `"}`,
config: Config{
ClientID: "client1",
SkipExpiryCheck: true,
},
signKey: newRSAKey(t),
}
t.Run("at_hash", func(t *testing.T) {
tok := vt.runGetToken(t)
if tok != nil {
if tok.AccessTokenHash != atHash {
t.Errorf("access token hash not preserved correctly, want %q got %q", atHash, tok.AccessTokenHash)
}
if tok.sigAlgorithm != RS256 {
t.Errorf("invalid signature algo, want %q got %q", RS256, tok.sigAlgorithm)
}
}
})
}
type verificationTest struct {
// Name of the subtest.
name string
......@@ -212,7 +236,7 @@ type verificationTest struct {
wantErr bool
}
func (v verificationTest) run(t *testing.T) {
func (v verificationTest) runGetToken(t *testing.T) *IDToken {
token := v.signKey.sign(t, []byte(v.idToken))
ctx, cancel := context.WithCancel(context.Background())
......@@ -230,7 +254,8 @@ func (v verificationTest) run(t *testing.T) {
}
verifier := newVerifier(ks, &v.config, issuer)
if _, err := verifier.Verify(ctx, token); err != nil {
idToken, err := verifier.Verify(ctx, token)
if err != nil {
if !v.wantErr {
t.Errorf("%s: verify %v", v.name, err)
}
......@@ -239,4 +264,9 @@ func (v verificationTest) run(t *testing.T) {
t.Errorf("%s: expected error", v.name)
}
}
return idToken
}
func (v verificationTest) run(t *testing.T) {
v.runGetToken(t)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment