Commit a30eaa12 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

net/http: fix up Response.Write edge cases

The Go HTTP server doesn't use Response.Write, but others do,
so make it correct. Add a bunch more tests.

This bug is almost a year old. :/

Fixes #5381

LGTM=adg
R=golang-codereviews, adg
CC=dsymonds, golang-codereviews, rsc
https://golang.org/cl/85740046
parent 9b3e2aa1
...@@ -8,6 +8,7 @@ package http ...@@ -8,6 +8,7 @@ package http
import ( import (
"bufio" "bufio"
"bytes"
"crypto/tls" "crypto/tls"
"errors" "errors"
"io" "io"
...@@ -199,7 +200,6 @@ func (r *Response) ProtoAtLeast(major, minor int) bool { ...@@ -199,7 +200,6 @@ func (r *Response) ProtoAtLeast(major, minor int) bool {
// //
// Body is closed after it is sent. // Body is closed after it is sent.
func (r *Response) Write(w io.Writer) error { func (r *Response) Write(w io.Writer) error {
// Status line // Status line
text := r.Status text := r.Status
if text == "" { if text == "" {
...@@ -212,10 +212,45 @@ func (r *Response) Write(w io.Writer) error { ...@@ -212,10 +212,45 @@ func (r *Response) Write(w io.Writer) error {
protoMajor, protoMinor := strconv.Itoa(r.ProtoMajor), strconv.Itoa(r.ProtoMinor) protoMajor, protoMinor := strconv.Itoa(r.ProtoMajor), strconv.Itoa(r.ProtoMinor)
statusCode := strconv.Itoa(r.StatusCode) + " " statusCode := strconv.Itoa(r.StatusCode) + " "
text = strings.TrimPrefix(text, statusCode) text = strings.TrimPrefix(text, statusCode)
io.WriteString(w, "HTTP/"+protoMajor+"."+protoMinor+" "+statusCode+text+"\r\n") if _, err := io.WriteString(w, "HTTP/"+protoMajor+"."+protoMinor+" "+statusCode+text+"\r\n"); err != nil {
return err
}
// Clone it, so we can modify r1 as needed.
r1 := new(Response)
*r1 = *r
if r1.ContentLength == 0 && r1.Body != nil {
// Is it actually 0 length? Or just unknown?
var buf [1]byte
n, err := r1.Body.Read(buf[:])
if err != nil && err != io.EOF {
return err
}
if n == 0 {
// Reset it to a known zero reader, in case underlying one
// is unhappy being read repeatedly.
r1.Body = eofReader
} else {
r1.ContentLength = -1
r1.Body = struct {
io.Reader
io.Closer
}{
io.MultiReader(bytes.NewReader(buf[:1]), r.Body),
r.Body,
}
}
}
// If we're sending a non-chunked HTTP/1.1 response without a
// content-length, the only way to do that is the old HTTP/1.0
// way, by noting the EOF with a connection close, so we need
// to set Close.
if r1.ContentLength == -1 && !r1.Close && r1.ProtoAtLeast(1, 1) && !chunked(r1.TransferEncoding) {
r1.Close = true
}
// Process Body,ContentLength,Close,Trailer // Process Body,ContentLength,Close,Trailer
tw, err := newTransferWriter(r) tw, err := newTransferWriter(r1)
if err != nil { if err != nil {
return err return err
} }
...@@ -230,8 +265,16 @@ func (r *Response) Write(w io.Writer) error { ...@@ -230,8 +265,16 @@ func (r *Response) Write(w io.Writer) error {
return err return err
} }
if r1.ContentLength == 0 && !chunked(r1.TransferEncoding) {
if _, err := io.WriteString(w, "Content-Length: 0\r\n"); err != nil {
return err
}
}
// End-of-header // End-of-header
io.WriteString(w, "\r\n") if _, err := io.WriteString(w, "\r\n"); err != nil {
return err
}
// Write body and trailer // Write body and trailer
err = tw.WriteBody(w) err = tw.WriteBody(w)
......
...@@ -29,6 +29,10 @@ func dummyReq(method string) *Request { ...@@ -29,6 +29,10 @@ func dummyReq(method string) *Request {
return &Request{Method: method} return &Request{Method: method}
} }
func dummyReq11(method string) *Request {
return &Request{Method: method, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1}
}
var respTests = []respTest{ var respTests = []respTest{
// Unchunked response without Content-Length. // Unchunked response without Content-Length.
{ {
......
...@@ -26,7 +26,7 @@ func TestResponseWrite(t *testing.T) { ...@@ -26,7 +26,7 @@ func TestResponseWrite(t *testing.T) {
ProtoMinor: 0, ProtoMinor: 0,
Request: dummyReq("GET"), Request: dummyReq("GET"),
Header: Header{}, Header: Header{},
Body: ioutil.NopCloser(bytes.NewBufferString("abcdef")), Body: ioutil.NopCloser(strings.NewReader("abcdef")),
ContentLength: 6, ContentLength: 6,
}, },
...@@ -49,6 +49,106 @@ func TestResponseWrite(t *testing.T) { ...@@ -49,6 +49,106 @@ func TestResponseWrite(t *testing.T) {
"\r\n" + "\r\n" +
"abcdef", "abcdef",
}, },
// HTTP/1.1 response with unknown length and Connection: close
{
Response{
StatusCode: 200,
ProtoMajor: 1,
ProtoMinor: 1,
Request: dummyReq("GET"),
Header: Header{},
Body: ioutil.NopCloser(strings.NewReader("abcdef")),
ContentLength: -1,
Close: true,
},
"HTTP/1.1 200 OK\r\n" +
"Connection: close\r\n" +
"\r\n" +
"abcdef",
},
// HTTP/1.1 response with unknown length and not setting connection: close
{
Response{
StatusCode: 200,
ProtoMajor: 1,
ProtoMinor: 1,
Request: dummyReq11("GET"),
Header: Header{},
Body: ioutil.NopCloser(strings.NewReader("abcdef")),
ContentLength: -1,
Close: false,
},
"HTTP/1.1 200 OK\r\n" +
"Connection: close\r\n" +
"\r\n" +
"abcdef",
},
// HTTP/1.1 response with unknown length and not setting connection: close, but
// setting chunked.
{
Response{
StatusCode: 200,
ProtoMajor: 1,
ProtoMinor: 1,
Request: dummyReq11("GET"),
Header: Header{},
Body: ioutil.NopCloser(strings.NewReader("abcdef")),
ContentLength: -1,
TransferEncoding: []string{"chunked"},
Close: false,
},
"HTTP/1.1 200 OK\r\n" +
"Transfer-Encoding: chunked\r\n\r\n" +
"6\r\nabcdef\r\n0\r\n\r\n",
},
// HTTP/1.1 response 0 content-length, and nil body
{
Response{
StatusCode: 200,
ProtoMajor: 1,
ProtoMinor: 1,
Request: dummyReq11("GET"),
Header: Header{},
Body: nil,
ContentLength: 0,
Close: false,
},
"HTTP/1.1 200 OK\r\n" +
"Content-Length: 0\r\n" +
"\r\n",
},
// HTTP/1.1 response 0 content-length, and non-nil empty body
{
Response{
StatusCode: 200,
ProtoMajor: 1,
ProtoMinor: 1,
Request: dummyReq11("GET"),
Header: Header{},
Body: ioutil.NopCloser(strings.NewReader("")),
ContentLength: 0,
Close: false,
},
"HTTP/1.1 200 OK\r\n" +
"Content-Length: 0\r\n" +
"\r\n",
},
// HTTP/1.1 response 0 content-length, and non-nil non-empty body
{
Response{
StatusCode: 200,
ProtoMajor: 1,
ProtoMinor: 1,
Request: dummyReq11("GET"),
Header: Header{},
Body: ioutil.NopCloser(strings.NewReader("foo")),
ContentLength: 0,
Close: false,
},
"HTTP/1.1 200 OK\r\n" +
"Connection: close\r\n" +
"\r\nfoo",
},
// HTTP/1.1, chunked coding; empty trailer; close // HTTP/1.1, chunked coding; empty trailer; close
{ {
Response{ Response{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment