Commit 48e75337 authored by Mikio Hara's avatar Mikio Hara

net: separate DNS transport from DNS query-response interaction

Before fixing issue 6579 this CL separates DNS transport from
DNS message interaction to make it easier to add builtin DNS
resolver control logic.

Update #6579

LGTM=alex, kevlar
R=golang-codereviews, alex, gobot, iant, minux, kevlar
CC=golang-codereviews
https://golang.org/cl/101220044
parent f2f17c0f
......@@ -16,6 +16,7 @@
package net
import (
"errors"
"io"
"math/rand"
"os"
......@@ -23,118 +24,187 @@ import (
"time"
)
// Send a request on the connection and hope for a reply.
// Up to cfg.attempts attempts.
func exchange(cfg *dnsConfig, c Conn, name string, qtype uint16) (*dnsMsg, error) {
_, useTCP := c.(*TCPConn)
if len(name) >= 256 {
return nil, &DNSError{Err: "name too long", Name: name}
// A dnsConn represents a DNS transport endpoint.
type dnsConn interface {
Conn
// readDNSResponse reads a DNS response message from the DNS
// transport endpoint and returns the received DNS response
// message.
readDNSResponse() (*dnsMsg, error)
// writeDNSQuery writes a DNS query message to the DNS
// connection endpoint.
writeDNSQuery(*dnsMsg) error
}
func (c *UDPConn) readDNSResponse() (*dnsMsg, error) {
b := make([]byte, 512) // see RFC 1035
n, err := c.Read(b)
if err != nil {
return nil, err
}
msg := &dnsMsg{}
if !msg.Unpack(b[:n]) {
return nil, errors.New("cannot unmarshal DNS message")
}
return msg, nil
}
func (c *UDPConn) writeDNSQuery(msg *dnsMsg) error {
b, ok := msg.Pack()
if !ok {
return errors.New("cannot marshal DNS message")
}
if _, err := c.Write(b); err != nil {
return err
}
return nil
}
func (c *TCPConn) readDNSResponse() (*dnsMsg, error) {
b := make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
if _, err := io.ReadFull(c, b[:2]); err != nil {
return nil, err
}
l := int(b[0])<<8 | int(b[1])
if l > len(b) {
b = make([]byte, l)
}
n, err := io.ReadFull(c, b[:l])
if err != nil {
return nil, err
}
out := new(dnsMsg)
out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
out.question = []dnsQuestion{
{name, qtype, dnsClassINET},
msg := &dnsMsg{}
if !msg.Unpack(b[:n]) {
return nil, errors.New("cannot unmarshal DNS message")
}
out.recursion_desired = true
msg, ok := out.Pack()
return msg, nil
}
func (c *TCPConn) writeDNSQuery(msg *dnsMsg) error {
b, ok := msg.Pack()
if !ok {
return nil, &DNSError{Err: "internal error - cannot pack message", Name: name}
return errors.New("cannot marshal DNS message")
}
l := uint16(len(b))
b = append([]byte{byte(l >> 8), byte(l)}, b...)
if _, err := c.Write(b); err != nil {
return err
}
if useTCP {
mlen := uint16(len(msg))
msg = append([]byte{byte(mlen >> 8), byte(mlen)}, msg...)
return nil
}
func (d *Dialer) dialDNS(network, server string) (dnsConn, error) {
switch network {
case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
default:
return nil, UnknownNetworkError(network)
}
for attempt := 0; attempt < cfg.attempts; attempt++ {
n, err := c.Write(msg)
// Calling Dial here is scary -- we have to be sure not to
// dial a name that will require a DNS lookup, or Dial will
// call back here to translate it. The DNS config parser has
// already checked that all the cfg.servers[i] are IP
// addresses, which Dial will use without a DNS lookup.
c, err := d.Dial(network, server)
if err != nil {
return nil, err
}
switch network {
case "tcp", "tcp4", "tcp6":
return c.(*TCPConn), nil
case "udp", "udp4", "udp6":
return c.(*UDPConn), nil
}
panic("unreachable")
}
// exchange sends a query on the connection and hopes for a response.
func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) {
d := Dialer{Timeout: timeout}
out := dnsMsg{
dnsMsgHdr: dnsMsgHdr{
recursion_desired: true,
},
question: []dnsQuestion{
{name, qtype, dnsClassINET},
},
}
for _, network := range []string{"udp", "tcp"} {
c, err := d.dialDNS(network, server)
if err != nil {
return nil, err
}
if cfg.timeout == 0 {
c.SetReadDeadline(noDeadline)
} else {
c.SetReadDeadline(time.Now().Add(time.Duration(cfg.timeout) * time.Second))
defer c.Close()
if timeout > 0 {
c.SetDeadline(time.Now().Add(timeout))
}
buf := make([]byte, 2000)
if useTCP {
n, err = io.ReadFull(c, buf[:2])
if err != nil {
if e, ok := err.(Error); ok && e.Timeout() {
continue
}
}
mlen := int(buf[0])<<8 | int(buf[1])
if mlen > len(buf) {
buf = make([]byte, mlen)
}
n, err = io.ReadFull(c, buf[:mlen])
} else {
n, err = c.Read(buf)
out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
if err := c.writeDNSQuery(&out); err != nil {
return nil, err
}
in, err := c.readDNSResponse()
if err != nil {
if e, ok := err.(Error); ok && e.Timeout() {
continue
}
return nil, err
}
buf = buf[:n]
in := new(dnsMsg)
if !in.Unpack(buf) || in.id != out.id {
if in.id != out.id {
return nil, errors.New("DNS message ID mismatch")
}
if in.truncated { // see RFC 5966
continue
}
return in, nil
}
var server string
if a := c.RemoteAddr(); a != nil {
server = a.String()
}
return nil, &DNSError{Err: "no answer from server", Name: name, Server: server, IsTimeout: true}
return nil, errors.New("no answer from DNS server")
}
// Do a lookup for a single name, which must be rooted
// (otherwise answer will not find the answers).
func tryOneName(cfg *dnsConfig, name string, qtype uint16) (cname string, addrs []dnsRR, err error) {
func tryOneName(cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) {
if len(cfg.servers) == 0 {
return "", nil, &DNSError{Err: "no DNS servers", Name: name}
}
for i := 0; i < len(cfg.servers); i++ {
// Calling Dial here is scary -- we have to be sure
// not to dial a name that will require a DNS lookup,
// or Dial will call back here to translate it.
// The DNS config parser has already checked that
// all the cfg.servers[i] are IP addresses, which
// Dial will use without a DNS lookup.
server := cfg.servers[i] + ":53"
c, cerr := Dial("udp", server)
if cerr != nil {
err = cerr
continue
}
msg, merr := exchange(cfg, c, name, qtype)
c.Close()
if merr != nil {
err = merr
continue
if len(name) >= 256 {
return "", nil, &DNSError{Err: "DNS name too long", Name: name}
}
timeout := time.Duration(cfg.timeout) * time.Second
var lastErr error
for _, server := range cfg.servers {
server += ":53"
lastErr = &DNSError{
Err: "no answer from DNS server",
Name: name,
Server: server,
IsTimeout: true,
}
if msg.truncated { // see RFC 5966
c, cerr = Dial("tcp", server)
if cerr != nil {
err = cerr
continue
for i := 0; i < cfg.attempts; i++ {
msg, err := exchange(server, name, qtype, timeout)
if err != nil {
if nerr, ok := err.(Error); ok && nerr.Timeout() {
lastErr = &DNSError{
Err: nerr.Error(),
Name: name,
Server: server,
IsTimeout: true,
}
continue
}
lastErr = &DNSError{
Err: err.Error(),
Name: name,
Server: server,
}
break
}
msg, merr = exchange(cfg, c, name, qtype)
c.Close()
if merr != nil {
err = merr
continue
cname, addrs, err := answer(name, server, msg, qtype)
if err == nil || err.(*DNSError).Err == noSuchHost {
return cname, addrs, err
}
}
cname, addrs, err = answer(name, server, msg, qtype)
if err == nil || err.(*DNSError).Err == noSuchHost {
break
lastErr = err
}
}
return
return "", nil, lastErr
}
func convertRR_A(records []dnsRR) []IP {
......
......@@ -16,19 +16,79 @@ import (
"time"
)
func TestTCPLookup(t *testing.T) {
var dnsTransportFallbackTests = []struct {
server string
name string
qtype uint16
timeout int
rcode int
}{
// Querying "com." with qtype=255 usually makes an answer
// which requires more than 512 bytes.
{"8.8.8.8:53", "com.", dnsTypeALL, 2, dnsRcodeSuccess},
{"8.8.4.4:53", "com.", dnsTypeALL, 4, dnsRcodeSuccess},
}
func TestDNSTransportFallback(t *testing.T) {
if testing.Short() || !*testExternal {
t.Skip("skipping test to avoid external network")
}
c, err := Dial("tcp", "8.8.8.8:53")
if err != nil {
t.Fatalf("Dial failed: %v", err)
for _, tt := range dnsTransportFallbackTests {
timeout := time.Duration(tt.timeout) * time.Second
msg, err := exchange(tt.server, tt.name, tt.qtype, timeout)
if err != nil {
t.Error(err)
continue
}
switch msg.rcode {
case tt.rcode, dnsRcodeServerFailure:
default:
t.Errorf("got %v from %v; want %v", msg.rcode, tt.server, tt.rcode)
continue
}
}
defer c.Close()
cfg := &dnsConfig{timeout: 10, attempts: 3}
_, err = exchange(cfg, c, "com.", dnsTypeALL)
if err != nil {
t.Fatalf("exchange failed: %v", err)
}
// See RFC 6761 for further information about the reserved, pseudo
// domain names.
var specialDomainNameTests = []struct {
name string
qtype uint16
rcode int
}{
// Name resoltion APIs and libraries should not recongnize the
// followings as special.
{"1.0.168.192.in-addr.arpa.", dnsTypePTR, dnsRcodeNameError},
{"test.", dnsTypeALL, dnsRcodeNameError},
{"example.com.", dnsTypeALL, dnsRcodeSuccess},
// Name resoltion APIs and libraries should recongnize the
// followings as special and should not send any queries.
// Though, we test those names here for verifying nagative
// answers at DNS query-response interaction level.
{"localhost.", dnsTypeALL, dnsRcodeNameError},
{"invalid.", dnsTypeALL, dnsRcodeNameError},
}
func TestSpecialDomainName(t *testing.T) {
if testing.Short() || !*testExternal {
t.Skip("skipping test to avoid external network")
}
server := "8.8.8.8:53"
for _, tt := range specialDomainNameTests {
msg, err := exchange(server, tt.name, tt.qtype, 0)
if err != nil {
t.Error(err)
continue
}
switch msg.rcode {
case tt.rcode, dnsRcodeServerFailure:
default:
t.Errorf("got %v from %v; want %v", msg.rcode, server, tt.rcode)
continue
}
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment