Commit 6c877e5d authored by Ian Lance Taylor's avatar Ian Lance Taylor

net: avoid race on test hooks with DNS goroutines

The DNS code can start goroutines and not wait for them to complete.
This does no harm, but in tests this can cause a race condition with
the test hooks that are installed and unintalled around the tests.
Add a WaitGroup that tests of DNS can use to avoid the race.

Fixes #21090

Change-Id: I6c1443a9c2378e8b89d0ab1d6390c0e3e726b0ce
Reviewed-on: https://go-review.googlesource.com/82795Reviewed-by: 's avatarBrad Fitzpatrick <bradfitz@golang.org>
parent 6af8c0d8
......@@ -65,8 +65,10 @@ func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, e
}
// DoChan is like Do but returns a channel that will receive the
// results when they are ready.
func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result {
// results when they are ready. The second result is true if the function
// will eventually be called, false if it will not (because there is
// a pending request with this key).
func (g *Group) DoChan(key string, fn func() (interface{}, error)) (<-chan Result, bool) {
ch := make(chan Result, 1)
g.mu.Lock()
if g.m == nil {
......@@ -76,7 +78,7 @@ func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result
c.dups++
c.chans = append(c.chans, ch)
g.mu.Unlock()
return ch
return ch, false
}
c := &call{chans: []chan<- Result{ch}}
c.wg.Add(1)
......@@ -85,7 +87,7 @@ func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result
go g.doCall(c, key, fn)
return ch
return ch, true
}
// doCall handles the single call for a key.
......
......@@ -13,6 +13,7 @@ import (
)
func TestCgoLookupIP(t *testing.T) {
defer dnsWaitGroup.Wait()
ctx := context.Background()
_, err, ok := cgoLookupIP(ctx, "localhost")
if !ok {
......@@ -24,6 +25,7 @@ func TestCgoLookupIP(t *testing.T) {
}
func TestCgoLookupIPWithCancel(t *testing.T) {
defer dnsWaitGroup.Wait()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err, ok := cgoLookupIP(ctx, "localhost")
......@@ -36,6 +38,7 @@ func TestCgoLookupIPWithCancel(t *testing.T) {
}
func TestCgoLookupPort(t *testing.T) {
defer dnsWaitGroup.Wait()
ctx := context.Background()
_, err, ok := cgoLookupPort(ctx, "tcp", "smtp")
if !ok {
......@@ -47,6 +50,7 @@ func TestCgoLookupPort(t *testing.T) {
}
func TestCgoLookupPortWithCancel(t *testing.T) {
defer dnsWaitGroup.Wait()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err, ok := cgoLookupPort(ctx, "tcp", "smtp")
......@@ -59,6 +63,7 @@ func TestCgoLookupPortWithCancel(t *testing.T) {
}
func TestCgoLookupPTR(t *testing.T) {
defer dnsWaitGroup.Wait()
ctx := context.Background()
_, err, ok := cgoLookupPTR(ctx, "127.0.0.1")
if !ok {
......@@ -70,6 +75,7 @@ func TestCgoLookupPTR(t *testing.T) {
}
func TestCgoLookupPTRWithCancel(t *testing.T) {
defer dnsWaitGroup.Wait()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err, ok := cgoLookupPTR(ctx, "127.0.0.1")
......
......@@ -479,7 +479,9 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order
var lastErr error
for _, fqdn := range conf.nameList(name) {
for _, qtype := range qtypes {
dnsWaitGroup.Add(1)
go func(qtype uint16) {
defer dnsWaitGroup.Done()
cname, rrs, err := r.tryOneName(ctx, conf, fqdn, qtype)
lane <- racer{cname, rrs, err}
}(qtype)
......
......@@ -203,6 +203,7 @@ var fakeDNSServerSuccessful = fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.
// Issue 13705: don't try to resolve onion addresses, etc
func TestLookupTorOnion(t *testing.T) {
defer dnsWaitGroup.Wait()
r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext}
addrs, err := r.LookupIPAddr(context.Background(), "foo.onion")
if err != nil {
......@@ -300,6 +301,8 @@ var updateResolvConfTests = []struct {
}
func TestUpdateResolvConf(t *testing.T) {
defer dnsWaitGroup.Wait()
r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext}
conf, err := newResolvConfTest()
......@@ -455,6 +458,8 @@ var goLookupIPWithResolverConfigTests = []struct {
}
func TestGoLookupIPWithResolverConfig(t *testing.T) {
defer dnsWaitGroup.Wait()
fake := fakeDNSServer{func(n, s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
switch s {
case "[2001:4860:4860::8888]:53", "8.8.8.8:53":
......@@ -547,6 +552,8 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) {
// Test that goLookupIPOrder falls back to the host file when no DNS servers are available.
func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
defer dnsWaitGroup.Wait()
fake := fakeDNSServer{func(n, s string, q *dnsMsg, tm time.Time) (*dnsMsg, error) {
r := &dnsMsg{
dnsMsgHdr: dnsMsgHdr{
......@@ -603,6 +610,8 @@ func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
// querying the original name instead of an error encountered
// querying a generated name.
func TestErrorForOriginalNameWhenSearching(t *testing.T) {
defer dnsWaitGroup.Wait()
const fqdn = "doesnotexist.domain"
conf, err := newResolvConfTest()
......@@ -657,6 +666,8 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
// Issue 15434. If a name server gives a lame referral, continue to the next.
func TestIgnoreLameReferrals(t *testing.T) {
defer dnsWaitGroup.Wait()
conf, err := newResolvConfTest()
if err != nil {
t.Fatal(err)
......@@ -889,6 +900,8 @@ func TestIgnoreDNSForgeries(t *testing.T) {
// Issue 16865. If a name server times out, continue to the next.
func TestRetryTimeout(t *testing.T) {
defer dnsWaitGroup.Wait()
conf, err := newResolvConfTest()
if err != nil {
t.Fatal(err)
......@@ -945,6 +958,8 @@ func TestRotate(t *testing.T) {
}
func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) {
defer dnsWaitGroup.Wait()
conf, err := newResolvConfTest()
if err != nil {
t.Fatal(err)
......@@ -1008,6 +1023,8 @@ func mockTXTResponse(q *dnsMsg) *dnsMsg {
// Issue 17448. With StrictErrors enabled, temporary errors should make
// LookupIP fail rather than return a partial result.
func TestStrictErrorsLookupIP(t *testing.T) {
defer dnsWaitGroup.Wait()
conf, err := newResolvConfTest()
if err != nil {
t.Fatal(err)
......@@ -1256,6 +1273,8 @@ func TestStrictErrorsLookupIP(t *testing.T) {
// Issue 17448. With StrictErrors enabled, temporary errors should make
// LookupTXT stop walking the search list.
func TestStrictErrorsLookupTXT(t *testing.T) {
defer dnsWaitGroup.Wait()
conf, err := newResolvConfTest()
if err != nil {
t.Fatal(err)
......@@ -1312,3 +1331,25 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
}
}
}
// Test for a race between uninstalling the test hooks and closing a
// socket connection. This used to fail when testing with -race.
func TestDNSGoroutineRace(t *testing.T) {
defer dnsWaitGroup.Wait()
fake := fakeDNSServer{func(n, s string, q *dnsMsg, t time.Time) (*dnsMsg, error) {
time.Sleep(10 * time.Microsecond)
return nil, poll.ErrTimeout
}}
r := Resolver{PreferGo: true, Dial: fake.DialContext}
// The timeout here is less than the timeout used by the server,
// so the goroutine started to query the (fake) server will hang
// around after this test is done if we don't call dnsWaitGroup.Wait.
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Microsecond)
defer cancel()
_, err := r.LookupIPAddr(ctx, "where.are.they.now")
if err == nil {
t.Fatal("fake DNS lookup unexpectedly succeeded")
}
}
......@@ -8,6 +8,7 @@ import (
"context"
"internal/nettrace"
"internal/singleflight"
"sync"
)
// protocols contains minimal mappings between internet protocol
......@@ -53,6 +54,10 @@ var services = map[string]map[string]int{
},
}
// dnsWaitGroup can be used by tests to wait for all DNS goroutines to
// complete. This avoids races on the test hooks.
var dnsWaitGroup sync.WaitGroup
const maxProtoLength = len("RSVP-E2E-IGNORE") + 10 // with room to grow
func lookupProtocolMap(name string) (int, error) {
......@@ -189,9 +194,14 @@ func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]IPAddr, err
resolverFunc = alt
}
ch := lookupGroup.DoChan(host, func() (interface{}, error) {
dnsWaitGroup.Add(1)
ch, called := lookupGroup.DoChan(host, func() (interface{}, error) {
defer dnsWaitGroup.Done()
return testHookLookupIP(ctx, resolverFunc, host)
})
if !called {
dnsWaitGroup.Done()
}
select {
case <-ctx.Done():
......
......@@ -105,6 +105,8 @@ func TestLookupGmailMX(t *testing.T) {
t.Skip("IPv4 is required")
}
defer dnsWaitGroup.Wait()
for _, tt := range lookupGmailMXTests {
mxs, err := LookupMX(tt.name)
if err != nil {
......@@ -137,6 +139,8 @@ func TestLookupGmailNS(t *testing.T) {
t.Skip("IPv4 is required")
}
defer dnsWaitGroup.Wait()
for _, tt := range lookupGmailNSTests {
nss, err := LookupNS(tt.name)
if err != nil {
......@@ -170,6 +174,8 @@ func TestLookupGmailTXT(t *testing.T) {
t.Skip("IPv4 is required")
}
defer dnsWaitGroup.Wait()
for _, tt := range lookupGmailTXTTests {
txts, err := LookupTXT(tt.name)
if err != nil {
......@@ -205,6 +211,8 @@ func TestLookupGooglePublicDNSAddr(t *testing.T) {
t.Skip("both IPv4 and IPv6 are required")
}
defer dnsWaitGroup.Wait()
for _, tt := range lookupGooglePublicDNSAddrTests {
names, err := LookupAddr(tt.addr)
if err != nil {
......@@ -226,6 +234,8 @@ func TestLookupIPv6LinkLocalAddr(t *testing.T) {
t.Skip("IPv6 is required")
}
defer dnsWaitGroup.Wait()
addrs, err := LookupHost("localhost")
if err != nil {
t.Fatal(err)
......@@ -262,6 +272,8 @@ func TestLookupCNAME(t *testing.T) {
t.Skip("IPv4 is required")
}
defer dnsWaitGroup.Wait()
for _, tt := range lookupCNAMETests {
cname, err := LookupCNAME(tt.name)
if err != nil {
......@@ -289,6 +301,8 @@ func TestLookupGoogleHost(t *testing.T) {
t.Skip("IPv4 is required")
}
defer dnsWaitGroup.Wait()
for _, tt := range lookupGoogleHostTests {
addrs, err := LookupHost(tt.name)
if err != nil {
......@@ -313,6 +327,8 @@ func TestLookupLongTXT(t *testing.T) {
testenv.MustHaveExternalNetwork(t)
}
defer dnsWaitGroup.Wait()
txts, err := LookupTXT("golang.rsc.io")
if err != nil {
t.Fatal(err)
......@@ -343,6 +359,8 @@ func TestLookupGoogleIP(t *testing.T) {
t.Skip("IPv4 is required")
}
defer dnsWaitGroup.Wait()
for _, tt := range lookupGoogleIPTests {
ips, err := LookupIP(tt.name)
if err != nil {
......@@ -378,6 +396,7 @@ var revAddrTests = []struct {
}
func TestReverseAddress(t *testing.T) {
defer dnsWaitGroup.Wait()
for i, tt := range revAddrTests {
a, err := reverseaddr(tt.Addr)
if len(tt.ErrPrefix) > 0 && err == nil {
......@@ -401,6 +420,8 @@ func TestDNSFlood(t *testing.T) {
t.Skip("test disabled; use -dnsflood to enable")
}
defer dnsWaitGroup.Wait()
var N = 5000
if runtime.GOOS == "darwin" {
// On Darwin this test consumes kernel threads much
......@@ -482,6 +503,8 @@ func TestLookupDotsWithLocalSource(t *testing.T) {
testenv.MustHaveExternalNetwork(t)
}
defer dnsWaitGroup.Wait()
for i, fn := range []func() func(){forceGoDNS, forceCgoDNS} {
fixup := fn()
if fixup == nil {
......@@ -527,6 +550,8 @@ func TestLookupDotsWithRemoteSource(t *testing.T) {
t.Skip("IPv4 is required")
}
defer dnsWaitGroup.Wait()
if fixup := forceGoDNS(); fixup != nil {
testDots(t, "go")
fixup()
......@@ -747,6 +772,9 @@ func TestLookupNonLDH(t *testing.T) {
if runtime.GOOS == "nacl" {
t.Skip("skip on nacl")
}
defer dnsWaitGroup.Wait()
if fixup := forceGoDNS(); fixup != nil {
defer fixup()
}
......
......@@ -13,6 +13,7 @@ import (
)
func TestGoLookupIP(t *testing.T) {
defer dnsWaitGroup.Wait()
host := "localhost"
ctx := context.Background()
_, err, ok := cgoLookupIP(ctx, host)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment