Commit a4c73ec9 authored by Mikio Hara's avatar Mikio Hara

icmp: use subtests

This change reorganizes the existing test cases for supporting new
messages and extension in upcoming CLs. It also renames ping_test.go
to diag_test.go for clarification.

Updates golang/go#24440.

Change-Id: I5753a2ddbe76423d996a37f583fcf32b65d380e9
Reviewed-on: https://go-review.googlesource.com/63995Reviewed-by: 's avatarBrad Fitzpatrick <bradfitz@golang.org>
parent 24dd3780
......@@ -21,102 +21,102 @@ import (
"golang.org/x/net/ipv6"
)
func googleAddr(c *icmp.PacketConn, protocol int) (net.Addr, error) {
const host = "www.google.com"
ips, err := net.LookupIP(host)
if err != nil {
return nil, err
}
netaddr := func(ip net.IP) (net.Addr, error) {
switch c.LocalAddr().(type) {
case *net.UDPAddr:
return &net.UDPAddr{IP: ip}, nil
case *net.IPAddr:
return &net.IPAddr{IP: ip}, nil
default:
return nil, errors.New("neither UDPAddr nor IPAddr")
}
}
for _, ip := range ips {
switch protocol {
case iana.ProtocolICMP:
if ip.To4() != nil {
return netaddr(ip)
}
case iana.ProtocolIPv6ICMP:
if ip.To16() != nil && ip.To4() == nil {
return netaddr(ip)
}
}
}
return nil, errors.New("no A or AAAA record")
}
type pingTest struct {
type diagTest struct {
network, address string
protocol int
mtype icmp.Type
}
var nonPrivilegedPingTests = []pingTest{
{"udp4", "0.0.0.0", iana.ProtocolICMP, ipv4.ICMPTypeEcho},
{"udp6", "::", iana.ProtocolIPv6ICMP, ipv6.ICMPTypeEchoRequest},
m icmp.Message
}
func TestNonPrivilegedPing(t *testing.T) {
func TestDiag(t *testing.T) {
if testing.Short() {
t.Skip("avoid external network")
}
t.Run("Ping/NonPrivileged", func(t *testing.T) {
switch runtime.GOOS {
case "darwin":
case "linux":
t.Log("you may need to adjust the net.ipv4.ping_group_range kernel state")
default:
t.Skipf("not supported on %s", runtime.GOOS)
t.Logf("not supported on %s", runtime.GOOS)
return
}
for i, dt := range []diagTest{
{
"udp4", "0.0.0.0", iana.ProtocolICMP,
icmp.Message{
Type: ipv4.ICMPTypeEcho, Code: 0,
Body: &icmp.Echo{
ID: os.Getpid() & 0xffff,
Data: []byte("HELLO-R-U-THERE"),
},
},
},
for i, tt := range nonPrivilegedPingTests {
if err := doPing(tt, i); err != nil {
{
"udp6", "::", iana.ProtocolIPv6ICMP,
icmp.Message{
Type: ipv6.ICMPTypeEchoRequest, Code: 0,
Body: &icmp.Echo{
ID: os.Getpid() & 0xffff,
Data: []byte("HELLO-R-U-THERE"),
},
},
},
} {
if err := doDiag(dt, i); err != nil {
t.Error(err)
}
}
}
var privilegedPingTests = []pingTest{
{"ip4:icmp", "0.0.0.0", iana.ProtocolICMP, ipv4.ICMPTypeEcho},
{"ip6:ipv6-icmp", "::", iana.ProtocolIPv6ICMP, ipv6.ICMPTypeEchoRequest},
}
func TestPrivilegedPing(t *testing.T) {
if testing.Short() {
t.Skip("avoid external network")
}
})
t.Run("Ping/Privileged", func(t *testing.T) {
if m, ok := nettest.SupportsRawIPSocket(); !ok {
t.Skip(m)
t.Log(m)
return
}
for i, dt := range []diagTest{
{
"ip4:icmp", "0.0.0.0", iana.ProtocolICMP,
icmp.Message{
Type: ipv4.ICMPTypeEcho, Code: 0,
Body: &icmp.Echo{
ID: os.Getpid() & 0xffff,
Data: []byte("HELLO-R-U-THERE"),
},
},
},
for i, tt := range privilegedPingTests {
if err := doPing(tt, i); err != nil {
{
"ip6:ipv6-icmp", "::", iana.ProtocolIPv6ICMP,
icmp.Message{
Type: ipv6.ICMPTypeEchoRequest, Code: 0,
Body: &icmp.Echo{
ID: os.Getpid() & 0xffff,
Data: []byte("HELLO-R-U-THERE"),
},
},
},
} {
if err := doDiag(dt, i); err != nil {
t.Error(err)
}
}
})
}
func doPing(tt pingTest, seq int) error {
c, err := icmp.ListenPacket(tt.network, tt.address)
func doDiag(dt diagTest, seq int) error {
c, err := icmp.ListenPacket(dt.network, dt.address)
if err != nil {
return err
}
defer c.Close()
dst, err := googleAddr(c, tt.protocol)
dst, err := googleAddr(c, dt.protocol)
if err != nil {
return err
}
if tt.network != "udp6" && tt.protocol == iana.ProtocolIPv6ICMP {
if dt.network != "udp6" && dt.protocol == iana.ProtocolIPv6ICMP {
var f ipv6.ICMPFilter
f.SetAll(true)
f.Accept(ipv6.ICMPTypeDestinationUnreachable)
......@@ -129,14 +129,11 @@ func doPing(tt pingTest, seq int) error {
}
}
wm := icmp.Message{
Type: tt.mtype, Code: 0,
Body: &icmp.Echo{
ID: os.Getpid() & 0xffff, Seq: 1 << uint(seq),
Data: []byte("HELLO-R-U-THERE"),
},
switch m := dt.m.Body.(type) {
case *icmp.Echo:
m.Seq = 1 << uint(seq)
}
wb, err := wm.Marshal(nil)
wb, err := dt.m.Marshal(nil)
if err != nil {
return err
}
......@@ -154,18 +151,45 @@ func doPing(tt pingTest, seq int) error {
if err != nil {
return err
}
rm, err := icmp.ParseMessage(tt.protocol, rb[:n])
rm, err := icmp.ParseMessage(dt.protocol, rb[:n])
if err != nil {
return err
}
switch rm.Type {
case ipv4.ICMPTypeEchoReply, ipv6.ICMPTypeEchoReply:
switch {
case dt.m.Type == ipv4.ICMPTypeEcho && rm.Type == ipv4.ICMPTypeEchoReply:
fallthrough
case dt.m.Type == ipv6.ICMPTypeEchoRequest && rm.Type == ipv6.ICMPTypeEchoReply:
return nil
default:
return fmt.Errorf("got %+v from %v; want echo reply", rm, peer)
}
}
func googleAddr(c *icmp.PacketConn, protocol int) (net.Addr, error) {
host := "ipv4.google.com"
if protocol == iana.ProtocolIPv6ICMP {
host = "ipv6.google.com"
}
ips, err := net.LookupIP(host)
if err != nil {
return nil, err
}
netaddr := func(ip net.IP) (net.Addr, error) {
switch c.LocalAddr().(type) {
case *net.UDPAddr:
return &net.UDPAddr{IP: ip}, nil
case *net.IPAddr:
return &net.IPAddr{IP: ip}, nil
default:
return nil, errors.New("neither UDPAddr nor IPAddr")
}
}
if len(ips) > 0 {
return netaddr(ips[0])
}
return nil, errors.New("no A or AAAA record")
}
func TestConcurrentNonPrivilegedListenPacket(t *testing.T) {
if testing.Short() {
t.Skip("avoid external network")
......
......@@ -5,6 +5,7 @@
package icmp
import (
"fmt"
"net"
"reflect"
"testing"
......@@ -12,12 +13,60 @@ import (
"golang.org/x/net/internal/iana"
)
var marshalAndParseExtensionTests = []struct {
func TestMarshalAndParseExtension(t *testing.T) {
fn := func(t *testing.T, proto int, hdr, obj []byte, te Extension) error {
b, err := te.Marshal(proto)
if err != nil {
return err
}
if !reflect.DeepEqual(b, obj) {
return fmt.Errorf("got %#v; want %#v", b, obj)
}
for i, wire := range []struct {
data []byte // original datagram
inlattr int // length of padded original datagram, a hint
outlattr int // length of padded original datagram, a want
err error
}{
{nil, 0, -1, errNoExtension},
{make([]byte, 127), 128, -1, errNoExtension},
{make([]byte, 128), 127, -1, errNoExtension},
{make([]byte, 128), 128, -1, errNoExtension},
{make([]byte, 128), 129, -1, errNoExtension},
{append(make([]byte, 128), append(hdr, obj...)...), 127, 128, nil},
{append(make([]byte, 128), append(hdr, obj...)...), 128, 128, nil},
{append(make([]byte, 128), append(hdr, obj...)...), 129, 128, nil},
{append(make([]byte, 512), append(hdr, obj...)...), 511, -1, errNoExtension},
{append(make([]byte, 512), append(hdr, obj...)...), 512, 512, nil},
{append(make([]byte, 512), append(hdr, obj...)...), 513, -1, errNoExtension},
} {
exts, l, err := parseExtensions(wire.data, wire.inlattr)
if err != wire.err {
return fmt.Errorf("#%d: got %v; want %v", i, err, wire.err)
}
if wire.err != nil {
continue
}
if l != wire.outlattr {
return fmt.Errorf("#%d: got %d; want %d", i, l, wire.outlattr)
}
if !reflect.DeepEqual(exts, []Extension{te}) {
return fmt.Errorf("#%d: got %#v; want %#v", i, exts[0], te)
}
}
return nil
}
t.Run("MPLSLabelStack", func(t *testing.T) {
for _, et := range []struct {
proto int
hdr []byte
obj []byte
exts []Extension
}{
ext Extension
}{
// MPLS label stack with no label
{
proto: iana.ProtocolICMP,
......@@ -27,13 +76,11 @@ var marshalAndParseExtensionTests = []struct {
obj: []byte{
0x00, 0x04, 0x01, 0x01,
},
exts: []Extension{
&MPLSLabelStack{
ext: &MPLSLabelStack{
Class: classMPLSLabelStack,
Type: typeIncomingMPLSLabelStack,
},
},
},
// MPLS label stack with a single label
{
proto: iana.ProtocolIPv6ICMP,
......@@ -44,8 +91,7 @@ var marshalAndParseExtensionTests = []struct {
0x00, 0x08, 0x01, 0x01,
0x03, 0xe8, 0xe9, 0xff,
},
exts: []Extension{
&MPLSLabelStack{
ext: &MPLSLabelStack{
Class: classMPLSLabelStack,
Type: typeIncomingMPLSLabelStack,
Labels: []MPLSLabel{
......@@ -58,7 +104,6 @@ var marshalAndParseExtensionTests = []struct {
},
},
},
},
// MPLS label stack with multiple labels
{
proto: iana.ProtocolICMP,
......@@ -70,8 +115,7 @@ var marshalAndParseExtensionTests = []struct {
0x03, 0xe8, 0xde, 0xfe,
0x03, 0xe8, 0xe1, 0xff,
},
exts: []Extension{
&MPLSLabelStack{
ext: &MPLSLabelStack{
Class: classMPLSLabelStack,
Type: typeIncomingMPLSLabelStack,
Labels: []MPLSLabel{
......@@ -90,7 +134,19 @@ var marshalAndParseExtensionTests = []struct {
},
},
},
},
} {
if err := fn(t, et.proto, et.hdr, et.obj, et.ext); err != nil {
t.Error(err)
}
}
})
t.Run("InterfaceInfo", func(t *testing.T) {
for _, et := range []struct {
proto int
hdr []byte
obj []byte
ext Extension
}{
// Interface information with no attribute
{
proto: iana.ProtocolICMP,
......@@ -100,12 +156,10 @@ var marshalAndParseExtensionTests = []struct {
obj: []byte{
0x00, 0x04, 0x02, 0x00,
},
exts: []Extension{
&InterfaceInfo{
ext: &InterfaceInfo{
Class: classInterfaceInfo,
},
},
},
// Interface information with ifIndex and name
{
proto: iana.ProtocolICMP,
......@@ -118,8 +172,7 @@ var marshalAndParseExtensionTests = []struct {
0x08, byte('e'), byte('n'), byte('1'),
byte('0'), byte('1'), 0x00, 0x00,
},
exts: []Extension{
&InterfaceInfo{
ext: &InterfaceInfo{
Class: classInterfaceInfo,
Type: 0x0a,
Interface: &net.Interface{
......@@ -128,7 +181,6 @@ var marshalAndParseExtensionTests = []struct {
},
},
},
},
// Interface information with ifIndex, IPAddr, name and MTU
{
proto: iana.ProtocolIPv6ICMP,
......@@ -147,8 +199,7 @@ var marshalAndParseExtensionTests = []struct {
byte('0'), byte('1'), 0x00, 0x00,
0x00, 0x00, 0x20, 0x00,
},
exts: []Extension{
&InterfaceInfo{
ext: &InterfaceInfo{
Class: classInterfaceInfo,
Type: 0x0f,
Interface: &net.Interface{
......@@ -162,96 +213,25 @@ var marshalAndParseExtensionTests = []struct {
},
},
},
},
}
func TestMarshalAndParseExtension(t *testing.T) {
for i, tt := range marshalAndParseExtensionTests {
for j, ext := range tt.exts {
var err error
var b []byte
switch ext := ext.(type) {
case *MPLSLabelStack:
b, err = ext.Marshal(tt.proto)
if err != nil {
t.Errorf("#%v/%v: %v", i, j, err)
continue
}
case *InterfaceInfo:
b, err = ext.Marshal(tt.proto)
if err != nil {
t.Errorf("#%v/%v: %v", i, j, err)
continue
}
}
if !reflect.DeepEqual(b, tt.obj) {
t.Errorf("#%v/%v: got %#v; want %#v", i, j, b, tt.obj)
continue
}
}
for j, wire := range []struct {
data []byte // original datagram
inlattr int // length of padded original datagram, a hint
outlattr int // length of padded original datagram, a want
err error
}{
{nil, 0, -1, errNoExtension},
{make([]byte, 127), 128, -1, errNoExtension},
{make([]byte, 128), 127, -1, errNoExtension},
{make([]byte, 128), 128, -1, errNoExtension},
{make([]byte, 128), 129, -1, errNoExtension},
{append(make([]byte, 128), append(tt.hdr, tt.obj...)...), 127, 128, nil},
{append(make([]byte, 128), append(tt.hdr, tt.obj...)...), 128, 128, nil},
{append(make([]byte, 128), append(tt.hdr, tt.obj...)...), 129, 128, nil},
{append(make([]byte, 512), append(tt.hdr, tt.obj...)...), 511, -1, errNoExtension},
{append(make([]byte, 512), append(tt.hdr, tt.obj...)...), 512, 512, nil},
{append(make([]byte, 512), append(tt.hdr, tt.obj...)...), 513, -1, errNoExtension},
} {
exts, l, err := parseExtensions(wire.data, wire.inlattr)
if err != wire.err {
t.Errorf("#%v/%v: got %v; want %v", i, j, err, wire.err)
continue
}
if wire.err != nil {
continue
}
if l != wire.outlattr {
t.Errorf("#%v/%v: got %v; want %v", i, j, l, wire.outlattr)
}
if !reflect.DeepEqual(exts, tt.exts) {
for j, ext := range exts {
switch ext := ext.(type) {
case *MPLSLabelStack:
want := tt.exts[j].(*MPLSLabelStack)
t.Errorf("#%v/%v: got %#v; want %#v", i, j, ext, want)
case *InterfaceInfo:
want := tt.exts[j].(*InterfaceInfo)
t.Errorf("#%v/%v: got %#v; want %#v", i, j, ext, want)
}
}
continue
}
if err := fn(t, et.proto, et.hdr, et.obj, et.ext); err != nil {
t.Error(err)
}
}
})
}
var parseInterfaceNameTests = []struct {
func TestParseInterfaceName(t *testing.T) {
ifi := InterfaceInfo{Interface: &net.Interface{}}
for i, tt := range []struct {
b []byte
error
}{
}{
{[]byte{0, 'e', 'n', '0'}, errInvalidExtension},
{[]byte{4, 'e', 'n', '0'}, nil},
{[]byte{7, 'e', 'n', '0', 0xff, 0xff, 0xff, 0xff}, errInvalidExtension},
{[]byte{8, 'e', 'n', '0', 0xff, 0xff, 0xff}, errMessageTooShort},
}
func TestParseInterfaceName(t *testing.T) {
ifi := InterfaceInfo{Interface: &net.Interface{}}
for i, tt := range parseInterfaceNameTests {
} {
if _, err := ifi.parseName(tt.b); err != tt.error {
t.Errorf("#%d: got %v; want %v", i, err, tt.error)
}
......
......@@ -15,30 +15,28 @@ import (
"golang.org/x/net/ipv4"
)
type ipv4HeaderTest struct {
wireHeaderFromKernel [ipv4.HeaderLen]byte
wireHeaderFromTradBSDKernel [ipv4.HeaderLen]byte
Header *ipv4.Header
}
var ipv4HeaderLittleEndianTest = ipv4HeaderTest{
// TODO(mikio): Add platform dependent wire header formats when
// we support new platforms.
wireHeaderFromKernel: [ipv4.HeaderLen]byte{
func TestParseIPv4Header(t *testing.T) {
switch socket.NativeEndian {
case binary.LittleEndian:
t.Run("LittleEndian", func(t *testing.T) {
// TODO(mikio): Add platform dependent wire
// header formats when we support new
// platforms.
wireHeaderFromKernel := [ipv4.HeaderLen]byte{
0x45, 0x01, 0xbe, 0xef,
0xca, 0xfe, 0x45, 0xdc,
0xff, 0x01, 0xde, 0xad,
172, 16, 254, 254,
192, 168, 0, 1,
},
wireHeaderFromTradBSDKernel: [ipv4.HeaderLen]byte{
}
wireHeaderFromTradBSDKernel := [ipv4.HeaderLen]byte{
0x45, 0x01, 0xef, 0xbe,
0xca, 0xfe, 0x45, 0xdc,
0xff, 0x01, 0xde, 0xad,
172, 16, 254, 254,
192, 168, 0, 1,
},
Header: &ipv4.Header{
}
th := &ipv4.Header{
Version: ipv4.Version,
Len: ipv4.HeaderLen,
TOS: 1,
......@@ -51,33 +49,27 @@ var ipv4HeaderLittleEndianTest = ipv4HeaderTest{
Checksum: 0xdead,
Src: net.IPv4(172, 16, 254, 254),
Dst: net.IPv4(192, 168, 0, 1),
},
}
func TestParseIPv4Header(t *testing.T) {
tt := &ipv4HeaderLittleEndianTest
if socket.NativeEndian != binary.LittleEndian {
t.Skip("no test for non-little endian machine yet")
}
var wh []byte
switch runtime.GOOS {
case "darwin":
wh = tt.wireHeaderFromTradBSDKernel[:]
wh = wireHeaderFromTradBSDKernel[:]
case "freebsd":
if freebsdVersion >= 1000000 {
wh = tt.wireHeaderFromKernel[:]
wh = wireHeaderFromKernel[:]
} else {
wh = tt.wireHeaderFromTradBSDKernel[:]
wh = wireHeaderFromTradBSDKernel[:]
}
default:
wh = tt.wireHeaderFromKernel[:]
wh = wireHeaderFromKernel[:]
}
h, err := ParseIPv4Header(wh)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(h, tt.Header) {
t.Fatalf("got %#v; want %#v", h, tt.Header)
if !reflect.DeepEqual(h, th) {
t.Fatalf("got %#v; want %#v", h, th)
}
})
}
}
......@@ -15,7 +15,41 @@ import (
"golang.org/x/net/ipv6"
)
var marshalAndParseMessageForIPv4Tests = []icmp.Message{
func TestMarshalAndParseMessage(t *testing.T) {
fn := func(t *testing.T, proto int, tms []icmp.Message) {
var pshs [][]byte
switch proto {
case iana.ProtocolICMP:
pshs = [][]byte{nil}
case iana.ProtocolIPv6ICMP:
pshs = [][]byte{
icmp.IPv6PseudoHeader(net.ParseIP("fe80::1"), net.ParseIP("ff02::1")),
nil,
}
}
for i, tm := range tms {
for _, psh := range pshs {
b, err := tm.Marshal(psh)
if err != nil {
t.Fatal(err)
}
m, err := icmp.ParseMessage(proto, b)
if err != nil {
t.Fatal(err)
}
if m.Type != tm.Type || m.Code != tm.Code {
t.Errorf("#%d: got %#v; want %#v", i, m, &tm)
}
if !reflect.DeepEqual(m.Body, tm.Body) {
t.Errorf("#%d: got %#v; want %#v", i, m.Body, tm.Body)
}
}
}
}
t.Run("IPv4", func(t *testing.T) {
fn(t, iana.ProtocolICMP,
[]icmp.Message{
{
Type: ipv4.ICMPTypeDestinationUnreachable, Code: 15,
Body: &icmp.DstUnreach{
......@@ -48,28 +82,11 @@ var marshalAndParseMessageForIPv4Tests = []icmp.Message{
Data: []byte{0x80, 0x40, 0x20, 0x10},
},
},
}
func TestMarshalAndParseMessageForIPv4(t *testing.T) {
for i, tt := range marshalAndParseMessageForIPv4Tests {
b, err := tt.Marshal(nil)
if err != nil {
t.Fatal(err)
}
m, err := icmp.ParseMessage(iana.ProtocolICMP, b)
if err != nil {
t.Fatal(err)
}
if m.Type != tt.Type || m.Code != tt.Code {
t.Errorf("#%v: got %v; want %v", i, m, &tt)
}
if !reflect.DeepEqual(m.Body, tt.Body) {
t.Errorf("#%v: got %v; want %v", i, m.Body, tt.Body)
}
}
}
var marshalAndParseMessageForIPv6Tests = []icmp.Message{
})
})
t.Run("IPv6", func(t *testing.T) {
fn(t, iana.ProtocolIPv6ICMP,
[]icmp.Message{
{
Type: ipv6.ICMPTypeDestinationUnreachable, Code: 6,
Body: &icmp.DstUnreach{
......@@ -109,26 +126,6 @@ var marshalAndParseMessageForIPv6Tests = []icmp.Message{
Data: []byte{0x80, 0x40, 0x20, 0x10},
},
},
}
func TestMarshalAndParseMessageForIPv6(t *testing.T) {
pshicmp := icmp.IPv6PseudoHeader(net.ParseIP("fe80::1"), net.ParseIP("ff02::1"))
for i, tt := range marshalAndParseMessageForIPv6Tests {
for _, psh := range [][]byte{pshicmp, nil} {
b, err := tt.Marshal(psh)
if err != nil {
t.Fatal(err)
}
m, err := icmp.ParseMessage(iana.ProtocolIPv6ICMP, b)
if err != nil {
t.Fatal(err)
}
if m.Type != tt.Type || m.Code != tt.Code {
t.Errorf("#%v: got %v; want %v", i, m, &tt)
}
if !reflect.DeepEqual(m.Body, tt.Body) {
t.Errorf("#%v: got %v; want %v", i, m.Body, tt.Body)
}
}
}
})
})
}
......@@ -5,6 +5,7 @@
package icmp_test
import (
"errors"
"fmt"
"net"
"reflect"
......@@ -16,7 +17,80 @@ import (
"golang.org/x/net/ipv6"
)
var marshalAndParseMultipartMessageForIPv4Tests = []icmp.Message{
func TestMarshalAndParseMultipartMessage(t *testing.T) {
fn := func(t *testing.T, proto int, tm icmp.Message) error {
b, err := tm.Marshal(nil)
if err != nil {
return err
}
switch proto {
case iana.ProtocolICMP:
if b[5] != 32 {
return fmt.Errorf("got %d; want 32", b[5])
}
case iana.ProtocolIPv6ICMP:
if b[4] != 16 {
return fmt.Errorf("got %d; want 16", b[4])
}
default:
return fmt.Errorf("unknown protocol: %d", proto)
}
m, err := icmp.ParseMessage(proto, b)
if err != nil {
return err
}
if m.Type != tm.Type || m.Code != tm.Code {
return fmt.Errorf("got %v; want %v", m, &tm)
}
switch m.Type {
case ipv4.ICMPTypeDestinationUnreachable:
got, want := m.Body.(*icmp.DstUnreach), tm.Body.(*icmp.DstUnreach)
if !reflect.DeepEqual(got.Extensions, want.Extensions) {
return errors.New(dumpExtensions(got.Extensions, want.Extensions))
}
if len(got.Data) != 128 {
return fmt.Errorf("got %d; want 128", len(got.Data))
}
case ipv4.ICMPTypeTimeExceeded:
got, want := m.Body.(*icmp.TimeExceeded), tm.Body.(*icmp.TimeExceeded)
if !reflect.DeepEqual(got.Extensions, want.Extensions) {
return errors.New(dumpExtensions(got.Extensions, want.Extensions))
}
if len(got.Data) != 128 {
return fmt.Errorf("got %d; want 128", len(got.Data))
}
case ipv4.ICMPTypeParameterProblem:
got, want := m.Body.(*icmp.ParamProb), tm.Body.(*icmp.ParamProb)
if !reflect.DeepEqual(got.Extensions, want.Extensions) {
return errors.New(dumpExtensions(got.Extensions, want.Extensions))
}
if len(got.Data) != 128 {
return fmt.Errorf("got %d; want 128", len(got.Data))
}
case ipv6.ICMPTypeDestinationUnreachable:
got, want := m.Body.(*icmp.DstUnreach), tm.Body.(*icmp.DstUnreach)
if !reflect.DeepEqual(got.Extensions, want.Extensions) {
return errors.New(dumpExtensions(got.Extensions, want.Extensions))
}
if len(got.Data) != 128 {
return fmt.Errorf("got %d; want 128", len(got.Data))
}
case ipv6.ICMPTypeTimeExceeded:
got, want := m.Body.(*icmp.TimeExceeded), tm.Body.(*icmp.TimeExceeded)
if !reflect.DeepEqual(got.Extensions, want.Extensions) {
return errors.New(dumpExtensions(got.Extensions, want.Extensions))
}
if len(got.Data) != 128 {
return fmt.Errorf("got %d; want 128", len(got.Data))
}
default:
return fmt.Errorf("unknown message type: %v", m.Type)
}
return nil
}
t.Run("IPv4", func(t *testing.T) {
for i, tm := range []icmp.Message{
{
Type: ipv4.ICMPTypeDestinationUnreachable, Code: 15,
Body: &icmp.DstUnreach{
......@@ -126,54 +200,14 @@ var marshalAndParseMultipartMessageForIPv4Tests = []icmp.Message{
},
},
},
}
func TestMarshalAndParseMultipartMessageForIPv4(t *testing.T) {
for i, tt := range marshalAndParseMultipartMessageForIPv4Tests {
b, err := tt.Marshal(nil)
if err != nil {
t.Fatal(err)
}
if b[5] != 32 {
t.Errorf("#%v: got %v; want 32", i, b[5])
}
m, err := icmp.ParseMessage(iana.ProtocolICMP, b)
if err != nil {
t.Fatal(err)
} {
if err := fn(t, iana.ProtocolICMP, tm); err != nil {
t.Errorf("#%d: %v", i, err)
}
if m.Type != tt.Type || m.Code != tt.Code {
t.Errorf("#%v: got %v; want %v", i, m, &tt)
}
switch m.Type {
case ipv4.ICMPTypeDestinationUnreachable:
got, want := m.Body.(*icmp.DstUnreach), tt.Body.(*icmp.DstUnreach)
if !reflect.DeepEqual(got.Extensions, want.Extensions) {
t.Error(dumpExtensions(i, got.Extensions, want.Extensions))
}
if len(got.Data) != 128 {
t.Errorf("#%v: got %v; want 128", i, len(got.Data))
}
case ipv4.ICMPTypeTimeExceeded:
got, want := m.Body.(*icmp.TimeExceeded), tt.Body.(*icmp.TimeExceeded)
if !reflect.DeepEqual(got.Extensions, want.Extensions) {
t.Error(dumpExtensions(i, got.Extensions, want.Extensions))
}
if len(got.Data) != 128 {
t.Errorf("#%v: got %v; want 128", i, len(got.Data))
}
case ipv4.ICMPTypeParameterProblem:
got, want := m.Body.(*icmp.ParamProb), tt.Body.(*icmp.ParamProb)
if !reflect.DeepEqual(got.Extensions, want.Extensions) {
t.Error(dumpExtensions(i, got.Extensions, want.Extensions))
}
if len(got.Data) != 128 {
t.Errorf("#%v: got %v; want 128", i, len(got.Data))
}
}
}
}
var marshalAndParseMultipartMessageForIPv6Tests = []icmp.Message{
})
t.Run("IPv6", func(t *testing.T) {
for i, tm := range []icmp.Message{
{
Type: ipv6.ICMPTypeDestinationUnreachable, Code: 6,
Body: &icmp.DstUnreach{
......@@ -253,72 +287,42 @@ var marshalAndParseMultipartMessageForIPv6Tests = []icmp.Message{
},
},
},
}
func TestMarshalAndParseMultipartMessageForIPv6(t *testing.T) {
pshicmp := icmp.IPv6PseudoHeader(net.ParseIP("fe80::1"), net.ParseIP("ff02::1"))
for i, tt := range marshalAndParseMultipartMessageForIPv6Tests {
for _, psh := range [][]byte{pshicmp, nil} {
b, err := tt.Marshal(psh)
if err != nil {
t.Fatal(err)
}
if b[4] != 16 {
t.Errorf("#%v: got %v; want 16", i, b[4])
}
m, err := icmp.ParseMessage(iana.ProtocolIPv6ICMP, b)
if err != nil {
t.Fatal(err)
}
if m.Type != tt.Type || m.Code != tt.Code {
t.Errorf("#%v: got %v; want %v", i, m, &tt)
}
switch m.Type {
case ipv6.ICMPTypeDestinationUnreachable:
got, want := m.Body.(*icmp.DstUnreach), tt.Body.(*icmp.DstUnreach)
if !reflect.DeepEqual(got.Extensions, want.Extensions) {
t.Error(dumpExtensions(i, got.Extensions, want.Extensions))
}
if len(got.Data) != 128 {
t.Errorf("#%v: got %v; want 128", i, len(got.Data))
}
case ipv6.ICMPTypeTimeExceeded:
got, want := m.Body.(*icmp.TimeExceeded), tt.Body.(*icmp.TimeExceeded)
if !reflect.DeepEqual(got.Extensions, want.Extensions) {
t.Error(dumpExtensions(i, got.Extensions, want.Extensions))
}
if len(got.Data) != 128 {
t.Errorf("#%v: got %v; want 128", i, len(got.Data))
}
}
} {
if err := fn(t, iana.ProtocolIPv6ICMP, tm); err != nil {
t.Errorf("#%d: %v", i, err)
}
}
})
}
func dumpExtensions(i int, gotExts, wantExts []icmp.Extension) string {
func dumpExtensions(gotExts, wantExts []icmp.Extension) string {
var s string
for j, got := range gotExts {
for i, got := range gotExts {
switch got := got.(type) {
case *icmp.MPLSLabelStack:
want := wantExts[j].(*icmp.MPLSLabelStack)
want := wantExts[i].(*icmp.MPLSLabelStack)
if !reflect.DeepEqual(got, want) {
s += fmt.Sprintf("#%v/%v: got %#v; want %#v\n", i, j, got, want)
s += fmt.Sprintf("#%d: got %#v; want %#v\n", i, got, want)
}
case *icmp.InterfaceInfo:
want := wantExts[j].(*icmp.InterfaceInfo)
want := wantExts[i].(*icmp.InterfaceInfo)
if !reflect.DeepEqual(got, want) {
s += fmt.Sprintf("#%v/%v: got %#v, %#v, %#v; want %#v, %#v, %#v\n", i, j, got, got.Interface, got.Addr, want, want.Interface, want.Addr)
s += fmt.Sprintf("#%d: got %#v, %#v, %#v; want %#v, %#v, %#v\n", i, got, got.Interface, got.Addr, want, want.Interface, want.Addr)
}
}
}
if len(s) == 0 {
return "<nil>"
}
return s[:len(s)-1]
}
var multipartMessageBodyLenTests = []struct {
func TestMultipartMessageBodyLen(t *testing.T) {
for i, tt := range []struct {
proto int
in icmp.MessageBody
out int
}{
}{
{
iana.ProtocolICMP,
&icmp.DstUnreach{
......@@ -431,10 +435,7 @@ var multipartMessageBodyLenTests = []struct {
},
4 + 4 + 4 + 0 + 136, // [length, unused], extension header, object header, object payload and original datagram
},
}
func TestMultipartMessageBodyLen(t *testing.T) {
for i, tt := range multipartMessageBodyLenTests {
} {
if out := tt.in.Len(tt.proto); out != tt.out {
t.Errorf("#%d: got %d; want %d", i, out, tt.out)
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment