Commit 636670b8 authored by Scott Bell's avatar Scott Bell Committed by Brad Fitzpatrick

net: use contexts for cgo-based DNS resolution

Although calls to getaddrinfo can't be portably interrupted,
we still benefit from more granular resource management by
pushing the context downwards.

Fixes #15321

Change-Id: I5506195fc6493080410e3d46aaa3fe02018a24fe
Reviewed-on: https://go-review.googlesource.com/22961Reviewed-by: 's avatarBrad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 9edb27e7
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
package net package net
import "context"
func init() { netGo = true } func init() { netGo = true }
type addrinfoErrno int type addrinfoErrno int
...@@ -14,22 +16,22 @@ func (eai addrinfoErrno) Error() string { return "<nil>" } ...@@ -14,22 +16,22 @@ func (eai addrinfoErrno) Error() string { return "<nil>" }
func (eai addrinfoErrno) Temporary() bool { return false } func (eai addrinfoErrno) Temporary() bool { return false }
func (eai addrinfoErrno) Timeout() bool { return false } func (eai addrinfoErrno) Timeout() bool { return false }
func cgoLookupHost(name string) (addrs []string, err error, completed bool) { func cgoLookupHost(ctx context.Context, name string) (addrs []string, err error, completed bool) {
return nil, nil, false return nil, nil, false
} }
func cgoLookupPort(network, service string) (port int, err error, completed bool) { func cgoLookupPort(ctx context.Context, network, service string) (port int, err error, completed bool) {
return 0, nil, false return 0, nil, false
} }
func cgoLookupIP(name string) (addrs []IPAddr, err error, completed bool) { func cgoLookupIP(ctx context.Context, name string) (addrs []IPAddr, err error, completed bool) {
return nil, nil, false return nil, nil, false
} }
func cgoLookupCNAME(name string) (cname string, err error, completed bool) { func cgoLookupCNAME(ctx context.Context, name string) (cname string, err error, completed bool) {
return "", nil, false return "", nil, false
} }
func cgoLookupPTR(addr string) (ptrs []string, err error, completed bool) { func cgoLookupPTR(ctx context.Context, addr string) (ptrs []string, err error, completed bool) {
return nil, nil, false return nil, nil, false
} }
...@@ -19,6 +19,7 @@ package net ...@@ -19,6 +19,7 @@ package net
import "C" import "C"
import ( import (
"context"
"syscall" "syscall"
"unsafe" "unsafe"
) )
...@@ -32,18 +33,31 @@ func (eai addrinfoErrno) Error() string { return C.GoString(C.gai_strerror(C.i ...@@ -32,18 +33,31 @@ func (eai addrinfoErrno) Error() string { return C.GoString(C.gai_strerror(C.i
func (eai addrinfoErrno) Temporary() bool { return eai == C.EAI_AGAIN } func (eai addrinfoErrno) Temporary() bool { return eai == C.EAI_AGAIN }
func (eai addrinfoErrno) Timeout() bool { return false } func (eai addrinfoErrno) Timeout() bool { return false }
func cgoLookupHost(name string) (hosts []string, err error, completed bool) { type portLookupResult struct {
addrs, err, completed := cgoLookupIP(name) port int
err error
}
type ipLookupResult struct {
addrs []IPAddr
cname string
err error
}
type reverseLookupResult struct {
names []string
err error
}
func cgoLookupHost(ctx context.Context, name string) (hosts []string, err error, completed bool) {
addrs, err, completed := cgoLookupIP(ctx, name)
for _, addr := range addrs { for _, addr := range addrs {
hosts = append(hosts, addr.String()) hosts = append(hosts, addr.String())
} }
return return
} }
func cgoLookupPort(network, service string) (port int, err error, completed bool) { func cgoLookupPort(ctx context.Context, network, service string) (port int, err error, completed bool) {
acquireThread()
defer releaseThread()
var hints C.struct_addrinfo var hints C.struct_addrinfo
switch network { switch network {
case "": // no hints case "": // no hints
...@@ -64,11 +78,27 @@ func cgoLookupPort(network, service string) (port int, err error, completed bool ...@@ -64,11 +78,27 @@ func cgoLookupPort(network, service string) (port int, err error, completed bool
hints.ai_family = C.AF_INET6 hints.ai_family = C.AF_INET6
} }
} }
if ctx.Done() == nil {
port, err := cgoLookupServicePort(&hints, network, service)
return port, err, true
}
result := make(chan portLookupResult, 1)
go cgoPortLookup(result, &hints, network, service)
select {
case r := <-result:
return r.port, r.err, true
case <-ctx.Done():
// Since there isn't a portable way to cancel the lookup,
// we just let it finish and write to the buffered channel.
return 0, mapErr(ctx.Err()), false
}
}
func cgoLookupServicePort(hints *C.struct_addrinfo, network, service string) (port int, err error) {
s := C.CString(service) s := C.CString(service)
var res *C.struct_addrinfo var res *C.struct_addrinfo
defer C.free(unsafe.Pointer(s)) defer C.free(unsafe.Pointer(s))
gerrno, err := C.getaddrinfo(nil, s, &hints, &res) gerrno, err := C.getaddrinfo(nil, s, hints, &res)
if gerrno != 0 { if gerrno != 0 {
switch gerrno { switch gerrno {
case C.EAI_SYSTEM: case C.EAI_SYSTEM:
...@@ -78,7 +108,7 @@ func cgoLookupPort(network, service string) (port int, err error, completed bool ...@@ -78,7 +108,7 @@ func cgoLookupPort(network, service string) (port int, err error, completed bool
default: default:
err = addrinfoErrno(gerrno) err = addrinfoErrno(gerrno)
} }
return 0, &DNSError{Err: err.Error(), Name: network + "/" + service}, true return 0, &DNSError{Err: err.Error(), Name: network + "/" + service}
} }
defer C.freeaddrinfo(res) defer C.freeaddrinfo(res)
...@@ -87,17 +117,22 @@ func cgoLookupPort(network, service string) (port int, err error, completed bool ...@@ -87,17 +117,22 @@ func cgoLookupPort(network, service string) (port int, err error, completed bool
case C.AF_INET: case C.AF_INET:
sa := (*syscall.RawSockaddrInet4)(unsafe.Pointer(r.ai_addr)) sa := (*syscall.RawSockaddrInet4)(unsafe.Pointer(r.ai_addr))
p := (*[2]byte)(unsafe.Pointer(&sa.Port)) p := (*[2]byte)(unsafe.Pointer(&sa.Port))
return int(p[0])<<8 | int(p[1]), nil, true return int(p[0])<<8 | int(p[1]), nil
case C.AF_INET6: case C.AF_INET6:
sa := (*syscall.RawSockaddrInet6)(unsafe.Pointer(r.ai_addr)) sa := (*syscall.RawSockaddrInet6)(unsafe.Pointer(r.ai_addr))
p := (*[2]byte)(unsafe.Pointer(&sa.Port)) p := (*[2]byte)(unsafe.Pointer(&sa.Port))
return int(p[0])<<8 | int(p[1]), nil, true return int(p[0])<<8 | int(p[1]), nil
} }
} }
return 0, &DNSError{Err: "unknown port", Name: network + "/" + service}, true return 0, &DNSError{Err: "unknown port", Name: network + "/" + service}
} }
func cgoLookupIPCNAME(name string) (addrs []IPAddr, cname string, err error, completed bool) { func cgoPortLookup(result chan<- portLookupResult, hints *C.struct_addrinfo, network, service string) {
port, err := cgoLookupServicePort(hints, network, service)
result <- portLookupResult{port, err}
}
func cgoLookupIPCNAME(name string) (addrs []IPAddr, cname string, err error) {
acquireThread() acquireThread()
defer releaseThread() defer releaseThread()
...@@ -127,7 +162,7 @@ func cgoLookupIPCNAME(name string) (addrs []IPAddr, cname string, err error, com ...@@ -127,7 +162,7 @@ func cgoLookupIPCNAME(name string) (addrs []IPAddr, cname string, err error, com
default: default:
err = addrinfoErrno(gerrno) err = addrinfoErrno(gerrno)
} }
return nil, "", &DNSError{Err: err.Error(), Name: name}, true return nil, "", &DNSError{Err: err.Error(), Name: name}
} }
defer C.freeaddrinfo(res) defer C.freeaddrinfo(res)
...@@ -156,17 +191,42 @@ func cgoLookupIPCNAME(name string) (addrs []IPAddr, cname string, err error, com ...@@ -156,17 +191,42 @@ func cgoLookupIPCNAME(name string) (addrs []IPAddr, cname string, err error, com
addrs = append(addrs, addr) addrs = append(addrs, addr)
} }
} }
return addrs, cname, nil, true return addrs, cname, nil
} }
func cgoLookupIP(name string) (addrs []IPAddr, err error, completed bool) { func cgoIPLookup(result chan<- ipLookupResult, name string) {
addrs, _, err, completed = cgoLookupIPCNAME(name) addrs, cname, err := cgoLookupIPCNAME(name)
return result <- ipLookupResult{addrs, cname, err}
} }
func cgoLookupCNAME(name string) (cname string, err error, completed bool) { func cgoLookupIP(ctx context.Context, name string) (addrs []IPAddr, err error, completed bool) {
_, cname, err, completed = cgoLookupIPCNAME(name) if ctx.Done() == nil {
return addrs, _, err = cgoLookupIPCNAME(name)
return addrs, err, true
}
result := make(chan ipLookupResult, 1)
go cgoIPLookup(result, name)
select {
case r := <-result:
return r.addrs, r.err, true
case <-ctx.Done():
return nil, mapErr(ctx.Err()), false
}
}
func cgoLookupCNAME(ctx context.Context, name string) (cname string, err error, completed bool) {
if ctx.Done() == nil {
_, cname, err = cgoLookupIPCNAME(name)
return cname, err, true
}
result := make(chan ipLookupResult, 1)
go cgoIPLookup(result, name)
select {
case r := <-result:
return r.cname, r.err, true
case <-ctx.Done():
return "", mapErr(ctx.Err()), false
}
} }
// These are roughly enough for the following: // These are roughly enough for the following:
...@@ -182,10 +242,7 @@ const ( ...@@ -182,10 +242,7 @@ const (
maxNameinfoLen = 4096 maxNameinfoLen = 4096
) )
func cgoLookupPTR(addr string) ([]string, error, bool) { func cgoLookupPTR(ctx context.Context, addr string) (names []string, err error, completed bool) {
acquireThread()
defer releaseThread()
var zone string var zone string
ip := parseIPv4(addr) ip := parseIPv4(addr)
if ip == nil { if ip == nil {
...@@ -198,9 +255,26 @@ func cgoLookupPTR(addr string) ([]string, error, bool) { ...@@ -198,9 +255,26 @@ func cgoLookupPTR(addr string) ([]string, error, bool) {
if sa == nil { if sa == nil {
return nil, &DNSError{Err: "invalid address " + ip.String(), Name: addr}, true return nil, &DNSError{Err: "invalid address " + ip.String(), Name: addr}, true
} }
var err error if ctx.Done() == nil {
var b []byte names, err := cgoLookupAddrPTR(addr, sa, salen)
return names, err, true
}
result := make(chan reverseLookupResult, 1)
go cgoReverseLookup(result, addr, sa, salen)
select {
case r := <-result:
return r.names, r.err, true
case <-ctx.Done():
return nil, mapErr(ctx.Err()), false
}
}
func cgoLookupAddrPTR(addr string, sa *C.struct_sockaddr, salen C.socklen_t) (names []string, err error) {
acquireThread()
defer releaseThread()
var gerrno int var gerrno int
var b []byte
for l := nameinfoLen; l <= maxNameinfoLen; l *= 2 { for l := nameinfoLen; l <= maxNameinfoLen; l *= 2 {
b = make([]byte, l) b = make([]byte, l)
gerrno, err = cgoNameinfoPTR(b, sa, salen) gerrno, err = cgoNameinfoPTR(b, sa, salen)
...@@ -217,16 +291,20 @@ func cgoLookupPTR(addr string) ([]string, error, bool) { ...@@ -217,16 +291,20 @@ func cgoLookupPTR(addr string) ([]string, error, bool) {
default: default:
err = addrinfoErrno(gerrno) err = addrinfoErrno(gerrno)
} }
return nil, &DNSError{Err: err.Error(), Name: addr}, true return nil, &DNSError{Err: err.Error(), Name: addr}
} }
for i := 0; i < len(b); i++ { for i := 0; i < len(b); i++ {
if b[i] == 0 { if b[i] == 0 {
b = b[:i] b = b[:i]
break break
} }
} }
return []string{absDomainName(b)}, nil, true return []string{absDomainName(b)}, nil
}
func cgoReverseLookup(result chan<- reverseLookupResult, addr string, sa *C.struct_sockaddr, salen C.socklen_t) {
names, err := cgoLookupAddrPTR(addr, sa, salen)
result <- reverseLookupResult{names, err}
} }
func cgoSockaddr(ip IP, zone string) (*C.struct_sockaddr, C.socklen_t) { func cgoSockaddr(ip IP, zone string) (*C.struct_sockaddr, C.socklen_t) {
......
...@@ -13,15 +13,70 @@ import ( ...@@ -13,15 +13,70 @@ import (
) )
func TestCgoLookupIP(t *testing.T) { func TestCgoLookupIP(t *testing.T) {
host := "localhost" ctx := context.Background()
_, err, ok := cgoLookupIP(host) _, err, ok := cgoLookupIP(ctx, "localhost")
if !ok { if !ok {
t.Errorf("cgoLookupIP must not be a placeholder") t.Errorf("cgoLookupIP must not be a placeholder")
} }
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if _, err := goLookupIP(context.Background(), host); err != nil { }
func TestCgoLookupIPWithCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err, ok := cgoLookupIP(ctx, "localhost")
if !ok {
t.Errorf("cgoLookupIP must not be a placeholder")
}
if err != nil {
t.Error(err)
}
}
func TestCgoLookupPort(t *testing.T) {
ctx := context.Background()
_, err, ok := cgoLookupPort(ctx, "tcp", "smtp")
if !ok {
t.Errorf("cgoLookupPort must not be a placeholder")
}
if err != nil {
t.Error(err)
}
}
func TestCgoLookupPortWithCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err, ok := cgoLookupPort(ctx, "tcp", "smtp")
if !ok {
t.Errorf("cgoLookupPort must not be a placeholder")
}
if err != nil {
t.Error(err)
}
}
func TestCgoLookupPTR(t *testing.T) {
ctx := context.Background()
_, err, ok := cgoLookupPTR(ctx, "127.0.0.1")
if !ok {
t.Errorf("cgoLookupPTR must not be a placeholder")
}
if err != nil {
t.Error(err)
}
}
func TestCgoLookupPTRWithCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err, ok := cgoLookupPTR(ctx, "127.0.0.1")
if !ok {
t.Errorf("cgoLookupPTR must not be a placeholder")
}
if err != nil {
t.Error(err) t.Error(err)
} }
} }
...@@ -55,7 +55,7 @@ func lookupProtocol(_ context.Context, name string) (int, error) { ...@@ -55,7 +55,7 @@ func lookupProtocol(_ context.Context, name string) (int, error) {
func lookupHost(ctx context.Context, host string) (addrs []string, err error) { func lookupHost(ctx context.Context, host string) (addrs []string, err error) {
order := systemConf().hostLookupOrder(host) order := systemConf().hostLookupOrder(host)
if order == hostLookupCgo { if order == hostLookupCgo {
if addrs, err, ok := cgoLookupHost(host); ok { if addrs, err, ok := cgoLookupHost(ctx, host); ok {
return addrs, err return addrs, err
} }
// cgo not available (or netgo); fall back to Go's DNS resolver // cgo not available (or netgo); fall back to Go's DNS resolver
...@@ -67,8 +67,7 @@ func lookupHost(ctx context.Context, host string) (addrs []string, err error) { ...@@ -67,8 +67,7 @@ func lookupHost(ctx context.Context, host string) (addrs []string, err error) {
func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
order := systemConf().hostLookupOrder(host) order := systemConf().hostLookupOrder(host)
if order == hostLookupCgo { if order == hostLookupCgo {
// TODO(bradfitz): push down ctx, or at least its deadline to start if addrs, err, ok := cgoLookupIP(ctx, host); ok {
if addrs, err, ok := cgoLookupIP(host); ok {
return addrs, err return addrs, err
} }
// cgo not available (or netgo); fall back to Go's DNS resolver // cgo not available (or netgo); fall back to Go's DNS resolver
...@@ -84,7 +83,7 @@ func lookupPort(ctx context.Context, network, service string) (int, error) { ...@@ -84,7 +83,7 @@ func lookupPort(ctx context.Context, network, service string) (int, error) {
// files might be on a remote filesystem, though. This should // files might be on a remote filesystem, though. This should
// probably race goroutines if ctx != context.Background(). // probably race goroutines if ctx != context.Background().
if systemConf().canUseCgo() { if systemConf().canUseCgo() {
if port, err, ok := cgoLookupPort(network, service); ok { if port, err, ok := cgoLookupPort(ctx, network, service); ok {
return port, err return port, err
} }
} }
...@@ -93,8 +92,7 @@ func lookupPort(ctx context.Context, network, service string) (int, error) { ...@@ -93,8 +92,7 @@ func lookupPort(ctx context.Context, network, service string) (int, error) {
func lookupCNAME(ctx context.Context, name string) (string, error) { func lookupCNAME(ctx context.Context, name string) (string, error) {
if systemConf().canUseCgo() { if systemConf().canUseCgo() {
// TODO: use ctx. issue 15321. Or race goroutines. if cname, err, ok := cgoLookupCNAME(ctx, name); ok {
if cname, err, ok := cgoLookupCNAME(name); ok {
return cname, err return cname, err
} }
} }
...@@ -161,7 +159,7 @@ func lookupTXT(ctx context.Context, name string) ([]string, error) { ...@@ -161,7 +159,7 @@ func lookupTXT(ctx context.Context, name string) ([]string, error) {
func lookupAddr(ctx context.Context, addr string) ([]string, error) { func lookupAddr(ctx context.Context, addr string) ([]string, error) {
if systemConf().canUseCgo() { if systemConf().canUseCgo() {
if ptrs, err, ok := cgoLookupPTR(addr); ok { if ptrs, err, ok := cgoLookupPTR(ctx, addr); ok {
return ptrs, err return ptrs, err
} }
} }
......
...@@ -14,14 +14,15 @@ import ( ...@@ -14,14 +14,15 @@ import (
func TestGoLookupIP(t *testing.T) { func TestGoLookupIP(t *testing.T) {
host := "localhost" host := "localhost"
_, err, ok := cgoLookupIP(host) ctx := context.Background()
_, err, ok := cgoLookupIP(ctx, host)
if ok { if ok {
t.Errorf("cgoLookupIP must be a placeholder") t.Errorf("cgoLookupIP must be a placeholder")
} }
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if _, err := goLookupIP(context.Background(), host); err != nil { if _, err := goLookupIP(ctx, host); err != nil {
t.Error(err) t.Error(err)
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment