Commit fa3e4fc4 authored by Alex Brainman's avatar Alex Brainman

net: fix connection resets when closed on windows

It is common to close network connection while another goroutine is
blocked reading on another goroutine. This sequence corresponds to
windows calls to WSARecv to start io, followed by GetQueuedCompletionStatus
that blocks until io completes, and, finally, closesocket called from
another thread. We were expecting that closesocket would unblock
GetQueuedCompletionStatus, and it does, but not always
(http://code.google.com/p/go/issues/detail?id=4170#c5). Also that sequence
results in connection is being reset.

This CL inserts CancelIo between GetQueuedCompletionStatus and closesocket,
and waits for both WSARecv and GetQueuedCompletionStatus to complete before
proceeding to closesocket.  This seems to fix both connection resets and
issue 4170. It also makes windows code behave similar to unix version.

Unfortunately, CancelIo needs to be called on the same thread as WSARecv.
So we have to employ strategy we use for connections with deadlines to
every connection now. It means, there are 2 unavoidable thread switches
for every io. Some newer versions of windows have new CancelIoEx api that
doesn't have these drawbacks, and this CL uses this capability when available.
As time goes by, we should have less of CancelIo and more of CancelIoEx
systems. Computers with CancelIoEx are also not affected by issue 4195 anymore.

Fixes #3710
Fixes #3746
Fixes #4170
Partial fix for issue 4195

R=golang-dev, mikioh.mikioh, bradfitz, rsc
CC=golang-dev
https://golang.org/cl/6604072
parent ad487dad
......@@ -261,6 +261,8 @@ var pollMaxN int
var pollservers []*pollServer
var startServersOnce []func()
var canCancelIO = true // used for testing current package
func init() {
pollMaxN = runtime.NumCPU()
if pollMaxN > 8 {
......
This diff is collapsed.
......@@ -174,3 +174,42 @@ func TestUDPListenClose(t *testing.T) {
t.Fatal("timeout waiting for UDP close")
}
}
func TestTCPClose(t *testing.T) {
l, err := Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer l.Close()
read := func(r io.Reader) error {
var m [1]byte
_, err := r.Read(m[:])
return err
}
go func() {
c, err := Dial("tcp", l.Addr().String())
if err != nil {
t.Fatal(err)
}
go read(c)
time.Sleep(10 * time.Millisecond)
c.Close()
}()
c, err := l.Accept()
if err != nil {
t.Fatal(err)
}
defer c.Close()
for err == nil {
err = read(c)
}
if err != nil && err != io.EOF {
t.Fatal(err)
}
}
......@@ -499,6 +499,27 @@ func TestClientWriteError(t *testing.T) {
w.done <- true
}
func TestTCPClose(t *testing.T) {
once.Do(startServer)
client, err := dialHTTP()
if err != nil {
t.Fatalf("dialing: %v", err)
}
defer client.Close()
args := Args{17, 8}
var reply Reply
err = client.Call("Arith.Mul", args, &reply)
if err != nil {
t.Fatal("arith error:", err)
}
t.Logf("Arith: %d*%d=%d\n", args.A, args.B, reply)
if reply.C != args.A*args.B {
t.Errorf("Add: expected %d got %d", reply.C, args.A*args.B)
}
}
func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) {
b.StopTimer()
once.Do(startServer)
......
......@@ -48,12 +48,12 @@ func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) {
return 0, nil, false
}
c.wio.Lock()
defer c.wio.Unlock()
if err := c.incref(false); err != nil {
return 0, err, true
}
defer c.decref()
c.wio.Lock()
defer c.wio.Unlock()
var o sendfileOp
o.Init(c, 'w')
......
......@@ -146,3 +146,82 @@ func TestTimeoutAccept(t *testing.T) {
// Pass.
}
}
func TestReadWriteDeadline(t *testing.T) {
if !canCancelIO {
t.Logf("skipping test on this system")
return
}
const (
readTimeout = 100 * time.Millisecond
writeTimeout = 200 * time.Millisecond
delta = 40 * time.Millisecond
)
checkTimeout := func(command string, start time.Time, should time.Duration) {
is := time.Now().Sub(start)
d := should - is
if d < -delta || delta < d {
t.Errorf("%s timeout test failed: is=%v should=%v\n", command, is, should)
}
}
ln, err := Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("ListenTCP on :0: %v", err)
}
lnquit := make(chan bool)
go func() {
c, err := ln.Accept()
if err != nil {
t.Fatalf("Accept: %v", err)
}
defer c.Close()
lnquit <- true
}()
c, err := Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer c.Close()
start := time.Now()
err = c.SetReadDeadline(start.Add(readTimeout))
if err != nil {
t.Fatalf("SetReadDeadline: %v", err)
}
err = c.SetWriteDeadline(start.Add(writeTimeout))
if err != nil {
t.Fatalf("SetWriteDeadline: %v", err)
}
quit := make(chan bool)
go func() {
var buf [10]byte
_, err = c.Read(buf[:])
if err == nil {
t.Errorf("Read should not succeed")
}
checkTimeout("Read", start, readTimeout)
quit <- true
}()
go func() {
var buf [10000]byte
for {
_, err = c.Write(buf[:])
if err != nil {
break
}
}
checkTimeout("Write", start, writeTimeout)
quit <- true
}()
<-quit
<-quit
<-lnquit
}
......@@ -142,6 +142,7 @@ func NewCallback(fn interface{}) uintptr
//sys GetQueuedCompletionStatus(cphandle Handle, qty *uint32, key *uint32, overlapped **Overlapped, timeout uint32) (err error)
//sys PostQueuedCompletionStatus(cphandle Handle, qty uint32, key uint32, overlapped *Overlapped) (err error)
//sys CancelIo(s Handle) (err error)
//sys CancelIoEx(s Handle, o *Overlapped) (err error)
//sys CreateProcess(appName *uint16, commandLine *uint16, procSecurity *SecurityAttributes, threadSecurity *SecurityAttributes, inheritHandles bool, creationFlags uint32, env *uint16, currentDir *uint16, startupInfo *StartupInfo, outProcInfo *ProcessInformation) (err error) = CreateProcessW
//sys OpenProcess(da uint32, inheritHandle bool, pid uint32) (handle Handle, err error)
//sys TerminateProcess(handle Handle, exitcode uint32) (err error)
......@@ -474,6 +475,10 @@ func Chmod(path string, mode uint32) (err error) {
return SetFileAttributes(p, attrs)
}
func LoadCancelIoEx() error {
return procCancelIoEx.Find()
}
// net api calls
const socket_error = uintptr(^uint32(0))
......
......@@ -49,6 +49,7 @@ var (
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
procPostQueuedCompletionStatus = modkernel32.NewProc("PostQueuedCompletionStatus")
procCancelIo = modkernel32.NewProc("CancelIo")
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
procCreateProcessW = modkernel32.NewProc("CreateProcessW")
procOpenProcess = modkernel32.NewProc("OpenProcess")
procTerminateProcess = modkernel32.NewProc("TerminateProcess")
......@@ -535,6 +536,18 @@ func CancelIo(s Handle) (err error) {
return
}
func CancelIoEx(s Handle, o *Overlapped) (err error) {
r1, _, e1 := Syscall(procCancelIoEx.Addr(), 2, uintptr(s), uintptr(unsafe.Pointer(o)), 0)
if r1 == 0 {
if e1 != 0 {
err = error(e1)
} else {
err = EINVAL
}
}
return
}
func CreateProcess(appName *uint16, commandLine *uint16, procSecurity *SecurityAttributes, threadSecurity *SecurityAttributes, inheritHandles bool, creationFlags uint32, env *uint16, currentDir *uint16, startupInfo *StartupInfo, outProcInfo *ProcessInformation) (err error) {
var _p0 uint32
if inheritHandles {
......
......@@ -49,6 +49,7 @@ var (
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
procPostQueuedCompletionStatus = modkernel32.NewProc("PostQueuedCompletionStatus")
procCancelIo = modkernel32.NewProc("CancelIo")
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
procCreateProcessW = modkernel32.NewProc("CreateProcessW")
procOpenProcess = modkernel32.NewProc("OpenProcess")
procTerminateProcess = modkernel32.NewProc("TerminateProcess")
......@@ -535,6 +536,18 @@ func CancelIo(s Handle) (err error) {
return
}
func CancelIoEx(s Handle, o *Overlapped) (err error) {
r1, _, e1 := Syscall(procCancelIoEx.Addr(), 2, uintptr(s), uintptr(unsafe.Pointer(o)), 0)
if r1 == 0 {
if e1 != 0 {
err = error(e1)
} else {
err = EINVAL
}
}
return
}
func CreateProcess(appName *uint16, commandLine *uint16, procSecurity *SecurityAttributes, threadSecurity *SecurityAttributes, inheritHandles bool, creationFlags uint32, env *uint16, currentDir *uint16, startupInfo *StartupInfo, outProcInfo *ProcessInformation) (err error) {
var _p0 uint32
if inheritHandles {
......
......@@ -20,6 +20,7 @@ const (
ERROR_ENVVAR_NOT_FOUND Errno = 203
ERROR_OPERATION_ABORTED Errno = 995
ERROR_IO_PENDING Errno = 997
ERROR_NOT_FOUND Errno = 1168
)
const (
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment