Commit 3411d632 authored by Matthew Dempsky's avatar Matthew Dempsky

net: keep waiting for valid DNS response until timeout

Prevents denial of service attacks from bogus UDP packets.

Fixes #13281.

Change-Id: Ifb51b17a1b0807bfd27b144d6037431701184e7b
Reviewed-on: https://go-review.googlesource.com/22126Reviewed-by: 's avatarBrad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 9f1ccd64
...@@ -38,46 +38,67 @@ type dnsConn interface { ...@@ -38,46 +38,67 @@ type dnsConn interface {
SetDeadline(time.Time) error SetDeadline(time.Time) error
// readDNSResponse reads a DNS response message from the DNS // dnsRoundTrip executes a single DNS transaction, returning a
// transport endpoint and returns the received DNS response // DNS response message for the provided DNS query message.
// message. dnsRoundTrip(query *dnsMsg) (*dnsMsg, error)
readDNSResponse() (*dnsMsg, error)
// writeDNSQuery writes a DNS query message to the DNS
// connection endpoint.
writeDNSQuery(*dnsMsg) error
} }
func (c *UDPConn) readDNSResponse() (*dnsMsg, error) { func (c *UDPConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
b := make([]byte, 512) // see RFC 1035 return dnsRoundTripUDP(c, query)
n, err := c.Read(b) }
if err != nil {
// dnsRoundTripUDP implements the dnsRoundTrip interface for RFC 1035's
// "UDP usage" transport mechanism. c should be a packet-oriented connection,
// such as a *UDPConn.
func dnsRoundTripUDP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
b, ok := query.Pack()
if !ok {
return nil, errors.New("cannot marshal DNS message")
}
if _, err := c.Write(b); err != nil {
return nil, err return nil, err
} }
msg := &dnsMsg{}
if !msg.Unpack(b[:n]) { b = make([]byte, 512) // see RFC 1035
return nil, errors.New("cannot unmarshal DNS message") for {
n, err := c.Read(b)
if err != nil {
return nil, err
}
resp := &dnsMsg{}
if !resp.Unpack(b[:n]) || !resp.IsResponseTo(query) {
// Ignore invalid responses as they may be malicious
// forgery attempts. Instead continue waiting until
// timeout. See golang.org/issue/13281.
continue
}
return resp, nil
} }
return msg, nil
} }
func (c *UDPConn) writeDNSQuery(msg *dnsMsg) error { func (c *TCPConn) dnsRoundTrip(out *dnsMsg) (*dnsMsg, error) {
b, ok := msg.Pack() return dnsRoundTripTCP(c, out)
}
// dnsRoundTripTCP implements the dnsRoundTrip interface for RFC 1035's
// "TCP usage" transport mechanism. c should be a stream-oriented connection,
// such as a *TCPConn.
func dnsRoundTripTCP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
b, ok := query.Pack()
if !ok { if !ok {
return errors.New("cannot marshal DNS message") return nil, errors.New("cannot marshal DNS message")
} }
l := len(b)
b = append([]byte{byte(l >> 8), byte(l)}, b...)
if _, err := c.Write(b); err != nil { if _, err := c.Write(b); err != nil {
return err return nil, err
} }
return nil
}
func (c *TCPConn) readDNSResponse() (*dnsMsg, error) { b = make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
b := make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
if _, err := io.ReadFull(c, b[:2]); err != nil { if _, err := io.ReadFull(c, b[:2]); err != nil {
return nil, err return nil, err
} }
l := int(b[0])<<8 | int(b[1]) l = int(b[0])<<8 | int(b[1])
if l > len(b) { if l > len(b) {
b = make([]byte, l) b = make([]byte, l)
} }
...@@ -85,24 +106,14 @@ func (c *TCPConn) readDNSResponse() (*dnsMsg, error) { ...@@ -85,24 +106,14 @@ func (c *TCPConn) readDNSResponse() (*dnsMsg, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
msg := &dnsMsg{} resp := &dnsMsg{}
if !msg.Unpack(b[:n]) { if !resp.Unpack(b[:n]) {
return nil, errors.New("cannot unmarshal DNS message") return nil, errors.New("cannot unmarshal DNS message")
} }
return msg, nil if !resp.IsResponseTo(query) {
} return nil, errors.New("invalid DNS response")
func (c *TCPConn) writeDNSQuery(msg *dnsMsg) error {
b, ok := msg.Pack()
if !ok {
return errors.New("cannot marshal DNS message")
} }
l := uint16(len(b)) return resp, nil
b = append([]byte{byte(l >> 8), byte(l)}, b...)
if _, err := c.Write(b); err != nil {
return err
}
return nil
} }
func (d *Dialer) dialDNS(ctx context.Context, network, server string) (dnsConn, error) { func (d *Dialer) dialDNS(ctx context.Context, network, server string) (dnsConn, error) {
...@@ -150,16 +161,10 @@ func exchange(ctx context.Context, server, name string, qtype uint16) (*dnsMsg, ...@@ -150,16 +161,10 @@ func exchange(ctx context.Context, server, name string, qtype uint16) (*dnsMsg,
c.SetDeadline(d) c.SetDeadline(d)
} }
out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano()) out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
if err := c.writeDNSQuery(&out); err != nil { in, err := c.dnsRoundTrip(&out)
return nil, mapErr(err)
}
in, err := c.readDNSResponse()
if err != nil { if err != nil {
return nil, mapErr(err) return nil, mapErr(err)
} }
if in.id != out.id {
return nil, errors.New("DNS message ID mismatch")
}
if in.truncated { // see RFC 5966 if in.truncated { // see RFC 5966
continue continue
} }
......
...@@ -567,9 +567,6 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) { ...@@ -567,9 +567,6 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
} }
type fakeDNSConn struct { type fakeDNSConn struct {
// last query
qmu sync.Mutex // guards q
q *dnsMsg
// reply handler // reply handler
rh func(*dnsMsg) (*dnsMsg, error) rh func(*dnsMsg) (*dnsMsg, error)
} }
...@@ -586,16 +583,76 @@ func (f *fakeDNSConn) SetDeadline(time.Time) error { ...@@ -586,16 +583,76 @@ func (f *fakeDNSConn) SetDeadline(time.Time) error {
return nil return nil
} }
func (f *fakeDNSConn) writeDNSQuery(q *dnsMsg) error { func (f *fakeDNSConn) dnsRoundTrip(q *dnsMsg) (*dnsMsg, error) {
f.qmu.Lock() return f.rh(q)
defer f.qmu.Unlock()
f.q = q
return nil
} }
func (f *fakeDNSConn) readDNSResponse() (*dnsMsg, error) { // UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281).
f.qmu.Lock() func TestIgnoreDNSForgeries(t *testing.T) {
q := f.q const TestAddr uint32 = 0x80420001
f.qmu.Unlock()
return f.rh(q) c, s := Pipe()
go func() {
b := make([]byte, 512)
n, err := s.Read(b)
if err != nil {
t.Fatal(err)
}
msg := &dnsMsg{}
if !msg.Unpack(b[:n]) {
t.Fatal("invalid DNS query")
}
s.Write([]byte("garbage DNS response packet"))
msg.response = true
msg.id++ // make invalid ID
b, ok := msg.Pack()
if !ok {
t.Fatal("failed to pack DNS response")
}
s.Write(b)
msg.id-- // restore original ID
msg.answer = []dnsRR{
&dnsRR_A{
Hdr: dnsRR_Header{
Name: "www.example.com.",
Rrtype: dnsTypeA,
Class: dnsClassINET,
Rdlength: 4,
},
A: TestAddr,
},
}
b, ok = msg.Pack()
if !ok {
t.Fatal("failed to pack DNS response")
}
s.Write(b)
}()
msg := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
id: 42,
},
question: []dnsQuestion{
{
Name: "www.example.com.",
Qtype: dnsTypeA,
Qclass: dnsClassINET,
},
},
}
resp, err := dnsRoundTripUDP(c, msg)
if err != nil {
t.Fatalf("dnsRoundTripUDP failed: %v", err)
}
if got := resp.answer[0].(*dnsRR_A).A; got != TestAddr {
t.Error("got address %v, want %v", got, TestAddr)
}
} }
...@@ -934,3 +934,23 @@ func (dns *dnsMsg) String() string { ...@@ -934,3 +934,23 @@ func (dns *dnsMsg) String() string {
} }
return s return s
} }
// IsResponseTo reports whether m is an acceptable response to query.
func (m *dnsMsg) IsResponseTo(query *dnsMsg) bool {
if !m.response {
return false
}
if m.id != query.id {
return false
}
if len(m.question) != len(query.question) {
return false
}
for i, q := range m.question {
q2 := query.question[i]
if !equalASCIILabel(q.Name, q2.Name) || q.Qtype != q2.Qtype || q.Qclass != q2.Qclass {
return false
}
}
return true
}
...@@ -280,6 +280,124 @@ func TestDNSParseTXTCorruptTXTLengthReply(t *testing.T) { ...@@ -280,6 +280,124 @@ func TestDNSParseTXTCorruptTXTLengthReply(t *testing.T) {
} }
} }
func TestIsResponseTo(t *testing.T) {
// Sample DNS query.
query := dnsMsg{
dnsMsgHdr: dnsMsgHdr{
id: 42,
},
question: []dnsQuestion{
{
Name: "www.example.com.",
Qtype: dnsTypeA,
Qclass: dnsClassINET,
},
},
}
resp := query
resp.response = true
if !resp.IsResponseTo(&query) {
t.Error("got false, want true")
}
badResponses := []dnsMsg{
// Different ID.
{
dnsMsgHdr: dnsMsgHdr{
id: 43,
response: true,
},
question: []dnsQuestion{
{
Name: "www.example.com.",
Qtype: dnsTypeA,
Qclass: dnsClassINET,
},
},
},
// Different query name.
{
dnsMsgHdr: dnsMsgHdr{
id: 42,
response: true,
},
question: []dnsQuestion{
{
Name: "www.google.com.",
Qtype: dnsTypeA,
Qclass: dnsClassINET,
},
},
},
// Different query type.
{
dnsMsgHdr: dnsMsgHdr{
id: 42,
response: true,
},
question: []dnsQuestion{
{
Name: "www.example.com.",
Qtype: dnsTypeAAAA,
Qclass: dnsClassINET,
},
},
},
// Different query class.
{
dnsMsgHdr: dnsMsgHdr{
id: 42,
response: true,
},
question: []dnsQuestion{
{
Name: "www.example.com.",
Qtype: dnsTypeA,
Qclass: dnsClassCSNET,
},
},
},
// No questions.
{
dnsMsgHdr: dnsMsgHdr{
id: 42,
response: true,
},
},
// Extra questions.
{
dnsMsgHdr: dnsMsgHdr{
id: 42,
response: true,
},
question: []dnsQuestion{
{
Name: "www.example.com.",
Qtype: dnsTypeA,
Qclass: dnsClassINET,
},
{
Name: "www.golang.org.",
Qtype: dnsTypeAAAA,
Qclass: dnsClassINET,
},
},
},
}
for i := range badResponses {
if badResponses[i].IsResponseTo(&query) {
t.Error("%v: got true, want false", i)
}
}
}
// Valid DNS SRV reply // Valid DNS SRV reply
const dnsSRVReply = "0901818000010005000000000c5f786d70702d736572766572045f74637006676f6f67" + const dnsSRVReply = "0901818000010005000000000c5f786d70702d736572766572045f74637006676f6f67" +
"6c6503636f6d0000210001c00c002100010000012c00210014000014950c786d70702d" + "6c6503636f6d0000210001c00c002100010000012c00210014000014950c786d70702d" +
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment