Commit 442fc736 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

http2/hpack: fix nil pointer dereference crash in huffman decoder

Return an error instead.

Fixes bradfitz/http2#56

Change-Id: I3d1e80a214a8635932479943f0ef9610ee02b233
Reviewed-on: https://go-review.googlesource.com/15738Reviewed-by: 's avatarAndrew Gerrand <adg@golang.org>
parent d8f3c68d
......@@ -473,12 +473,10 @@ func readString(p []byte, wantStr bool) (s string, remain []byte, err error) {
}
if wantStr {
// TODO: optimize this garbage:
var buf bytes.Buffer
if _, err := HuffmanDecode(&buf, p[:strLen]); err != nil {
s, err = HuffmanDecodeToString(p[:strLen])
if err != nil {
return "", nil, err
}
s = buf.String()
}
return s, p[strLen:], nil
}
......@@ -10,11 +10,13 @@ import (
"bytes"
"encoding/hex"
"fmt"
"math/rand"
"reflect"
"regexp"
"strconv"
"strings"
"testing"
"time"
)
func TestStaticTable(t *testing.T) {
......@@ -582,6 +584,77 @@ func TestAppendHuffmanString(t *testing.T) {
}
}
func TestHuffmanRoundtripStress(t *testing.T) {
const Len = 50 // of uncompressed string
input := make([]byte, Len)
var output bytes.Buffer
var huff []byte
n := 5000
if testing.Short() {
n = 100
}
seed := time.Now().UnixNano()
t.Logf("Seed = %v", seed)
src := rand.New(rand.NewSource(seed))
var encSize int64
for i := 0; i < n; i++ {
for l := range input {
input[l] = byte(src.Intn(256))
}
huff = AppendHuffmanString(huff[:0], string(input))
encSize += int64(len(huff))
output.Reset()
if err := huffmanDecode(&output, huff); err != nil {
t.Errorf("Failed to decode %q -> %q -> error %v", input, huff, err)
continue
}
if !bytes.Equal(output.Bytes(), input) {
t.Errorf("Roundtrip failure on %q -> %q -> %q", input, huff, output.Bytes())
}
}
t.Logf("Compressed size of original: %0.02f%% (%v -> %v)", 100*(float64(encSize)/(Len*float64(n))), Len*n, encSize)
}
func TestHuffmanDecodeFuzz(t *testing.T) {
const Len = 50 // of compressed
var buf, zbuf bytes.Buffer
n := 5000
if testing.Short() {
n = 100
}
seed := time.Now().UnixNano()
t.Logf("Seed = %v", seed)
src := rand.New(rand.NewSource(seed))
numFail := 0
for i := 0; i < n; i++ {
zbuf.Reset()
if i == 0 {
// Start with at least one invalid one.
zbuf.WriteString("00\x91\xff\xff\xff\xff\xc8")
} else {
for l := 0; l < Len; l++ {
zbuf.WriteByte(byte(src.Intn(256)))
}
}
buf.Reset()
if err := huffmanDecode(&buf, zbuf.Bytes()); err != nil {
if err == ErrInvalidHuffman {
numFail++
continue
}
t.Errorf("Failed to decode %q: %v", zbuf.Bytes(), err)
continue
}
}
t.Logf("%0.02f%% are invalid (%d / %d)", 100*float64(numFail)/float64(n), numFail, n)
if numFail < 1 {
t.Error("expected at least one invalid huffman encoding (test starts with one)")
}
}
func TestReadVarInt(t *testing.T) {
type res struct {
i uint64
......@@ -637,6 +710,17 @@ func TestReadVarInt(t *testing.T) {
}
}
// Fuzz crash, originally reported at https://github.com/bradfitz/http2/issues/56
func TestHuffmanFuzzCrash(t *testing.T) {
got, err := HuffmanDecodeToString([]byte("00\x91\xff\xff\xff\xff\xc8"))
if got != "" {
t.Errorf("Got %q; want empty string", got)
}
if err != ErrInvalidHuffman {
t.Errorf("Err = %v; want ErrInvalidHuffman", err)
}
}
func dehex(s string) []byte {
s = strings.Replace(s, " ", "", -1)
s = strings.Replace(s, "\n", "", -1)
......
......@@ -7,6 +7,7 @@ package hpack
import (
"bytes"
"errors"
"io"
"sync"
)
......@@ -22,14 +23,39 @@ func HuffmanDecode(w io.Writer, v []byte) (int, error) {
buf := bufPool.Get().(*bytes.Buffer)
buf.Reset()
defer bufPool.Put(buf)
if err := huffmanDecode(buf, v); err != nil {
return 0, err
}
return w.Write(buf.Bytes())
}
// HuffmanDecodeToString decodes the string in v.
func HuffmanDecodeToString(v []byte) (string, error) {
buf := bufPool.Get().(*bytes.Buffer)
buf.Reset()
defer bufPool.Put(buf)
if err := huffmanDecode(buf, v); err != nil {
return "", err
}
return buf.String(), nil
}
// ErrInvalidHuffman is returned for errors found decoding
// Huffman-encoded strings.
var ErrInvalidHuffman = errors.New("hpack: invalid Huffman-encoded data")
func huffmanDecode(buf *bytes.Buffer, v []byte) error {
n := rootHuffmanNode
cur, nbits := uint(0), uint8(0)
for _, b := range v {
cur = cur<<8 | uint(b)
nbits += 8
for nbits >= 8 {
n = n.children[byte(cur>>(nbits-8))]
idx := byte(cur >> (nbits - 8))
n = n.children[idx]
if n == nil {
return ErrInvalidHuffman
}
if n.children == nil {
buf.WriteByte(n.sym)
nbits -= n.codeLen
......@@ -48,7 +74,7 @@ func HuffmanDecode(w io.Writer, v []byte) (int, error) {
nbits -= n.codeLen
n = rootHuffmanNode
}
return w.Write(buf.Bytes())
return nil
}
type node struct {
......@@ -67,10 +93,10 @@ func newInternalNode() *node {
var rootHuffmanNode = newInternalNode()
func init() {
for i, code := range huffmanCodes {
if i > 255 {
panic("too many huffman codes")
if len(huffmanCodes) != 256 {
panic("unexpected size")
}
for i, code := range huffmanCodes {
addDecoderNode(byte(i), code, huffmanCodeLen[i])
}
}
......
......@@ -74,7 +74,7 @@ var staticTable = [...]HeaderField{
pair("www-authenticate", ""),
}
var huffmanCodes = []uint32{
var huffmanCodes = [256]uint32{
0x1ff8,
0x7fffd8,
0xfffffe2,
......@@ -333,7 +333,7 @@ var huffmanCodes = []uint32{
0x3ffffee,
}
var huffmanCodeLen = []uint8{
var huffmanCodeLen = [256]uint8{
13, 23, 28, 28, 28, 28, 28, 28, 28, 24, 30, 28, 28, 30, 28, 28,
28, 28, 28, 28, 28, 28, 30, 28, 28, 28, 28, 28, 28, 28, 28, 28,
6, 10, 10, 12, 13, 6, 8, 11, 10, 10, 8, 11, 8, 6, 6, 6,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment