Commit ba5b09f7 authored by Jukka-Pekka Kekkonen's avatar Jukka-Pekka Kekkonen Committed by Russ Cox

crypto/hmac: make Sum idempotent

Fixes #978.

R=rsc
CC=golang-dev
https://golang.org/cl/1967045
parent 33cb4690
......@@ -34,10 +34,9 @@ const (
)
type hmac struct {
size int
key []byte
tmp []byte
inner hash.Hash
size int
key, tmp []byte
outer, inner hash.Hash
}
func (h *hmac) tmpPad(xor byte) {
......@@ -50,14 +49,14 @@ func (h *hmac) tmpPad(xor byte) {
}
func (h *hmac) Sum() []byte {
h.tmpPad(0x5c)
sum := h.inner.Sum()
h.tmpPad(0x5c)
for i, b := range sum {
h.tmp[padSize+i] = b
}
h.inner.Reset()
h.inner.Write(h.tmp)
return h.inner.Sum()
h.outer.Reset()
h.outer.Write(h.tmp)
return h.outer.Sum()
}
func (h *hmac) Write(p []byte) (n int, err os.Error) {
......@@ -72,27 +71,26 @@ func (h *hmac) Reset() {
h.inner.Write(h.tmp[0:padSize])
}
// New returns a new HMAC hash using the given hash and key.
func New(h hash.Hash, key []byte) hash.Hash {
// New returns a new HMAC hash using the given hash generator and key.
func New(h func() hash.Hash, key []byte) hash.Hash {
hm := new(hmac)
hm.outer = h()
hm.inner = h()
hm.size = hm.inner.Size()
hm.tmp = make([]byte, padSize+hm.size)
if len(key) > padSize {
// If key is too big, hash it.
h.Write(key)
key = h.Sum()
hm.outer.Write(key)
key = hm.outer.Sum()
}
hm := new(hmac)
hm.inner = h
hm.size = h.Size()
hm.key = make([]byte, len(key))
for i, k := range key {
hm.key[i] = k
}
hm.tmp = make([]byte, padSize+hm.size)
copy(hm.key, key)
hm.Reset()
return hm
}
// NewMD5 returns a new HMAC-MD5 hash using the given key.
func NewMD5(key []byte) hash.Hash { return New(md5.New(), key) }
func NewMD5(key []byte) hash.Hash { return New(md5.New, key) }
// NewSHA1 returns a new HMAC-SHA1 hash using the given key.
func NewSHA1(key []byte) hash.Hash { return New(sha1.New(), key) }
func NewSHA1(key []byte) hash.Hash { return New(sha1.New, key) }
......@@ -84,9 +84,13 @@ func TestHMAC(t *testing.T) {
t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
continue
}
sum := fmt.Sprintf("%x", h.Sum())
if sum != tt.out {
t.Errorf("test %d.%d: have %s want %s\n", i, j, sum, tt.out)
// Repetive Sum() calls should return the same value
for k := 0; k < 2; k++ {
sum := fmt.Sprintf("%x", h.Sum())
if sum != tt.out {
t.Errorf("test %d.%d.%d: have %s want %s\n", i, j, k, sum, tt.out)
}
}
// Second iteration: make sure reset works.
......
......@@ -8,7 +8,6 @@ import (
"crypto/hmac"
"crypto/rc4"
"crypto/rsa"
"crypto/sha1"
"crypto/subtle"
"crypto/x509"
"io"
......@@ -226,7 +225,7 @@ func (c *Conn) clientHandshake() os.Error {
cipher, _ := rc4.NewCipher(clientKey)
c.out.prepareCipherSpec(cipher, hmac.New(sha1.New(), clientMAC))
c.out.prepareCipherSpec(cipher, hmac.NewSHA1(clientMAC))
c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
finished := new(finishedMsg)
......@@ -235,7 +234,7 @@ func (c *Conn) clientHandshake() os.Error {
c.writeRecord(recordTypeHandshake, finished.marshal())
cipher2, _ := rc4.NewCipher(serverKey)
c.in.prepareCipherSpec(cipher2, hmac.New(sha1.New(), serverMAC))
c.in.prepareCipherSpec(cipher2, hmac.NewSHA1(serverMAC))
c.readRecord(recordTypeChangeCipherSpec)
if c.err != nil {
return c.err
......
......@@ -16,7 +16,6 @@ import (
"crypto/hmac"
"crypto/rc4"
"crypto/rsa"
"crypto/sha1"
"crypto/subtle"
"crypto/x509"
"io"
......@@ -227,7 +226,7 @@ func (c *Conn) serverHandshake() os.Error {
keysFromPreMasterSecret11(preMasterSecret, clientHello.random, hello.random, suite.hashLength, suite.cipherKeyLength)
cipher, _ := rc4.NewCipher(clientKey)
c.in.prepareCipherSpec(cipher, hmac.New(sha1.New(), clientMAC))
c.in.prepareCipherSpec(cipher, hmac.NewSHA1(clientMAC))
c.readRecord(recordTypeChangeCipherSpec)
if err := c.error(); err != nil {
return err
......@@ -264,7 +263,7 @@ func (c *Conn) serverHandshake() os.Error {
finishedHash.Write(clientFinished.marshal())
cipher2, _ := rc4.NewCipher(serverKey)
c.out.prepareCipherSpec(cipher2, hmac.New(sha1.New(), serverMAC))
c.out.prepareCipherSpec(cipher2, hmac.NewSHA1(serverMAC))
c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
finished := new(finishedMsg)
......
......@@ -20,7 +20,7 @@ func splitPreMasterSecret(secret []byte) (s1, s2 []byte) {
}
// pHash implements the P_hash function, as defined in RFC 4346, section 5.
func pHash(result, secret, seed []byte, hash hash.Hash) {
func pHash(result, secret, seed []byte, hash func() hash.Hash) {
h := hmac.New(hash, secret)
h.Write(seed)
a := h.Sum()
......@@ -46,8 +46,8 @@ func pHash(result, secret, seed []byte, hash hash.Hash) {
// pRF11 implements the TLS 1.1 pseudo-random function, as defined in RFC 4346, section 5.
func pRF11(result, secret, label, seed []byte) {
hashSHA1 := sha1.New()
hashMD5 := md5.New()
hashSHA1 := sha1.New
hashMD5 := md5.New
labelAndSeed := make([]byte, len(label)+len(seed))
copy(labelAndSeed, label)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment