Commit fa3e4fc4 authored by Alex Brainman's avatar Alex Brainman

net: fix connection resets when closed on windows

It is common to close network connection while another goroutine is
blocked reading on another goroutine. This sequence corresponds to
windows calls to WSARecv to start io, followed by GetQueuedCompletionStatus
that blocks until io completes, and, finally, closesocket called from
another thread. We were expecting that closesocket would unblock
GetQueuedCompletionStatus, and it does, but not always
(http://code.google.com/p/go/issues/detail?id=4170#c5). Also that sequence
results in connection is being reset.

This CL inserts CancelIo between GetQueuedCompletionStatus and closesocket,
and waits for both WSARecv and GetQueuedCompletionStatus to complete before
proceeding to closesocket.  This seems to fix both connection resets and
issue 4170. It also makes windows code behave similar to unix version.

Unfortunately, CancelIo needs to be called on the same thread as WSARecv.
So we have to employ strategy we use for connections with deadlines to
every connection now. It means, there are 2 unavoidable thread switches
for every io. Some newer versions of windows have new CancelIoEx api that
doesn't have these drawbacks, and this CL uses this capability when available.
As time goes by, we should have less of CancelIo and more of CancelIoEx
systems. Computers with CancelIoEx are also not affected by issue 4195 anymore.

Fixes #3710
Fixes #3746
Fixes #4170
Partial fix for issue 4195

R=golang-dev, mikioh.mikioh, bradfitz, rsc
CC=golang-dev
https://golang.org/cl/6604072
parent ad487dad
......@@ -261,6 +261,8 @@ var pollMaxN int
var pollservers []*pollServer
var startServersOnce []func()
var canCancelIO = true // used for testing current package
func init() {
pollMaxN = runtime.NumCPU()
if pollMaxN > 8 {
......
......@@ -17,19 +17,32 @@ import (
var initErr error
// CancelIo Windows API cancels all outstanding IO for a particular
// socket on current thread. To overcome that limitation, we run
// special goroutine, locked to OS single thread, that both starts
// and cancels IO. It means, there are 2 unavoidable thread switches
// for every IO.
// Some newer versions of Windows has new CancelIoEx API, that does
// not have that limitation and can be used from any thread. This
// package uses CancelIoEx API, if present, otherwise it fallback
// to CancelIo.
var canCancelIO bool // determines if CancelIoEx API is present
func init() {
var d syscall.WSAData
e := syscall.WSAStartup(uint32(0x202), &d)
if e != nil {
initErr = os.NewSyscallError("WSAStartup", e)
}
canCancelIO = syscall.LoadCancelIoEx() == nil
}
func closesocket(s syscall.Handle) error {
return syscall.Closesocket(s)
}
// Interface for all io operations.
// Interface for all IO operations.
type anOpIface interface {
Op() *anOp
Name() string
......@@ -42,7 +55,7 @@ type ioResult struct {
err error
}
// anOp implements functionality common to all io operations.
// anOp implements functionality common to all IO operations.
type anOp struct {
// Used by IOCP interface, it must be first field
// of the struct, as our code rely on it.
......@@ -75,7 +88,7 @@ func (o *anOp) Op() *anOp {
return o
}
// bufOp is used by io operations that read / write
// bufOp is used by IO operations that read / write
// data from / to client buffer.
type bufOp struct {
anOp
......@@ -92,7 +105,7 @@ func (o *bufOp) Init(fd *netFD, buf []byte, mode int) {
}
}
// resultSrv will retrieve all io completion results from
// resultSrv will retrieve all IO completion results from
// iocp and send them to the correspondent waiting client
// goroutine via channel supplied in the request.
type resultSrv struct {
......@@ -107,7 +120,7 @@ func (s *resultSrv) Run() {
r.err = syscall.GetQueuedCompletionStatus(s.iocp, &(r.qty), &key, &o, syscall.INFINITE)
switch {
case r.err == nil:
// Dequeued successfully completed io packet.
// Dequeued successfully completed IO packet.
case r.err == syscall.Errno(syscall.WAIT_TIMEOUT) && o == nil:
// Wait has timed out (should not happen now, but might be used in the future).
panic("GetQueuedCompletionStatus timed out")
......@@ -115,22 +128,23 @@ func (s *resultSrv) Run() {
// Failed to dequeue anything -> report the error.
panic("GetQueuedCompletionStatus failed " + r.err.Error())
default:
// Dequeued failed io packet.
// Dequeued failed IO packet.
}
(*anOp)(unsafe.Pointer(o)).resultc <- r
}
}
// ioSrv executes net io requests.
// ioSrv executes net IO requests.
type ioSrv struct {
submchan chan anOpIface // submit io requests
canchan chan anOpIface // cancel io requests
submchan chan anOpIface // submit IO requests
canchan chan anOpIface // cancel IO requests
}
// ProcessRemoteIO will execute submit io requests on behalf
// ProcessRemoteIO will execute submit IO requests on behalf
// of other goroutines, all on a single os thread, so it can
// cancel them later. Results of all operations will be sent
// back to their requesters via channel supplied in request.
// It is used only when the CancelIoEx API is unavailable.
func (s *ioSrv) ProcessRemoteIO() {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
......@@ -144,20 +158,21 @@ func (s *ioSrv) ProcessRemoteIO() {
}
}
// ExecIO executes a single io operation. It either executes it
// inline, or, if a deadline is employed, passes the request onto
// ExecIO executes a single IO operation oi. It submits and cancels
// IO in the current thread for systems where Windows CancelIoEx API
// is available. Alternatively, it passes the request onto
// a special goroutine and waits for completion or cancels request.
// deadline is unix nanos.
func (s *ioSrv) ExecIO(oi anOpIface, deadline int64) (int, error) {
var err error
o := oi.Op()
if deadline != 0 {
if canCancelIO {
err = oi.Submit()
} else {
// Send request to a special dedicated thread,
// so it can stop the io with CancelIO later.
// so it can stop the IO with CancelIO later.
s.submchan <- oi
err = <-o.errnoc
} else {
err = oi.Submit()
}
switch err {
case nil:
......@@ -168,27 +183,45 @@ func (s *ioSrv) ExecIO(oi anOpIface, deadline int64) (int, error) {
default:
return 0, &OpError{oi.Name(), o.fd.net, o.fd.laddr, err}
}
// Wait for our request to complete.
var r ioResult
// Setup timer, if deadline is given.
var timer <-chan time.Time
if deadline != 0 {
dt := deadline - time.Now().UnixNano()
if dt < 1 {
dt = 1
}
timer := time.NewTimer(time.Duration(dt) * time.Nanosecond)
defer timer.Stop()
select {
case r = <-o.resultc:
case <-timer.C:
t := time.NewTimer(time.Duration(dt) * time.Nanosecond)
defer t.Stop()
timer = t.C
}
// Wait for our request to complete.
var r ioResult
var cancelled bool
select {
case r = <-o.resultc:
case <-timer:
cancelled = true
case <-o.fd.closec:
cancelled = true
}
if cancelled {
// Cancel it.
if canCancelIO {
err := syscall.CancelIoEx(syscall.Handle(o.Op().fd.sysfd), &o.o)
// Assuming ERROR_NOT_FOUND is returned, if IO is completed.
if err != nil && err != syscall.ERROR_NOT_FOUND {
// TODO(brainman): maybe do something else, but panic.
panic(err)
}
} else {
s.canchan <- oi
<-o.errnoc
r = <-o.resultc
if r.err == syscall.ERROR_OPERATION_ABORTED { // IO Canceled
r.err = syscall.EWOULDBLOCK
}
}
} else {
// Wait for IO to be canceled or complete successfully.
r = <-o.resultc
if r.err == syscall.ERROR_OPERATION_ABORTED { // IO Canceled
r.err = syscall.EWOULDBLOCK
}
}
if r.err != nil {
err = &OpError{oi.Name(), o.fd.net, o.fd.laddr, r.err}
......@@ -211,9 +244,13 @@ func startServer() {
go resultsrv.Run()
iosrv = new(ioSrv)
iosrv.submchan = make(chan anOpIface)
iosrv.canchan = make(chan anOpIface)
go iosrv.ProcessRemoteIO()
if !canCancelIO {
// Only CancelIo API is available. Lets start special goroutine
// locked to an OS thread, that both starts and cancels IO.
iosrv.submchan = make(chan anOpIface)
iosrv.canchan = make(chan anOpIface)
go iosrv.ProcessRemoteIO()
}
}
// Network file descriptor.
......@@ -233,6 +270,7 @@ type netFD struct {
raddr Addr
resultc [2]chan ioResult // read/write completion results
errnoc [2]chan error // read/write submit or cancel operation errors
closec chan bool // used by Close to cancel pending IO
// owned by client
rdeadline int64
......@@ -247,6 +285,7 @@ func allocFD(fd syscall.Handle, family, sotype int, net string) *netFD {
family: family,
sotype: sotype,
net: net,
closec: make(chan bool),
}
runtime.SetFinalizer(netfd, (*netFD).Close)
return netfd
......@@ -299,24 +338,12 @@ func (fd *netFD) incref(closing bool) error {
// Remove a reference to this FD and close if we've been asked to do so (and
// there are no references left.
func (fd *netFD) decref() {
if fd == nil {
return
}
fd.sysmu.Lock()
fd.sysref--
// NOTE(rsc): On Unix we check fd.sysref == 0 here before closing,
// but on Windows we have no way to wake up the blocked I/O other
// than closing the socket (or calling Shutdown, which breaks other
// programs that might have a reference to the socket). So there is
// a small race here that we might close fd.sysfd and then some other
// goroutine might start a read of fd.sysfd (having read it before we
// write InvalidHandle to it), which might refer to some other file
// if the specific handle value gets reused. I think handle values on
// Windows are not reused as aggressively as file descriptors on Unix,
// so this might be tolerable.
if fd.closing && fd.sysfd != syscall.InvalidHandle {
// In case the user has set linger, switch to blocking mode so
// the close blocks. As long as this doesn't happen often, we
// can handle the extra OS processes. Otherwise we'll need to
// use the resultsrv for Close too. Sigh.
syscall.SetNonblock(fd.sysfd, false)
if fd.closing && fd.sysref == 0 && fd.sysfd != syscall.InvalidHandle {
closesocket(fd.sysfd)
fd.sysfd = syscall.InvalidHandle
// no need for a finalizer anymore
......@@ -329,7 +356,14 @@ func (fd *netFD) Close() error {
if err := fd.incref(true); err != nil {
return err
}
fd.decref()
defer fd.decref()
// unblock pending reader and writer
close(fd.closec)
// wait for both reader and writer to exit
fd.rio.Lock()
defer fd.rio.Unlock()
fd.wio.Lock()
defer fd.wio.Unlock()
return nil
}
......@@ -368,18 +402,12 @@ func (o *readOp) Name() string {
}
func (fd *netFD) Read(buf []byte) (int, error) {
if fd == nil {
return 0, syscall.EINVAL
}
fd.rio.Lock()
defer fd.rio.Unlock()
if err := fd.incref(false); err != nil {
return 0, err
}
defer fd.decref()
if fd.sysfd == syscall.InvalidHandle {
return 0, syscall.EINVAL
}
fd.rio.Lock()
defer fd.rio.Unlock()
var o readOp
o.Init(fd, buf, 'r')
n, err := iosrv.ExecIO(&o, fd.rdeadline)
......@@ -407,18 +435,15 @@ func (o *readFromOp) Name() string {
}
func (fd *netFD) ReadFrom(buf []byte) (n int, sa syscall.Sockaddr, err error) {
if fd == nil {
return 0, nil, syscall.EINVAL
}
if len(buf) == 0 {
return 0, nil, nil
}
fd.rio.Lock()
defer fd.rio.Unlock()
if err := fd.incref(false); err != nil {
return 0, nil, err
}
defer fd.decref()
fd.rio.Lock()
defer fd.rio.Unlock()
var o readFromOp
o.Init(fd, buf, 'r')
o.rsan = int32(unsafe.Sizeof(o.rsa))
......@@ -446,15 +471,12 @@ func (o *writeOp) Name() string {
}
func (fd *netFD) Write(buf []byte) (int, error) {
if fd == nil {
return 0, syscall.EINVAL
}
fd.wio.Lock()
defer fd.wio.Unlock()
if err := fd.incref(false); err != nil {
return 0, err
}
defer fd.decref()
fd.wio.Lock()
defer fd.wio.Unlock()
var o writeOp
o.Init(fd, buf, 'w')
return iosrv.ExecIO(&o, fd.wdeadline)
......@@ -477,21 +499,15 @@ func (o *writeToOp) Name() string {
}
func (fd *netFD) WriteTo(buf []byte, sa syscall.Sockaddr) (int, error) {
if fd == nil {
return 0, syscall.EINVAL
}
if len(buf) == 0 {
return 0, nil
}
fd.wio.Lock()
defer fd.wio.Unlock()
if err := fd.incref(false); err != nil {
return 0, err
}
defer fd.decref()
if fd.sysfd == syscall.InvalidHandle {
return 0, syscall.EINVAL
}
fd.wio.Lock()
defer fd.wio.Unlock()
var o writeToOp
o.Init(fd, buf, 'w')
o.sa = sa
......
......@@ -174,3 +174,42 @@ func TestUDPListenClose(t *testing.T) {
t.Fatal("timeout waiting for UDP close")
}
}
func TestTCPClose(t *testing.T) {
l, err := Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer l.Close()
read := func(r io.Reader) error {
var m [1]byte
_, err := r.Read(m[:])
return err
}
go func() {
c, err := Dial("tcp", l.Addr().String())
if err != nil {
t.Fatal(err)
}
go read(c)
time.Sleep(10 * time.Millisecond)
c.Close()
}()
c, err := l.Accept()
if err != nil {
t.Fatal(err)
}
defer c.Close()
for err == nil {
err = read(c)
}
if err != nil && err != io.EOF {
t.Fatal(err)
}
}
......@@ -499,6 +499,27 @@ func TestClientWriteError(t *testing.T) {
w.done <- true
}
func TestTCPClose(t *testing.T) {
once.Do(startServer)
client, err := dialHTTP()
if err != nil {
t.Fatalf("dialing: %v", err)
}
defer client.Close()
args := Args{17, 8}
var reply Reply
err = client.Call("Arith.Mul", args, &reply)
if err != nil {
t.Fatal("arith error:", err)
}
t.Logf("Arith: %d*%d=%d\n", args.A, args.B, reply)
if reply.C != args.A*args.B {
t.Errorf("Add: expected %d got %d", reply.C, args.A*args.B)
}
}
func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) {
b.StopTimer()
once.Do(startServer)
......
......@@ -48,12 +48,12 @@ func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) {
return 0, nil, false
}
c.wio.Lock()
defer c.wio.Unlock()
if err := c.incref(false); err != nil {
return 0, err, true
}
defer c.decref()
c.wio.Lock()
defer c.wio.Unlock()
var o sendfileOp
o.Init(c, 'w')
......
......@@ -146,3 +146,82 @@ func TestTimeoutAccept(t *testing.T) {
// Pass.
}
}
func TestReadWriteDeadline(t *testing.T) {
if !canCancelIO {
t.Logf("skipping test on this system")
return
}
const (
readTimeout = 100 * time.Millisecond
writeTimeout = 200 * time.Millisecond
delta = 40 * time.Millisecond
)
checkTimeout := func(command string, start time.Time, should time.Duration) {
is := time.Now().Sub(start)
d := should - is
if d < -delta || delta < d {
t.Errorf("%s timeout test failed: is=%v should=%v\n", command, is, should)
}
}
ln, err := Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("ListenTCP on :0: %v", err)
}
lnquit := make(chan bool)
go func() {
c, err := ln.Accept()
if err != nil {
t.Fatalf("Accept: %v", err)
}
defer c.Close()
lnquit <- true
}()
c, err := Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer c.Close()
start := time.Now()
err = c.SetReadDeadline(start.Add(readTimeout))
if err != nil {
t.Fatalf("SetReadDeadline: %v", err)
}
err = c.SetWriteDeadline(start.Add(writeTimeout))
if err != nil {
t.Fatalf("SetWriteDeadline: %v", err)
}
quit := make(chan bool)
go func() {
var buf [10]byte
_, err = c.Read(buf[:])
if err == nil {
t.Errorf("Read should not succeed")
}
checkTimeout("Read", start, readTimeout)
quit <- true
}()
go func() {
var buf [10000]byte
for {
_, err = c.Write(buf[:])
if err != nil {
break
}
}
checkTimeout("Write", start, writeTimeout)
quit <- true
}()
<-quit
<-quit
<-lnquit
}
......@@ -142,6 +142,7 @@ func NewCallback(fn interface{}) uintptr
//sys GetQueuedCompletionStatus(cphandle Handle, qty *uint32, key *uint32, overlapped **Overlapped, timeout uint32) (err error)
//sys PostQueuedCompletionStatus(cphandle Handle, qty uint32, key uint32, overlapped *Overlapped) (err error)
//sys CancelIo(s Handle) (err error)
//sys CancelIoEx(s Handle, o *Overlapped) (err error)
//sys CreateProcess(appName *uint16, commandLine *uint16, procSecurity *SecurityAttributes, threadSecurity *SecurityAttributes, inheritHandles bool, creationFlags uint32, env *uint16, currentDir *uint16, startupInfo *StartupInfo, outProcInfo *ProcessInformation) (err error) = CreateProcessW
//sys OpenProcess(da uint32, inheritHandle bool, pid uint32) (handle Handle, err error)
//sys TerminateProcess(handle Handle, exitcode uint32) (err error)
......@@ -474,6 +475,10 @@ func Chmod(path string, mode uint32) (err error) {
return SetFileAttributes(p, attrs)
}
func LoadCancelIoEx() error {
return procCancelIoEx.Find()
}
// net api calls
const socket_error = uintptr(^uint32(0))
......
......@@ -49,6 +49,7 @@ var (
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
procPostQueuedCompletionStatus = modkernel32.NewProc("PostQueuedCompletionStatus")
procCancelIo = modkernel32.NewProc("CancelIo")
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
procCreateProcessW = modkernel32.NewProc("CreateProcessW")
procOpenProcess = modkernel32.NewProc("OpenProcess")
procTerminateProcess = modkernel32.NewProc("TerminateProcess")
......@@ -535,6 +536,18 @@ func CancelIo(s Handle) (err error) {
return
}
func CancelIoEx(s Handle, o *Overlapped) (err error) {
r1, _, e1 := Syscall(procCancelIoEx.Addr(), 2, uintptr(s), uintptr(unsafe.Pointer(o)), 0)
if r1 == 0 {
if e1 != 0 {
err = error(e1)
} else {
err = EINVAL
}
}
return
}
func CreateProcess(appName *uint16, commandLine *uint16, procSecurity *SecurityAttributes, threadSecurity *SecurityAttributes, inheritHandles bool, creationFlags uint32, env *uint16, currentDir *uint16, startupInfo *StartupInfo, outProcInfo *ProcessInformation) (err error) {
var _p0 uint32
if inheritHandles {
......
......@@ -49,6 +49,7 @@ var (
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
procPostQueuedCompletionStatus = modkernel32.NewProc("PostQueuedCompletionStatus")
procCancelIo = modkernel32.NewProc("CancelIo")
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
procCreateProcessW = modkernel32.NewProc("CreateProcessW")
procOpenProcess = modkernel32.NewProc("OpenProcess")
procTerminateProcess = modkernel32.NewProc("TerminateProcess")
......@@ -535,6 +536,18 @@ func CancelIo(s Handle) (err error) {
return
}
func CancelIoEx(s Handle, o *Overlapped) (err error) {
r1, _, e1 := Syscall(procCancelIoEx.Addr(), 2, uintptr(s), uintptr(unsafe.Pointer(o)), 0)
if r1 == 0 {
if e1 != 0 {
err = error(e1)
} else {
err = EINVAL
}
}
return
}
func CreateProcess(appName *uint16, commandLine *uint16, procSecurity *SecurityAttributes, threadSecurity *SecurityAttributes, inheritHandles bool, creationFlags uint32, env *uint16, currentDir *uint16, startupInfo *StartupInfo, outProcInfo *ProcessInformation) (err error) {
var _p0 uint32
if inheritHandles {
......
......@@ -20,6 +20,7 @@ const (
ERROR_ENVVAR_NOT_FOUND Errno = 203
ERROR_OPERATION_ABORTED Errno = 995
ERROR_IO_PENDING Errno = 997
ERROR_NOT_FOUND Errno = 1168
)
const (
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment