Commit 059c68bf authored by Alex Brainman's avatar Alex Brainman

net: implement ip protocol name to number resolver for windows

Fixes #2215.
Fixes #2216.

R=golang-dev, dave, rsc
CC=golang-dev
https://golang.org/cl/5248055
parent d69b820e
......@@ -11,6 +11,7 @@ import (
"bytes"
"flag"
"os"
"runtime"
"testing"
)
......@@ -64,7 +65,7 @@ var dsthost = flag.String("dsthost", "127.0.0.1", "Destination for the ICMP ECHO
// test (raw) IP socket using ICMP
func TestICMP(t *testing.T) {
if os.Getuid() != 0 {
if runtime.GOOS != "windows" && os.Getuid() != 0 {
t.Logf("test disabled; must be root")
return
}
......
......@@ -10,12 +10,9 @@ package net
import (
"os"
"sync"
"syscall"
)
var onceReadProtocols sync.Once
func sockaddrToIP(sa syscall.Sockaddr) Addr {
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
......@@ -209,33 +206,7 @@ func (c *IPConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) {
return c.WriteToIP(b, a)
}
var protocols map[string]int
func readProtocols() {
protocols = make(map[string]int)
if file, err := open("/etc/protocols"); err == nil {
for line, ok := file.readLine(); ok; line, ok = file.readLine() {
// tcp 6 TCP # transmission control protocol
if i := byteIndex(line, '#'); i >= 0 {
line = line[0:i]
}
f := getFields(line)
if len(f) < 2 {
continue
}
if proto, _, ok := dtoi(f[1], 0); ok {
protocols[f[0]] = proto
for _, alias := range f[2:] {
protocols[alias] = proto
}
}
}
file.close()
}
}
func splitNetProto(netProto string) (net string, proto int, err os.Error) {
onceReadProtocols.Do(readProtocols)
i := last(netProto, ':')
if i < 0 { // no colon
return "", 0, os.NewError("no IP protocol specified")
......@@ -244,13 +215,12 @@ func splitNetProto(netProto string) (net string, proto int, err os.Error) {
protostr := netProto[i+1:]
proto, i, ok := dtoi(protostr, 0)
if !ok || i != len(protostr) {
// lookup by name
proto, ok = protocols[protostr]
if ok {
return
proto, err = lookupProtocol(protostr)
if err != nil {
return "", 0, err
}
}
return
return net, proto, nil
}
// DialIP connects to the remote address raddr on the network net,
......
......@@ -8,8 +8,50 @@ package net
import (
"os"
"sync"
)
var (
protocols map[string]int
onceReadProtocols sync.Once
)
// readProtocols loads contents of /etc/protocols into protocols map
// for quick access.
func readProtocols() {
protocols = make(map[string]int)
if file, err := open("/etc/protocols"); err == nil {
for line, ok := file.readLine(); ok; line, ok = file.readLine() {
// tcp 6 TCP # transmission control protocol
if i := byteIndex(line, '#'); i >= 0 {
line = line[0:i]
}
f := getFields(line)
if len(f) < 2 {
continue
}
if proto, _, ok := dtoi(f[1], 0); ok {
protocols[f[0]] = proto
for _, alias := range f[2:] {
protocols[alias] = proto
}
}
}
file.close()
}
}
// lookupProtocol looks up IP protocol name in /etc/protocols and
// returns correspondent protocol number.
func lookupProtocol(name string) (proto int, err os.Error) {
onceReadProtocols.Do(readProtocols)
proto, found := protocols[name]
if !found {
return 0, os.NewError("unknown IP protocol specified: " + name)
}
return
}
// LookupHost looks up the given host using the local resolver.
// It returns an array of that host's addresses.
func LookupHost(host string) (addrs []string, err os.Error) {
......
......@@ -11,8 +11,22 @@ import (
"sync"
)
var hostentLock sync.Mutex
var serventLock sync.Mutex
var (
protoentLock sync.Mutex
hostentLock sync.Mutex
serventLock sync.Mutex
)
// lookupProtocol looks up IP protocol name and returns correspondent protocol number.
func lookupProtocol(name string) (proto int, err os.Error) {
protoentLock.Lock()
defer protoentLock.Unlock()
p, e := syscall.GetProtoByName(name)
if e != 0 {
return 0, os.NewSyscallError("GetProtoByName", e)
}
return int(p.Proto), nil
}
func LookupHost(name string) (addrs []string, err os.Error) {
ips, err := LookupIP(name)
......
......@@ -502,6 +502,7 @@ func Chmod(path string, mode uint32) (errno int) {
//sys GetHostByName(name string) (h *Hostent, errno int) [failretval==nil] = ws2_32.gethostbyname
//sys GetServByName(name string, proto string) (s *Servent, errno int) [failretval==nil] = ws2_32.getservbyname
//sys Ntohs(netshort uint16) (u uint16) = ws2_32.ntohs
//sys GetProtoByName(name string) (p *Protoent, errno int) [failretval==nil] = ws2_32.getprotobyname
//sys DnsQuery(name string, qtype uint16, options uint32, extra *byte, qrs **DNSRecord, pr *byte) (status uint32) = dnsapi.DnsQuery_W
//sys DnsRecordListFree(rl *DNSRecord, freetype uint32) = dnsapi.DnsRecordListFree
//sys GetIfEntry(pIfRow *MibIfRow) (errcode int) = iphlpapi.GetIfEntry
......
......@@ -101,6 +101,7 @@ var (
procgethostbyname = modws2_32.NewProc("gethostbyname")
procgetservbyname = modws2_32.NewProc("getservbyname")
procntohs = modws2_32.NewProc("ntohs")
procgetprotobyname = modws2_32.NewProc("getprotobyname")
procDnsQuery_W = moddnsapi.NewProc("DnsQuery_W")
procDnsRecordListFree = moddnsapi.NewProc("DnsRecordListFree")
procGetIfEntry = modiphlpapi.NewProc("GetIfEntry")
......@@ -1314,6 +1315,21 @@ func Ntohs(netshort uint16) (u uint16) {
return
}
func GetProtoByName(name string) (p *Protoent, errno int) {
r0, _, e1 := Syscall(procgetprotobyname.Addr(), 1, uintptr(unsafe.Pointer(StringBytePtr(name))), 0, 0)
p = (*Protoent)(unsafe.Pointer(r0))
if p == nil {
if e1 != 0 {
errno = int(e1)
} else {
errno = EINVAL
}
} else {
errno = 0
}
return
}
func DnsQuery(name string, qtype uint16, options uint32, extra *byte, qrs **DNSRecord, pr *byte) (status uint32) {
r0, _, _ := Syscall6(procDnsQuery_W.Addr(), 6, uintptr(unsafe.Pointer(StringToUTF16Ptr(name))), uintptr(qtype), uintptr(options), uintptr(unsafe.Pointer(extra)), uintptr(unsafe.Pointer(qrs)), uintptr(unsafe.Pointer(pr)))
status = uint32(r0)
......
......@@ -101,6 +101,7 @@ var (
procgethostbyname = modws2_32.NewProc("gethostbyname")
procgetservbyname = modws2_32.NewProc("getservbyname")
procntohs = modws2_32.NewProc("ntohs")
procgetprotobyname = modws2_32.NewProc("getprotobyname")
procDnsQuery_W = moddnsapi.NewProc("DnsQuery_W")
procDnsRecordListFree = moddnsapi.NewProc("DnsRecordListFree")
procGetIfEntry = modiphlpapi.NewProc("GetIfEntry")
......@@ -1314,6 +1315,21 @@ func Ntohs(netshort uint16) (u uint16) {
return
}
func GetProtoByName(name string) (p *Protoent, errno int) {
r0, _, e1 := Syscall(procgetprotobyname.Addr(), 1, uintptr(unsafe.Pointer(StringBytePtr(name))), 0, 0)
p = (*Protoent)(unsafe.Pointer(r0))
if p == nil {
if e1 != 0 {
errno = int(e1)
} else {
errno = EINVAL
}
} else {
errno = 0
}
return
}
func DnsQuery(name string, qtype uint16, options uint32, extra *byte, qrs **DNSRecord, pr *byte) (status uint32) {
r0, _, _ := Syscall6(procDnsQuery_W.Addr(), 6, uintptr(unsafe.Pointer(StringToUTF16Ptr(name))), uintptr(qtype), uintptr(options), uintptr(unsafe.Pointer(extra)), uintptr(unsafe.Pointer(qrs)), uintptr(unsafe.Pointer(pr)))
status = uint32(r0)
......
......@@ -411,6 +411,12 @@ type Hostent struct {
AddrList **byte
}
type Protoent struct {
Name *byte
Aliases **byte
Proto uint16
}
const (
DNS_TYPE_A = 0x0001
DNS_TYPE_NS = 0x0002
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment