Commit 1ea796ee authored by Mark Ryan's avatar Mark Ryan Committed by Brad Fitzpatrick

encoding/base32: ensure base32 decoder propagates errors correctly

A number of issues in decoder.Read and newlineFilteringReader.Read were
preventing errors from the reader supplying the encoded data from being
propagated to the caller.  Fixing these issues revealed some additional
problems in which valid decoded data was not always returned to the user
when errors were actually propagated.

This commit fixes both the error propagation and the lost decoded data
problems.  It also adds some new unit tests to ensure errors are handled
correctly by decoder.Read.  The new unit tests increase the test coverage
of this package from 96.2% to 97.9%.

Fixes #20044

Change-Id: I1a8632da20135906e2d191c2a8825b10e7ecc4c5
Reviewed-on: https://go-review.googlesource.com/42094Reviewed-by: 's avatarBrad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 26a85211
...@@ -343,18 +343,33 @@ type decoder struct { ...@@ -343,18 +343,33 @@ type decoder struct {
outbuf [1024 / 8 * 5]byte outbuf [1024 / 8 * 5]byte
} }
func (d *decoder) Read(p []byte) (n int, err error) { func readEncodedData(r io.Reader, buf []byte, min int) (n int, err error) {
if d.err != nil { for n < min && err == nil {
return 0, d.err var nn int
nn, err = r.Read(buf[n:])
n += nn
}
if n < min && n > 0 && err == io.EOF {
err = io.ErrUnexpectedEOF
} }
return
}
func (d *decoder) Read(p []byte) (n int, err error) {
// Use leftover decoded output from last read. // Use leftover decoded output from last read.
if len(d.out) > 0 { if len(d.out) > 0 {
n = copy(p, d.out) n = copy(p, d.out)
d.out = d.out[n:] d.out = d.out[n:]
if len(d.out) == 0 {
return n, d.err
}
return n, nil return n, nil
} }
if d.err != nil {
return 0, d.err
}
// Read a chunk. // Read a chunk.
nn := len(p) / 5 * 8 nn := len(p) / 5 * 8
if nn < 8 { if nn < 8 {
...@@ -363,7 +378,8 @@ func (d *decoder) Read(p []byte) (n int, err error) { ...@@ -363,7 +378,8 @@ func (d *decoder) Read(p []byte) (n int, err error) {
if nn > len(d.buf) { if nn > len(d.buf) {
nn = len(d.buf) nn = len(d.buf)
} }
nn, d.err = io.ReadAtLeast(d.r, d.buf[d.nbuf:nn], 8-d.nbuf)
nn, d.err = readEncodedData(d.r, d.buf[d.nbuf:nn], 8-d.nbuf)
d.nbuf += nn d.nbuf += nn
if d.nbuf < 8 { if d.nbuf < 8 {
return 0, d.err return 0, d.err
...@@ -373,21 +389,30 @@ func (d *decoder) Read(p []byte) (n int, err error) { ...@@ -373,21 +389,30 @@ func (d *decoder) Read(p []byte) (n int, err error) {
nr := d.nbuf / 8 * 8 nr := d.nbuf / 8 * 8
nw := d.nbuf / 8 * 5 nw := d.nbuf / 8 * 5
if nw > len(p) { if nw > len(p) {
nw, d.end, d.err = d.enc.decode(d.outbuf[0:], d.buf[0:nr]) nw, d.end, err = d.enc.decode(d.outbuf[0:], d.buf[0:nr])
d.out = d.outbuf[0:nw] d.out = d.outbuf[0:nw]
n = copy(p, d.out) n = copy(p, d.out)
d.out = d.out[n:] d.out = d.out[n:]
} else { } else {
n, d.end, d.err = d.enc.decode(p, d.buf[0:nr]) n, d.end, err = d.enc.decode(p, d.buf[0:nr])
} }
d.nbuf -= nr d.nbuf -= nr
for i := 0; i < d.nbuf; i++ { for i := 0; i < d.nbuf; i++ {
d.buf[i] = d.buf[i+nr] d.buf[i] = d.buf[i+nr]
} }
if d.err == nil { if err != nil && (d.err == nil || d.err == io.EOF) {
d.err = err d.err = err
} }
if len(d.out) > 0 {
// We cannot return all the decoded bytes to the caller in this
// invocation of Read, so we return a nil error to ensure that Read
// will be called again. The error stored in d.err, if any, will be
// returned with the last set of decoded bytes.
return n, nil
}
return n, d.err return n, d.err
} }
...@@ -407,7 +432,7 @@ func (r *newlineFilteringReader) Read(p []byte) (int, error) { ...@@ -407,7 +432,7 @@ func (r *newlineFilteringReader) Read(p []byte) (int, error) {
offset++ offset++
} }
} }
if offset > 0 { if err != nil || offset > 0 {
return offset, err return offset, err
} }
// Previous buffer entirely whitespace, read again // Previous buffer entirely whitespace, read again
......
...@@ -6,6 +6,7 @@ package base32 ...@@ -6,6 +6,7 @@ package base32
import ( import (
"bytes" "bytes"
"errors"
"io" "io"
"io/ioutil" "io/ioutil"
"strings" "strings"
...@@ -123,6 +124,160 @@ func TestDecoder(t *testing.T) { ...@@ -123,6 +124,160 @@ func TestDecoder(t *testing.T) {
} }
} }
type badReader struct {
data []byte
errs []error
called int
limit int
}
// Populates p with data, returns a count of the bytes written and an
// error. The error returned is taken from badReader.errs, with each
// invocation of Read returning the next error in this slice, or io.EOF,
// if all errors from the slice have already been returned. The
// number of bytes returned is determined by the size of the input buffer
// the test passes to decoder.Read and will be a multiple of 8, unless
// badReader.limit is non zero.
func (b *badReader) Read(p []byte) (int, error) {
lim := len(p)
if b.limit != 0 && b.limit < lim {
lim = b.limit
}
if len(b.data) < lim {
lim = len(b.data)
}
for i := range p[:lim] {
p[i] = b.data[i]
}
b.data = b.data[lim:]
err := io.EOF
if b.called < len(b.errs) {
err = b.errs[b.called]
}
b.called++
return lim, err
}
// TestIssue20044 tests that decoder.Read behaves correctly when the caller
// supplied reader returns an error.
func TestIssue20044(t *testing.T) {
badErr := errors.New("bad reader error")
testCases := []struct {
r badReader
res string
err error
dbuflen int
}{
// Check valid input data accompanied by an error is processed and the error is propagated.
{r: badReader{data: []byte("MY======"), errs: []error{badErr}},
res: "f", err: badErr},
// Check a read error accompanied by input data consisting of newlines only is propagated.
{r: badReader{data: []byte("\n\n\n\n\n\n\n\n"), errs: []error{badErr, nil}},
res: "", err: badErr},
// Reader will be called twice. The first time it will return 8 newline characters. The
// second time valid base32 encoded data and an error. The data should be decoded
// correctly and the error should be propagated.
{r: badReader{data: []byte("\n\n\n\n\n\n\n\nMY======"), errs: []error{nil, badErr}},
res: "f", err: badErr, dbuflen: 8},
// Reader returns invalid input data (too short) and an error. Verify the reader
// error is returned.
{r: badReader{data: []byte("MY====="), errs: []error{badErr}},
res: "", err: badErr},
// Reader returns invalid input data (too short) but no error. Verify io.ErrUnexpectedEOF
// is returned.
{r: badReader{data: []byte("MY====="), errs: []error{nil}},
res: "", err: io.ErrUnexpectedEOF},
// Reader returns invalid input data and an error. Verify the reader and not the
// decoder error is returned.
{r: badReader{data: []byte("Ma======"), errs: []error{badErr}},
res: "", err: badErr},
// Reader returns valid data and io.EOF. Check data is decoded and io.EOF is propagated.
{r: badReader{data: []byte("MZXW6YTB"), errs: []error{io.EOF}},
res: "fooba", err: io.EOF},
// Check errors are properly reported when decoder.Read is called multiple times.
// decoder.Read will be called 8 times, badReader.Read will be called twice, returning
// valid data both times but an error on the second call.
{r: badReader{data: []byte("NRSWC43VOJSS4==="), errs: []error{nil, badErr}},
res: "leasure.", err: badErr, dbuflen: 1},
// Check io.EOF is properly reported when decoder.Read is called multiple times.
// decoder.Read will be called 8 times, badReader.Read will be called twice, returning
// valid data both times but io.EOF on the second call.
{r: badReader{data: []byte("NRSWC43VOJSS4==="), errs: []error{nil, io.EOF}},
res: "leasure.", err: io.EOF, dbuflen: 1},
// The following two test cases check that errors are propagated correctly when more than
// 8 bytes are read at a time.
{r: badReader{data: []byte("NRSWC43VOJSS4==="), errs: []error{io.EOF}},
res: "leasure.", err: io.EOF, dbuflen: 11},
{r: badReader{data: []byte("NRSWC43VOJSS4==="), errs: []error{badErr}},
res: "leasure.", err: badErr, dbuflen: 11},
// Check that errors are correctly propagated when the reader returns valid bytes in
// groups that are not divisible by 8. The first read will return 11 bytes and no
// error. The second will return 7 and an error. The data should be decoded correctly
// and the error should be propagated.
{r: badReader{data: []byte("NRSWC43VOJSS4==="), errs: []error{nil, badErr}, limit: 11},
res: "leasure.", err: badErr},
}
for _, tc := range testCases {
input := tc.r.data
decoder := NewDecoder(StdEncoding, &tc.r)
var dbuflen int
if tc.dbuflen > 0 {
dbuflen = tc.dbuflen
} else {
dbuflen = StdEncoding.DecodedLen(len(input))
}
dbuf := make([]byte, dbuflen)
var err error
var res []byte
for err == nil {
var n int
n, err = decoder.Read(dbuf)
if n > 0 {
res = append(res, dbuf[:n]...)
}
}
testEqual(t, "Decoding of %q = %q, want %q", string(input), string(res), tc.res)
testEqual(t, "Decoding of %q err = %v, expected %v", string(input), err, tc.err)
}
}
// TestDecoderError verifies decode errors are propagated when there are no read
// errors.
func TestDecoderError(t *testing.T) {
for _, readErr := range []error{io.EOF, nil} {
input := "MZXW6YTb"
dbuf := make([]byte, StdEncoding.DecodedLen(len(input)))
br := badReader{data: []byte(input), errs: []error{readErr}}
decoder := NewDecoder(StdEncoding, &br)
n, err := decoder.Read(dbuf)
testEqual(t, "Read after EOF, n = %d, expected %d", n, 0)
if _, ok := err.(CorruptInputError); !ok {
t.Errorf("Corrupt input error expected. Found %T", err)
}
}
}
// TestReaderEOF ensures decoder.Read behaves correctly when input data is
// exhausted.
func TestReaderEOF(t *testing.T) {
for _, readErr := range []error{io.EOF, nil} {
input := "MZXW6YTB"
br := badReader{data: []byte(input), errs: []error{nil, readErr}}
decoder := NewDecoder(StdEncoding, &br)
dbuf := make([]byte, StdEncoding.DecodedLen(len(input)))
n, err := decoder.Read(dbuf)
testEqual(t, "Decoding of %q err = %v, expected %v", string(input), err, error(nil))
n, err = decoder.Read(dbuf)
testEqual(t, "Read after EOF, n = %d, expected %d", n, 0)
testEqual(t, "Read after EOF, err = %v, expected %v", err, io.EOF)
n, err = decoder.Read(dbuf)
testEqual(t, "Read after EOF, n = %d, expected %d", n, 0)
testEqual(t, "Read after EOF, err = %v, expected %v", err, io.EOF)
}
}
func TestDecoderBuffering(t *testing.T) { func TestDecoderBuffering(t *testing.T) {
for bs := 1; bs <= 12; bs++ { for bs := 1; bs <= 12; bs++ {
decoder := NewDecoder(StdEncoding, strings.NewReader(bigtest.encoded)) decoder := NewDecoder(StdEncoding, strings.NewReader(bigtest.encoded))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment