Commit c0f5de1a authored by Joe Tsai's avatar Joe Tsai Committed by Chris Broadfoot

[release-branch.go1.7] compress/flate: make huffmanBitWriter errors persistent

For persistent error handling, the methods of huffmanBitWriter have to be
consistent about how they check errors. It must either consistently
check error *before* every operation OR immediately *after* every
operation. Since most of the current logic uses the previous approach,
we apply the same style of error checking to writeBits and all calls
to Write such that they only operate if w.err is already nil going
into them.

The error handling approach is brittle and easily broken by future commits to
the code. In the near future, we should switch the logic to use panic at the
lowest levels and a recover at the edge of the public API to ensure
that errors are always persistent.

Fixes #16749

Change-Id: Ie1d83e4ed8842f6911a31e23311cd3cbf38abe8c
Reviewed-on: https://go-review.googlesource.com/27200Reviewed-by: 's avatarMatthew Dempsky <mdempsky@google.com>
Reviewed-by: 's avatarBrad Fitzpatrick <bradfitz@golang.org>
Reviewed-on: https://go-review.googlesource.com/28634Reviewed-by: 's avatarJoe Tsai <thebrokentoaster@gmail.com>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
parent 1861ac15
......@@ -724,7 +724,7 @@ func (w *Writer) Close() error {
// the result of NewWriter or NewWriterDict called with dst
// and w's level and dictionary.
func (w *Writer) Reset(dst io.Writer) {
if dw, ok := w.d.w.w.(*dictWriter); ok {
if dw, ok := w.d.w.writer.(*dictWriter); ok {
// w was created with NewWriterDict
dw.w = dst
w.d.reset(dw)
......
......@@ -6,6 +6,7 @@ package flate
import (
"bytes"
"errors"
"fmt"
"internal/testenv"
"io"
......@@ -631,3 +632,52 @@ func TestBestSpeed(t *testing.T) {
}
}
}
var errIO = errors.New("IO error")
// failWriter fails with errIO exactly at the nth call to Write.
type failWriter struct{ n int }
func (w *failWriter) Write(b []byte) (int, error) {
w.n--
if w.n == -1 {
return 0, errIO
}
return len(b), nil
}
func TestWriterPersistentError(t *testing.T) {
d, err := ioutil.ReadFile("../testdata/Mark.Twain-Tom.Sawyer.txt")
if err != nil {
t.Fatalf("ReadFile: %v", err)
}
d = d[:10000] // Keep this test short
zw, err := NewWriter(nil, DefaultCompression)
if err != nil {
t.Fatalf("NewWriter: %v", err)
}
// Sweep over the threshold at which an error is returned.
// The variable i makes it such that the ith call to failWriter.Write will
// return errIO. Since failWriter errors are not persistent, we must ensure
// that flate.Writer errors are persistent.
for i := 0; i < 1000; i++ {
fw := &failWriter{i}
zw.Reset(fw)
_, werr := zw.Write(d)
cerr := zw.Close()
if werr != errIO && werr != nil {
t.Errorf("test %d, mismatching Write error: got %v, want %v", i, werr, errIO)
}
if cerr != errIO && fw.n < 0 {
t.Errorf("test %d, mismatching Close error: got %v, want %v", i, cerr, errIO)
}
if fw.n >= 0 {
// At this point, the failure threshold was sufficiently high enough
// that we wrote the whole stream without any errors.
return
}
}
}
......@@ -77,7 +77,11 @@ var offsetBase = []uint32{
var codegenOrder = []uint32{16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}
type huffmanBitWriter struct {
w io.Writer
// writer is the underlying writer.
// Do not use it directly; use the write method, which ensures
// that Write errors are sticky.
writer io.Writer
// Data waiting to be written is bytes[0:nbytes]
// and then the low nbits of bits.
bits uint64
......@@ -96,7 +100,7 @@ type huffmanBitWriter struct {
func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter {
return &huffmanBitWriter{
w: w,
writer: w,
literalFreq: make([]int32, maxNumLit),
offsetFreq: make([]int32, offsetCodeCount),
codegen: make([]uint8, maxNumLit+offsetCodeCount+1),
......@@ -107,7 +111,7 @@ func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter {
}
func (w *huffmanBitWriter) reset(writer io.Writer) {
w.w = writer
w.writer = writer
w.bits, w.nbits, w.nbytes, w.err = 0, 0, 0, nil
w.bytes = [bufferSize]byte{}
}
......@@ -129,11 +133,21 @@ func (w *huffmanBitWriter) flush() {
n++
}
w.bits = 0
_, w.err = w.w.Write(w.bytes[:n])
w.write(w.bytes[:n])
w.nbytes = 0
}
func (w *huffmanBitWriter) write(b []byte) {
if w.err != nil {
return
}
_, w.err = w.writer.Write(b)
}
func (w *huffmanBitWriter) writeBits(b int32, nb uint) {
if w.err != nil {
return
}
w.bits |= uint64(b) << w.nbits
w.nbits += nb
if w.nbits >= 48 {
......@@ -150,7 +164,7 @@ func (w *huffmanBitWriter) writeBits(b int32, nb uint) {
bytes[5] = byte(bits >> 40)
n += 6
if n >= bufferFlushSize {
_, w.err = w.w.Write(w.bytes[:n])
w.write(w.bytes[:n])
n = 0
}
w.nbytes = n
......@@ -173,13 +187,10 @@ func (w *huffmanBitWriter) writeBytes(bytes []byte) {
n++
}
if n != 0 {
_, w.err = w.w.Write(w.bytes[:n])
if w.err != nil {
return
}
w.write(w.bytes[:n])
}
w.nbytes = 0
_, w.err = w.w.Write(bytes)
w.write(bytes)
}
// RFC 1951 3.2.7 specifies a special run-length encoding for specifying
......@@ -341,7 +352,7 @@ func (w *huffmanBitWriter) writeCode(c hcode) {
bytes[5] = byte(bits >> 40)
n += 6
if n >= bufferFlushSize {
_, w.err = w.w.Write(w.bytes[:n])
w.write(w.bytes[:n])
n = 0
}
w.nbytes = n
......@@ -572,6 +583,9 @@ func (w *huffmanBitWriter) indexTokens(tokens []token) (numLiterals, numOffsets
// writeTokens writes a slice of tokens to the output.
// codes for literal and offset encoding must be supplied.
func (w *huffmanBitWriter) writeTokens(tokens []token, leCodes, oeCodes []hcode) {
if w.err != nil {
return
}
for _, t := range tokens {
if t < matchType {
w.writeCode(leCodes[t.literal()])
......@@ -676,9 +690,9 @@ func (w *huffmanBitWriter) writeBlockHuff(eof bool, input []byte) {
if n < bufferFlushSize {
continue
}
_, w.err = w.w.Write(w.bytes[:n])
w.write(w.bytes[:n])
if w.err != nil {
return
return // Return early in the event of write failures
}
n = 0
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment