Commit e49bc465 authored by Aman Gupta's avatar Aman Gupta Committed by Alex Brainman

net: implement ReadMsg/WriteMsg on windows

This means {Read,Write}Msg{UDP,IP} now work on windows.

Fixes #9252

Change-Id: Ifb105f9ad18d61289b22d7358a95faabe73d2d02
Reviewed-on: https://go-review.googlesource.com/76393
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: 's avatarAlex Brainman <alex.brainman@gmail.com>
parent 5d0cab03
......@@ -154,7 +154,7 @@ var pkgDeps = map[string][]string{
"syscall",
},
"internal/poll": {"L0", "internal/race", "syscall", "time", "unicode/utf16", "unicode/utf8"},
"internal/poll": {"L0", "internal/race", "syscall", "time", "unicode/utf16", "unicode/utf8", "internal/syscall/windows"},
"os": {"L1", "os", "syscall", "time", "internal/poll", "internal/syscall/windows"},
"path/filepath": {"L2", "os", "syscall", "internal/syscall/windows"},
"io/ioutil": {"L2", "os", "path/filepath", "time"},
......
......@@ -7,6 +7,7 @@ package poll
import (
"errors"
"internal/race"
"internal/syscall/windows"
"io"
"runtime"
"sync"
......@@ -92,6 +93,7 @@ type operation struct {
fd *FD
errc chan error
buf syscall.WSABuf
msg windows.WSAMsg
sa syscall.Sockaddr
rsa *syscall.RawSockaddrAny
rsan int32
......@@ -132,6 +134,22 @@ func (o *operation) ClearBufs() {
o.bufs = o.bufs[:0]
}
func (o *operation) InitMsg(p []byte, oob []byte) {
o.InitBuf(p)
o.msg.Buffers = &o.buf
o.msg.BufferCount = 1
o.msg.Name = nil
o.msg.Namelen = 0
o.msg.Flags = 0
o.msg.Control.Len = uint32(len(oob))
o.msg.Control.Buf = nil
if len(oob) != 0 {
o.msg.Control.Buf = &oob[0]
}
}
// ioSrv executes net IO requests.
type ioSrv struct {
req chan ioSrvReq
......@@ -898,3 +916,77 @@ func (fd *FD) RawRead(f func(uintptr) bool) error {
func (fd *FD) RawWrite(f func(uintptr) bool) error {
return errors.New("not implemented")
}
func sockaddrToRaw(sa syscall.Sockaddr) (unsafe.Pointer, int32, error) {
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
var raw syscall.RawSockaddrInet4
raw.Family = syscall.AF_INET
p := (*[2]byte)(unsafe.Pointer(&raw.Port))
p[0] = byte(sa.Port >> 8)
p[1] = byte(sa.Port)
for i := 0; i < len(sa.Addr); i++ {
raw.Addr[i] = sa.Addr[i]
}
return unsafe.Pointer(&raw), int32(unsafe.Sizeof(raw)), nil
case *syscall.SockaddrInet6:
var raw syscall.RawSockaddrInet6
raw.Family = syscall.AF_INET6
p := (*[2]byte)(unsafe.Pointer(&raw.Port))
p[0] = byte(sa.Port >> 8)
p[1] = byte(sa.Port)
raw.Scope_id = sa.ZoneId
for i := 0; i < len(sa.Addr); i++ {
raw.Addr[i] = sa.Addr[i]
}
return unsafe.Pointer(&raw), int32(unsafe.Sizeof(raw)), nil
default:
return nil, 0, syscall.EWINDOWS
}
}
// ReadMsg wraps the WSARecvMsg network call.
func (fd *FD) ReadMsg(p []byte, oob []byte) (int, int, int, syscall.Sockaddr, error) {
if err := fd.readLock(); err != nil {
return 0, 0, 0, nil, err
}
defer fd.readUnlock()
o := &fd.rop
o.InitMsg(p, oob)
o.rsa = new(syscall.RawSockaddrAny)
o.msg.Name = o.rsa
o.msg.Namelen = int32(unsafe.Sizeof(*o.rsa))
n, err := rsrv.ExecIO(o, func(o *operation) error {
return windows.WSARecvMsg(o.fd.Sysfd, &o.msg, &o.qty, &o.o, nil)
})
err = fd.eofError(n, err)
var sa syscall.Sockaddr
if err == nil {
sa, err = o.rsa.Sockaddr()
}
return n, int(o.msg.Control.Len), int(o.msg.Flags), sa, err
}
// WriteMsg wraps the WSASendMsg network call.
func (fd *FD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (int, int, error) {
if err := fd.writeLock(); err != nil {
return 0, 0, err
}
defer fd.writeUnlock()
o := &fd.wop
o.InitMsg(p, oob)
if sa != nil {
rsa, len, err := sockaddrToRaw(sa)
if err != nil {
return 0, 0, err
}
o.msg.Name = (*syscall.RawSockaddrAny)(rsa)
o.msg.Namelen = len
}
n, err := wsrv.ExecIO(o, func(o *operation) error {
return windows.WSASendMsg(o.fd.Sysfd, &o.msg, 0, &o.qty, &o.o, nil)
})
return n, int(o.msg.Control.Len), err
}
......@@ -4,7 +4,11 @@
package windows
import "syscall"
import (
"sync"
"syscall"
"unsafe"
)
const (
ERROR_SHARING_VIOLATION syscall.Errno = 32
......@@ -115,10 +119,109 @@ const (
const (
WSA_FLAG_OVERLAPPED = 0x01
WSA_FLAG_NO_HANDLE_INHERIT = 0x80
WSAEMSGSIZE syscall.Errno = 10040
MSG_TRUNC = 0x0100
MSG_CTRUNC = 0x0200
socket_error = uintptr(^uint32(0))
)
var WSAID_WSASENDMSG = syscall.GUID{
Data1: 0xa441e712,
Data2: 0x754f,
Data3: 0x43ca,
Data4: [8]byte{0x84, 0xa7, 0x0d, 0xee, 0x44, 0xcf, 0x60, 0x6d},
}
var WSAID_WSARECVMSG = syscall.GUID{
Data1: 0xf689d7c8,
Data2: 0x6f1f,
Data3: 0x436b,
Data4: [8]byte{0x8a, 0x53, 0xe5, 0x4f, 0xe3, 0x51, 0xc3, 0x22},
}
var sendRecvMsgFunc struct {
once sync.Once
sendAddr uintptr
recvAddr uintptr
err error
}
type WSAMsg struct {
Name *syscall.RawSockaddrAny
Namelen int32
Buffers *syscall.WSABuf
BufferCount uint32
Control syscall.WSABuf
Flags uint32
}
//sys WSASocket(af int32, typ int32, protocol int32, protinfo *syscall.WSAProtocolInfo, group uint32, flags uint32) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = ws2_32.WSASocketW
func loadWSASendRecvMsg() error {
sendRecvMsgFunc.once.Do(func() {
var s syscall.Handle
s, sendRecvMsgFunc.err = syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP)
if sendRecvMsgFunc.err != nil {
return
}
defer syscall.CloseHandle(s)
var n uint32
sendRecvMsgFunc.err = syscall.WSAIoctl(s,
syscall.SIO_GET_EXTENSION_FUNCTION_POINTER,
(*byte)(unsafe.Pointer(&WSAID_WSARECVMSG)),
uint32(unsafe.Sizeof(WSAID_WSARECVMSG)),
(*byte)(unsafe.Pointer(&sendRecvMsgFunc.recvAddr)),
uint32(unsafe.Sizeof(sendRecvMsgFunc.recvAddr)),
&n, nil, 0)
if sendRecvMsgFunc.err != nil {
return
}
sendRecvMsgFunc.err = syscall.WSAIoctl(s,
syscall.SIO_GET_EXTENSION_FUNCTION_POINTER,
(*byte)(unsafe.Pointer(&WSAID_WSASENDMSG)),
uint32(unsafe.Sizeof(WSAID_WSASENDMSG)),
(*byte)(unsafe.Pointer(&sendRecvMsgFunc.sendAddr)),
uint32(unsafe.Sizeof(sendRecvMsgFunc.sendAddr)),
&n, nil, 0)
})
return sendRecvMsgFunc.err
}
func WSASendMsg(fd syscall.Handle, msg *WSAMsg, flags uint32, bytesSent *uint32, overlapped *syscall.Overlapped, croutine *byte) error {
err := loadWSASendRecvMsg()
if err != nil {
return err
}
r1, _, e1 := syscall.Syscall6(sendRecvMsgFunc.sendAddr, 6, uintptr(fd), uintptr(unsafe.Pointer(msg)), uintptr(flags), uintptr(unsafe.Pointer(bytesSent)), uintptr(unsafe.Pointer(overlapped)), uintptr(unsafe.Pointer(croutine)))
if r1 == socket_error {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return err
}
func WSARecvMsg(fd syscall.Handle, msg *WSAMsg, bytesReceived *uint32, overlapped *syscall.Overlapped, croutine *byte) error {
err := loadWSASendRecvMsg()
if err != nil {
return err
}
r1, _, e1 := syscall.Syscall6(sendRecvMsgFunc.recvAddr, 5, uintptr(fd), uintptr(unsafe.Pointer(msg)), uintptr(unsafe.Pointer(bytesReceived)), uintptr(unsafe.Pointer(overlapped)), uintptr(unsafe.Pointer(croutine)), 0)
if r1 == socket_error {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return err
}
const (
ComputerNameNetBIOS = 0
ComputerNameDnsHostname = 1
......
......@@ -223,17 +223,21 @@ func (fd *netFD) accept() (*netFD, error) {
return netfd, nil
}
func (fd *netFD) readMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) {
n, oobn, flags, sa, err = fd.pfd.ReadMsg(p, oob)
runtime.KeepAlive(fd)
return n, oobn, flags, sa, wrapSyscallError("wsarecvmsg", err)
}
func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
n, oobn, err = fd.pfd.WriteMsg(p, oob, sa)
runtime.KeepAlive(fd)
return n, oobn, wrapSyscallError("wsasendmsg", err)
}
// Unimplemented functions.
func (fd *netFD) dup() (*os.File, error) {
// TODO: Implement this
return nil, syscall.EWINDOWS
}
func (fd *netFD) readMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) {
return 0, 0, 0, nil, syscall.EWINDOWS
}
func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
return 0, 0, syscall.EWINDOWS
}
......@@ -21,7 +21,7 @@ import (
// change the behavior of these methods; use Read or ReadMsgIP
// instead.
// BUG(mikio): On NaCl, Plan 9 and Windows, the ReadMsgIP and
// BUG(mikio): On NaCl and Plan 9, the ReadMsgIP and
// WriteMsgIP methods of IPConn are not implemented.
// BUG(mikio): On Windows, the File method of IPConn is not
......
......@@ -149,12 +149,18 @@ func testableListenArgs(network, address, client string) bool {
return true
}
var condFatalf = func() func(*testing.T, string, ...interface{}) {
// A few APIs, File, Read/WriteMsg{UDP,IP}, are not
// implemented yet on both Plan 9 and Windows.
func condFatalf(t *testing.T, api string, format string, args ...interface{}) {
// A few APIs like File and Read/WriteMsg{UDP,IP} are not
// fully implemented yet on Plan 9 and Windows.
switch runtime.GOOS {
case "plan9", "windows":
return (*testing.T).Logf
case "windows":
if api == "file" {
t.Logf(format, args...)
return
}
case "plan9":
t.Logf(format, args...)
return
}
return (*testing.T).Fatalf
}()
t.Fatalf(format, args...)
}
......@@ -54,7 +54,7 @@ func TestTCPListenerSpecificMethods(t *testing.T) {
}
if f, err := ln.File(); err != nil {
condFatalf(t, "%v", err)
condFatalf(t, "file", "%v", err)
} else {
f.Close()
}
......@@ -139,14 +139,14 @@ func TestUDPConnSpecificMethods(t *testing.T) {
t.Fatal(err)
}
if _, _, err := c.WriteMsgUDP(wb, nil, c.LocalAddr().(*UDPAddr)); err != nil {
condFatalf(t, "%v", err)
condFatalf(t, "udp", "%v", err)
}
if _, _, _, _, err := c.ReadMsgUDP(rb, nil); err != nil {
condFatalf(t, "%v", err)
condFatalf(t, "udp", "%v", err)
}
if f, err := c.File(); err != nil {
condFatalf(t, "%v", err)
condFatalf(t, "file", "%v", err)
} else {
f.Close()
}
......@@ -184,7 +184,7 @@ func TestIPConnSpecificMethods(t *testing.T) {
c.SetWriteBuffer(2048)
if f, err := c.File(); err != nil {
condFatalf(t, "%v", err)
condFatalf(t, "file", "%v", err)
} else {
f.Close()
}
......
......@@ -9,7 +9,7 @@ import (
"syscall"
)
// BUG(mikio): On NaCl, Plan 9 and Windows, the ReadMsgUDP and
// BUG(mikio): On NaCl and Plan 9, the ReadMsgUDP and
// WriteMsgUDP methods of UDPConn are not implemented.
// BUG(mikio): On Windows, the File method of UDPConn is not
......
......@@ -161,7 +161,7 @@ func testWriteToConn(t *testing.T, raddr string) {
}
_, _, err = c.(*UDPConn).WriteMsgUDP(b, nil, nil)
switch runtime.GOOS {
case "nacl", "windows": // see golang.org/issue/9252
case "nacl": // see golang.org/issue/9252
t.Skipf("not implemented yet on %s", runtime.GOOS)
default:
if err != nil {
......@@ -204,7 +204,7 @@ func testWriteToPacketConn(t *testing.T, raddr string) {
}
_, _, err = c.(*UDPConn).WriteMsgUDP(b, nil, ra)
switch runtime.GOOS {
case "nacl", "windows": // see golang.org/issue/9252
case "nacl": // see golang.org/issue/9252
t.Skipf("not implemented yet on %s", runtime.GOOS)
default:
if err != nil {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment