diff options
author | Mikio Hara <mikioh.mikioh@gmail.com> | 2015-01-15 01:00:16 +0900 |
---|---|---|
committer | Mikio Hara <mikioh.mikioh@gmail.com> | 2015-02-06 05:26:11 +0000 |
commit | 71586c3cf98f806af322c5a361660eb046e00501 (patch) | |
tree | 67f0ebbb7bc897669317e17e1dc5d071141eeed2 | |
parent | 074db39ac33144545f748023c2857998d29c613f (diff) | |
download | net-71586c3cf98f806af322c5a361660eb046e00501.tar.gz |
icmp: add extensions for MPLS
This change implements ICMP multi-part message marshaler, parser and
extensions for MPLS which are used for route trace applications as
described in RFC 4950.
API breaking changes:
type MessageBody interface, Len() int
type Extension interface, Len() int
type Extension interface, Marshal() ([]byte, error)
are replaced with
type MessageBody interface, Len(int) int
type Extension interface, Len(int) int
type Extension interface, Marshal(int) ([]byte, error)
Change-Id: Iee1f2e03916d49b8dfe3a89fe682c702d40ecc85
Reviewed-on: https://go-review.googlesource.com/2794
Reviewed-by: Ian Lance Taylor <iant@golang.org>
-rw-r--r-- | icmp/dstunreach.go | 19 | ||||
-rw-r--r-- | icmp/echo.go | 2 | ||||
-rw-r--r-- | icmp/extension.go | 69 | ||||
-rw-r--r-- | icmp/extension_test.go | 158 | ||||
-rw-r--r-- | icmp/message.go | 33 | ||||
-rw-r--r-- | icmp/message_test.go | 12 | ||||
-rw-r--r-- | icmp/messagebody.go | 5 | ||||
-rw-r--r-- | icmp/mpls.go | 75 | ||||
-rw-r--r-- | icmp/multipart.go | 103 | ||||
-rw-r--r-- | icmp/multipart_test.go | 223 | ||||
-rw-r--r-- | icmp/packettoobig.go | 4 | ||||
-rw-r--r-- | icmp/paramprob.go | 39 | ||||
-rw-r--r-- | icmp/timeexceeded.go | 19 |
13 files changed, 698 insertions, 63 deletions
diff --git a/icmp/dstunreach.go b/icmp/dstunreach.go index d1905a9..01dc660 100644 --- a/icmp/dstunreach.go +++ b/icmp/dstunreach.go @@ -12,31 +12,30 @@ type DstUnreach struct { } // Len implements the Len method of MessageBody interface. -func (p *DstUnreach) Len() int { +func (p *DstUnreach) Len(proto int) int { if p == nil { return 0 } - return 4 + len(p.Data) + l, _ := multipartMessageBodyDataLen(proto, p.Data, p.Extensions) + return l } // Marshal implements the Marshal method of MessageBody interface. func (p *DstUnreach) Marshal(proto int) ([]byte, error) { - b := make([]byte, 4+len(p.Data)) - copy(b[4:], p.Data) - return b, nil + return marshalMultipartMessageBody(proto, p.Data, p.Extensions) } // parseDstUnreach parses b as an ICMP destination unreachable message // body. func parseDstUnreach(proto int, b []byte) (MessageBody, error) { - bodyLen := len(b) - if bodyLen < 4 { + if len(b) < 4 { return nil, errMessageTooShort } p := &DstUnreach{} - if bodyLen > 4 { - p.Data = make([]byte, bodyLen-4) - copy(p.Data, b[4:]) + var err error + p.Data, p.Extensions, err = parseMultipartMessageBody(proto, b) + if err != nil { + return nil, err } return p, nil } diff --git a/icmp/echo.go b/icmp/echo.go index 6b373fc..8943eab 100644 --- a/icmp/echo.go +++ b/icmp/echo.go @@ -12,7 +12,7 @@ type Echo struct { } // Len implements the Len method of MessageBody interface. -func (p *Echo) Len() int { +func (p *Echo) Len(proto int) int { if p == nil { return 0 } diff --git a/icmp/extension.go b/icmp/extension.go index 7575f02..37af0e1 100644 --- a/icmp/extension.go +++ b/icmp/extension.go @@ -7,10 +7,75 @@ package icmp // An Extension represents an ICMP extension. type Extension interface { // Len returns the length of ICMP extension. - Len() int + // Proto must be either the ICMPv4 or ICMPv6 protocol number. + Len(proto int) int // Marshal returns the binary enconding of ICMP extension. - Marshal() ([]byte, error) + // Proto must be either the ICMPv4 or ICMPv6 protocol number. + Marshal(proto int) ([]byte, error) } const extensionVersion = 2 + +func validExtensionHeader(b []byte) bool { + v := int(b[0]&0xf0) >> 4 + s := uint16(b[2])<<8 | uint16(b[3]) + if s != 0 { + s = checksum(b) + } + if v != extensionVersion || s != 0 { + return false + } + return true +} + +// parseExtensions parses b as a list of ICMP extensions. +// The length attribute l must be the length attribute field in +// received icmp messages. +// +// It will return a list of ICMP extensions and an adjusted length +// attribute that represents the length of the padded original +// datagram field. Otherwise, it returns an error. +func parseExtensions(b []byte, l int) ([]Extension, int, error) { + // Still a lot of non-RFC 4884 compliant implementations are + // out there. Set the length attribute l to 128 when it looks + // inappropriate for backwards compatibility. + // + // A minimal extension at least requires 8 octets; 4 octets + // for an extension header, and 4 octets for a single object + // header. + // + // See RFC 4884 for further information. + if 128 > l || l+8 > len(b) { + l = 128 + } + if l+8 > len(b) { + return nil, -1, errNoExtension + } + if !validExtensionHeader(b[l:]) { + if l == 128 { + return nil, -1, errNoExtension + } + l = 128 + if !validExtensionHeader(b[l:]) { + return nil, -1, errNoExtension + } + } + var exts []Extension + for b = b[l+4:]; len(b) >= 4; { + ol := int(b[0])<<8 | int(b[1]) + if 4 > ol || ol > len(b) { + break + } + switch b[2] { + case classMPLSLabelStack: + ext, err := parseMPLSLabelStack(b[:ol]) + if err != nil { + return nil, -1, err + } + exts = append(exts, ext) + } + b = b[ol:] + } + return exts, l, nil +} diff --git a/icmp/extension_test.go b/icmp/extension_test.go new file mode 100644 index 0000000..6ed9e70 --- /dev/null +++ b/icmp/extension_test.go @@ -0,0 +1,158 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package icmp + +import ( + "reflect" + "testing" + + "golang.org/x/net/internal/iana" +) + +var marshalAndParseExtensionTests = []struct { + proto int + hdr []byte + obj []byte + exts []Extension +}{ + // MPLS label stack with no label + { + proto: iana.ProtocolICMP, + hdr: []byte{ + 0x20, 0x00, 0x00, 0x00, + }, + obj: []byte{ + 0x00, 0x04, 0x01, 0x01, + }, + exts: []Extension{ + &MPLSLabelStack{ + Class: classMPLSLabelStack, + Type: typeIncomingMPLSLabelStack, + }, + }, + }, + // MPLS label stack with a single label + { + proto: iana.ProtocolIPv6ICMP, + hdr: []byte{ + 0x20, 0x00, 0x00, 0x00, + }, + obj: []byte{ + 0x00, 0x08, 0x01, 0x01, + 0x03, 0xe8, 0xe9, 0xff, + }, + exts: []Extension{ + &MPLSLabelStack{ + Class: classMPLSLabelStack, + Type: typeIncomingMPLSLabelStack, + Labels: []MPLSLabel{ + { + Label: 16014, + TC: 0x4, + S: true, + TTL: 255, + }, + }, + }, + }, + }, + // MPLS label stack with multiple labels + { + proto: iana.ProtocolICMP, + hdr: []byte{ + 0x20, 0x00, 0x00, 0x00, + }, + obj: []byte{ + 0x00, 0x0c, 0x01, 0x01, + 0x03, 0xe8, 0xde, 0xfe, + 0x03, 0xe8, 0xe1, 0xff, + }, + exts: []Extension{ + &MPLSLabelStack{ + Class: classMPLSLabelStack, + Type: typeIncomingMPLSLabelStack, + Labels: []MPLSLabel{ + { + Label: 16013, + TC: 0x7, + S: false, + TTL: 254, + }, + { + Label: 16014, + TC: 0, + S: true, + TTL: 255, + }, + }, + }, + }, + }, +} + +func TestMarshalAndParseExtension(t *testing.T) { + for i, tt := range marshalAndParseExtensionTests { + for j, ext := range tt.exts { + var err error + var b []byte + switch ext := ext.(type) { + case *MPLSLabelStack: + b, err = ext.Marshal(tt.proto) + if err != nil { + t.Errorf("#%v/%v: %v", i, j, err) + continue + } + } + if !reflect.DeepEqual(b, tt.obj) { + t.Errorf("#%v/%v: got %#v; want %#v", i, j, b, tt.obj) + continue + } + } + + for j, wire := range []struct { + data []byte // original datagram + inlattr int // length of padded original datagram, a hint + outlattr int // length of padded original datagram, a want + err error + }{ + {nil, 0, -1, errNoExtension}, + {make([]byte, 127), 128, -1, errNoExtension}, + + {make([]byte, 128), 127, -1, errNoExtension}, + {make([]byte, 128), 128, -1, errNoExtension}, + {make([]byte, 128), 129, -1, errNoExtension}, + + {append(make([]byte, 128), append(tt.hdr, tt.obj...)...), 127, 128, nil}, + {append(make([]byte, 128), append(tt.hdr, tt.obj...)...), 128, 128, nil}, + {append(make([]byte, 128), append(tt.hdr, tt.obj...)...), 129, 128, nil}, + + {append(make([]byte, 512), append(tt.hdr, tt.obj...)...), 511, -1, errNoExtension}, + {append(make([]byte, 512), append(tt.hdr, tt.obj...)...), 512, 512, nil}, + {append(make([]byte, 512), append(tt.hdr, tt.obj...)...), 513, -1, errNoExtension}, + } { + exts, l, err := parseExtensions(wire.data, wire.inlattr) + if err != wire.err { + t.Errorf("#%v/%v: got %v; want %v", i, j, err, wire.err) + continue + } + if wire.err != nil { + continue + } + if l != wire.outlattr { + t.Errorf("#%v/%v: got %v; want %v", i, j, l, wire.outlattr) + } + if !reflect.DeepEqual(exts, tt.exts) { + for j, ext := range exts { + switch ext := ext.(type) { + case *MPLSLabelStack: + want := tt.exts[j].(*MPLSLabelStack) + t.Errorf("#%v/%v: got %#v; want %#v", i, j, ext, want) + } + } + continue + } + } + } +} diff --git a/icmp/message.go b/icmp/message.go index 4e6bec7..13b57f7 100644 --- a/icmp/message.go +++ b/icmp/message.go @@ -8,6 +8,7 @@ // // ICMPv4 and ICMPv6 are defined in RFC 792 and RFC 4443. // Multi-part message support for ICMP is defined in RFC 4884. +// ICMP extensions for MPLS are defined in RFC 4950. package icmp // import "golang.org/x/net/icmp" import ( @@ -25,8 +26,23 @@ var ( errHeaderTooShort = errors.New("header too short") errBufferTooShort = errors.New("buffer too short") errOpNoSupport = errors.New("operation not supported") + errNoExtension = errors.New("no extension") ) +func checksum(b []byte) uint16 { + csumcv := len(b) - 1 // checksum coverage + s := uint32(0) + for i := 0; i < csumcv; i += 2 { + s += uint32(b[i+1])<<8 | uint32(b[i]) + } + if csumcv&1 == 0 { + s += uint32(b[csumcv]) + } + s = s>>16 + s&0xffff + s = s + s>>16 + return ^uint16(s) +} + // A Type represents an ICMP message type. type Type interface { Protocol() int @@ -63,7 +79,7 @@ func (m *Message) Marshal(psh []byte) ([]byte, error) { if m.Type.Protocol() == iana.ProtocolIPv6ICMP && psh != nil { b = append(psh, b...) } - if m.Body != nil && m.Body.Len() != 0 { + if m.Body != nil && m.Body.Len(m.Type.Protocol()) != 0 { mb, err := m.Body.Marshal(m.Type.Protocol()) if err != nil { return nil, err @@ -77,20 +93,11 @@ func (m *Message) Marshal(psh []byte) ([]byte, error) { off, l := 2*net.IPv6len, len(b)-len(psh) b[off], b[off+1], b[off+2], b[off+3] = byte(l>>24), byte(l>>16), byte(l>>8), byte(l) } - csumcv := len(b) - 1 // checksum coverage - s := uint32(0) - for i := 0; i < csumcv; i += 2 { - s += uint32(b[i+1])<<8 | uint32(b[i]) - } - if csumcv&1 == 0 { - s += uint32(b[csumcv]) - } - s = s>>16 + s&0xffff - s = s + s>>16 + s := checksum(b) // Place checksum back in header; using ^= avoids the // assumption the checksum bytes are zero. - b[len(psh)+2] ^= byte(^s) - b[len(psh)+3] ^= byte(^s >> 8) + b[len(psh)+2] ^= byte(s) + b[len(psh)+3] ^= byte(s >> 8) return b[len(psh):], nil } diff --git a/icmp/message_test.go b/icmp/message_test.go index 162bf7f..5d2605f 100644 --- a/icmp/message_test.go +++ b/icmp/message_test.go @@ -51,7 +51,7 @@ var marshalAndParseMessageForIPv4Tests = []icmp.Message{ } func TestMarshalAndParseMessageForIPv4(t *testing.T) { - for _, tt := range marshalAndParseMessageForIPv4Tests { + for i, tt := range marshalAndParseMessageForIPv4Tests { b, err := tt.Marshal(nil) if err != nil { t.Fatal(err) @@ -61,10 +61,10 @@ func TestMarshalAndParseMessageForIPv4(t *testing.T) { t.Fatal(err) } if m.Type != tt.Type || m.Code != tt.Code { - t.Errorf("got %v; want %v", m, &tt) + t.Errorf("#%v: got %v; want %v", i, m, &tt) } if !reflect.DeepEqual(m.Body, tt.Body) { - t.Errorf("got %v; want %v", m.Body, tt.Body) + t.Errorf("#%v: got %v; want %v", i, m.Body, tt.Body) } } } @@ -113,7 +113,7 @@ var marshalAndParseMessageForIPv6Tests = []icmp.Message{ func TestMarshalAndParseMessageForIPv6(t *testing.T) { pshicmp := icmp.IPv6PseudoHeader(net.ParseIP("fe80::1"), net.ParseIP("ff02::1")) - for _, tt := range marshalAndParseMessageForIPv6Tests { + for i, tt := range marshalAndParseMessageForIPv6Tests { for _, psh := range [][]byte{pshicmp, nil} { b, err := tt.Marshal(psh) if err != nil { @@ -124,10 +124,10 @@ func TestMarshalAndParseMessageForIPv6(t *testing.T) { t.Fatal(err) } if m.Type != tt.Type || m.Code != tt.Code { - t.Errorf("got %v; want %v", m, &tt) + t.Errorf("#%v: got %v; want %v", i, m, &tt) } if !reflect.DeepEqual(m.Body, tt.Body) { - t.Errorf("got %v; want %v", m.Body, tt.Body) + t.Errorf("#%v: got %v; want %v", i, m.Body, tt.Body) } } } diff --git a/icmp/messagebody.go b/icmp/messagebody.go index 30f2df8..d314480 100644 --- a/icmp/messagebody.go +++ b/icmp/messagebody.go @@ -7,7 +7,8 @@ package icmp // A MessageBody represents an ICMP message body. type MessageBody interface { // Len returns the length of ICMP message body. - Len() int + // Proto must be either the ICMPv4 or ICMPv6 protocol number. + Len(proto int) int // Marshal returns the binary enconding of ICMP message body. // Proto must be either the ICMPv4 or ICMPv6 protocol number. @@ -20,7 +21,7 @@ type DefaultMessageBody struct { } // Len implements the Len method of MessageBody interface. -func (p *DefaultMessageBody) Len() int { +func (p *DefaultMessageBody) Len(proto int) int { if p == nil { return 0 } diff --git a/icmp/mpls.go b/icmp/mpls.go new file mode 100644 index 0000000..31bcfe8 --- /dev/null +++ b/icmp/mpls.go @@ -0,0 +1,75 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package icmp + +// A MPLSLabel represents a MPLS label stack entry. +type MPLSLabel struct { + Label int // label value + TC int // traffic class; formerly experimental use + S bool // bottom of stack + TTL int // time to live +} + +const ( + classMPLSLabelStack = 1 + typeIncomingMPLSLabelStack = 1 +) + +// A MPLSLabelStack represents a MPLS label stack. +type MPLSLabelStack struct { + Class int // extension object class number + Type int // extension object sub-type + Labels []MPLSLabel +} + +// Len implements the Len method of Extension interface. +func (ls *MPLSLabelStack) Len(proto int) int { + return 4 + (4 * len(ls.Labels)) +} + +// Marshal implements the Marshal method of Extension interface. +func (ls *MPLSLabelStack) Marshal(proto int) ([]byte, error) { + b := make([]byte, ls.Len(proto)) + if err := ls.marshal(proto, b); err != nil { + return nil, err + } + return b, nil +} + +func (ls *MPLSLabelStack) marshal(proto int, b []byte) error { + l := ls.Len(proto) + b[0], b[1] = byte(l>>8), byte(l) + b[2], b[3] = classMPLSLabelStack, typeIncomingMPLSLabelStack + off := 4 + for _, ll := range ls.Labels { + b[off], b[off+1], b[off+2] = byte(ll.Label>>12), byte(ll.Label>>4&0xff), byte(ll.Label<<4&0xf0) + b[off+2] |= byte(ll.TC << 1 & 0x0e) + if ll.S { + b[off+2] |= 0x1 + } + b[off+3] = byte(ll.TTL) + off += 4 + } + return nil +} + +func parseMPLSLabelStack(b []byte) (Extension, error) { + ls := &MPLSLabelStack{ + Class: int(b[2]), + Type: int(b[3]), + } + for b = b[4:]; len(b) >= 4; b = b[4:] { + ll := MPLSLabel{ + Label: int(b[0])<<12 | int(b[1])<<4 | int(b[2])>>4, + TC: int(b[2]&0x0e) >> 1, + TTL: int(b[3]), + } + if b[2]&0x1 != 0 { + ll.S = true + } + ls.Labels = append(ls.Labels, ll) + } + return ls, nil +} diff --git a/icmp/multipart.go b/icmp/multipart.go new file mode 100644 index 0000000..3f89a76 --- /dev/null +++ b/icmp/multipart.go @@ -0,0 +1,103 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package icmp + +import "golang.org/x/net/internal/iana" + +// multipartMessageBodyDataLen takes b as an original datagram and +// exts as extensions, and returns a required length for message body +// and a required length for a padded original datagram in wire +// format. +func multipartMessageBodyDataLen(proto int, b []byte, exts []Extension) (bodyLen, dataLen int) { + for _, ext := range exts { + bodyLen += ext.Len(proto) + } + if bodyLen > 0 { + dataLen = multipartMessageOrigDatagramLen(proto, b) + bodyLen += 4 // length of extension header + } else { + dataLen = len(b) + } + bodyLen += dataLen + return bodyLen, dataLen +} + +// multipartMessageOrigDatagramLen takes b as an original datagram, +// and returns a required length for a padded orignal datagram in wire +// format. +func multipartMessageOrigDatagramLen(proto int, b []byte) int { + roundup := func(b []byte, align int) int { + // According to RFC 4884, the padded original datagram + // field must contain at least 128 octets. + if len(b) < 128 { + return 128 + } + r := len(b) + return (r + align) &^ (align - 1) + } + switch proto { + case iana.ProtocolICMP: + return roundup(b, 4) + case iana.ProtocolIPv6ICMP: + return roundup(b, 8) + default: + return len(b) + } +} + +// marshalMultipartMessageBody takes data as an original datagram and +// exts as extesnsions, and returns a binary encoding of message body. +// It can be used for non-multipart message bodies when exts is nil. +func marshalMultipartMessageBody(proto int, data []byte, exts []Extension) ([]byte, error) { + bodyLen, dataLen := multipartMessageBodyDataLen(proto, data, exts) + b := make([]byte, 4+bodyLen) + copy(b[4:], data) + off := dataLen + 4 + if len(exts) > 0 { + b[dataLen+4] = byte(extensionVersion << 4) + off += 4 // length of object header + for _, ext := range exts { + switch ext := ext.(type) { + case *MPLSLabelStack: + if err := ext.marshal(proto, b[off:]); err != nil { + return nil, err + } + off += ext.Len(proto) + } + } + s := checksum(b[dataLen+4:]) + b[dataLen+4+2] ^= byte(s) + b[dataLen+4+3] ^= byte(s >> 8) + switch proto { + case iana.ProtocolICMP: + b[1] = byte(dataLen / 4) + case iana.ProtocolIPv6ICMP: + b[0] = byte(dataLen / 8) + } + } + return b, nil +} + +// parseMultipartMessageBody parses b as either a non-multipart +// message body or a multipart message body. +func parseMultipartMessageBody(proto int, b []byte) ([]byte, []Extension, error) { + var l int + switch proto { + case iana.ProtocolICMP: + l = 4 * int(b[1]) + case iana.ProtocolIPv6ICMP: + l = 8 * int(b[0]) + } + if len(b) == 4 { + return nil, nil, nil + } + exts, l, err := parseExtensions(b[4:], l) + if err != nil { + l = len(b) - 4 + } + data := make([]byte, l) + copy(data, b[4:]) + return data, exts, nil +} diff --git a/icmp/multipart_test.go b/icmp/multipart_test.go new file mode 100644 index 0000000..505db86 --- /dev/null +++ b/icmp/multipart_test.go @@ -0,0 +1,223 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package icmp_test + +import ( + "fmt" + "net" + "reflect" + "testing" + + "golang.org/x/net/icmp" + "golang.org/x/net/internal/iana" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +var marshalAndParseMultipartMessageForIPv4Tests = []icmp.Message{ + { + Type: ipv4.ICMPTypeDestinationUnreachable, Code: 15, + Body: &icmp.DstUnreach{ + Data: []byte("ERROR-INVOKING-PACKET"), + Extensions: []icmp.Extension{ + &icmp.MPLSLabelStack{ + Class: 1, + Type: 1, + Labels: []icmp.MPLSLabel{ + { + Label: 16014, + TC: 0x4, + S: true, + TTL: 255, + }, + }, + }, + }, + }, + }, + { + Type: ipv4.ICMPTypeTimeExceeded, Code: 1, + Body: &icmp.TimeExceeded{ + Data: []byte("ERROR-INVOKING-PACKET"), + Extensions: []icmp.Extension{ + &icmp.MPLSLabelStack{ + Class: 1, + Type: 1, + Labels: []icmp.MPLSLabel{ + { + Label: 16014, + TC: 0x4, + S: true, + TTL: 255, + }, + }, + }, + }, + }, + }, + { + Type: ipv4.ICMPTypeParameterProblem, Code: 2, + Body: &icmp.ParamProb{ + Pointer: 8, + Data: []byte("ERROR-INVOKING-PACKET"), + Extensions: []icmp.Extension{ + &icmp.MPLSLabelStack{ + Class: 1, + Type: 1, + Labels: []icmp.MPLSLabel{ + { + Label: 16014, + TC: 0x4, + S: true, + TTL: 255, + }, + }, + }, + }, + }, + }, +} + +func TestMarshalAndParseMultipartMessageForIPv4(t *testing.T) { + for i, tt := range marshalAndParseMultipartMessageForIPv4Tests { + b, err := tt.Marshal(nil) + if err != nil { + t.Fatal(err) + } + if b[5] != 32 { + t.Errorf("#%v: got %v; want 32", i, b[5]) + } + m, err := icmp.ParseMessage(iana.ProtocolICMP, b) + if err != nil { + t.Fatal(err) + } + if m.Type != tt.Type || m.Code != tt.Code { + t.Errorf("#%v: got %v; want %v", i, m, &tt) + } + switch m.Type { + case ipv4.ICMPTypeDestinationUnreachable: + got, want := m.Body.(*icmp.DstUnreach), tt.Body.(*icmp.DstUnreach) + if !reflect.DeepEqual(got.Extensions, want.Extensions) { + t.Errorf("#%v: got %#v; want %#v", i, got.Extensions, want.Extensions) + } + if len(got.Data) != 128 { + t.Errorf("#%v: got %v; want 128", i, len(got.Data)) + } + case ipv4.ICMPTypeTimeExceeded: + got, want := m.Body.(*icmp.TimeExceeded), tt.Body.(*icmp.TimeExceeded) + if !reflect.DeepEqual(got.Extensions, want.Extensions) { + t.Error(dumpExtensions(i, got.Extensions, want.Extensions)) + } + if len(got.Data) != 128 { + t.Errorf("#%v: got %v; want 128", i, len(got.Data)) + } + case ipv4.ICMPTypeParameterProblem: + got, want := m.Body.(*icmp.ParamProb), tt.Body.(*icmp.ParamProb) + if !reflect.DeepEqual(got.Extensions, want.Extensions) { + t.Error(dumpExtensions(i, got.Extensions, want.Extensions)) + } + if len(got.Data) != 128 { + t.Errorf("#%v: got %v; want 128", i, len(got.Data)) + } + } + } +} + +var marshalAndParseMultipartMessageForIPv6Tests = []icmp.Message{ + { + Type: ipv6.ICMPTypeDestinationUnreachable, Code: 6, + Body: &icmp.DstUnreach{ + Data: []byte("ERROR-INVOKING-PACKET"), + Extensions: []icmp.Extension{ + &icmp.MPLSLabelStack{ + Class: 1, + Type: 1, + Labels: []icmp.MPLSLabel{ + { + Label: 16014, + TC: 0x4, + S: true, + TTL: 255, + }, + }, + }, + }, + }, + }, + { + Type: ipv6.ICMPTypeTimeExceeded, Code: 1, + Body: &icmp.TimeExceeded{ + Data: []byte("ERROR-INVOKING-PACKET"), + Extensions: []icmp.Extension{ + &icmp.MPLSLabelStack{ + Class: 1, + Type: 1, + Labels: []icmp.MPLSLabel{ + { + Label: 16014, + TC: 0x4, + S: true, + TTL: 255, + }, + }, + }, + }, + }, + }, +} + +func TestMarshalAndParseMultipartMessageForIPv6(t *testing.T) { + pshicmp := icmp.IPv6PseudoHeader(net.ParseIP("fe80::1"), net.ParseIP("ff02::1")) + for i, tt := range marshalAndParseMultipartMessageForIPv6Tests { + for _, psh := range [][]byte{pshicmp, nil} { + b, err := tt.Marshal(psh) + if err != nil { + t.Fatal(err) + } + if b[4] != 16 { + t.Errorf("#%v: got %v; want 16", i, b[4]) + } + m, err := icmp.ParseMessage(iana.ProtocolIPv6ICMP, b) + if err != nil { + t.Fatal(err) + } + if m.Type != tt.Type || m.Code != tt.Code { + t.Errorf("#%v: got %v; want %v", i, m, &tt) + } + switch m.Type { + case ipv6.ICMPTypeDestinationUnreachable: + got, want := m.Body.(*icmp.DstUnreach), tt.Body.(*icmp.DstUnreach) + if !reflect.DeepEqual(got.Extensions, want.Extensions) { + t.Error(dumpExtensions(i, got.Extensions, want.Extensions)) + } + if len(got.Data) != 128 { + t.Errorf("#%v: got %v; want 128", i, len(got.Data)) + } + case ipv6.ICMPTypeTimeExceeded: + got, want := m.Body.(*icmp.TimeExceeded), tt.Body.(*icmp.TimeExceeded) + if !reflect.DeepEqual(got.Extensions, want.Extensions) { + t.Error(dumpExtensions(i, got.Extensions, want.Extensions)) + } + if len(got.Data) != 128 { + t.Errorf("#%v: got %v; want 128", i, len(got.Data)) + } + } + } + } +} + +func dumpExtensions(i int, gotExts, wantExts []icmp.Extension) string { + var s string + for j, got := range gotExts { + switch got := got.(type) { + case *icmp.MPLSLabelStack: + want := wantExts[j].(*icmp.MPLSLabelStack) + if !reflect.DeepEqual(got, want) { + s += fmt.Sprintf("#%v/%v: got %#v; want %#v\n", i, j, got, want) + } + } + } + return s[:len(s)-1] +} diff --git a/icmp/packettoobig.go b/icmp/packettoobig.go index 0628e38..91d289b 100644 --- a/icmp/packettoobig.go +++ b/icmp/packettoobig.go @@ -7,11 +7,11 @@ package icmp // A PacketTooBig represents an ICMP packet too big message body. type PacketTooBig struct { MTU int // maximum transmission unit of the nexthop link - Data []byte // data + Data []byte // data, known as original datagram field } // Len implements the Len method of MessageBody interface. -func (p *PacketTooBig) Len() int { +func (p *PacketTooBig) Len(proto int) int { if p == nil { return 0 } diff --git a/icmp/paramprob.go b/icmp/paramprob.go index 48ae601..f200a7c 100644 --- a/icmp/paramprob.go +++ b/icmp/paramprob.go @@ -14,42 +14,47 @@ type ParamProb struct { } // Len implements the Len method of MessageBody interface. -func (p *ParamProb) Len() int { +func (p *ParamProb) Len(proto int) int { if p == nil { return 0 } - return 4 + len(p.Data) + l, _ := multipartMessageBodyDataLen(proto, p.Data, p.Extensions) + return l } // Marshal implements the Marshal method of MessageBody interface. func (p *ParamProb) Marshal(proto int) ([]byte, error) { - b := make([]byte, 4+len(p.Data)) - switch proto { - case iana.ProtocolICMP: - b[0] = byte(p.Pointer) - case iana.ProtocolIPv6ICMP: + if proto == iana.ProtocolIPv6ICMP { + b := make([]byte, 4+p.Len(proto)) b[0], b[1], b[2], b[3] = byte(p.Pointer>>24), byte(p.Pointer>>16), byte(p.Pointer>>8), byte(p.Pointer) + copy(b[4:], p.Data) + return b, nil } - copy(b[4:], p.Data) + b, err := marshalMultipartMessageBody(proto, p.Data, p.Extensions) + if err != nil { + return nil, err + } + b[0] = byte(p.Pointer) return b, nil } // parseParamProb parses b as an ICMP parameter problem message body. func parseParamProb(proto int, b []byte) (MessageBody, error) { - bodyLen := len(b) - if bodyLen < 4 { + if len(b) < 4 { return nil, errMessageTooShort } p := &ParamProb{} - switch proto { - case iana.ProtocolICMP: - p.Pointer = uintptr(b[0]) - case iana.ProtocolIPv6ICMP: + if proto == iana.ProtocolIPv6ICMP { p.Pointer = uintptr(b[0])<<24 | uintptr(b[1])<<16 | uintptr(b[2])<<8 | uintptr(b[3]) - } - if bodyLen > 4 { - p.Data = make([]byte, bodyLen-4) + p.Data = make([]byte, len(b)-4) copy(p.Data, b[4:]) + return p, nil + } + p.Pointer = uintptr(b[0]) + var err error + p.Data, p.Extensions, err = parseMultipartMessageBody(proto, b) + if err != nil { + return nil, err } return p, nil } diff --git a/icmp/timeexceeded.go b/icmp/timeexceeded.go index f61f431..18628c8 100644 --- a/icmp/timeexceeded.go +++ b/icmp/timeexceeded.go @@ -11,30 +11,29 @@ type TimeExceeded struct { } // Len implements the Len method of MessageBody interface. -func (p *TimeExceeded) Len() int { +func (p *TimeExceeded) Len(proto int) int { if p == nil { return 0 } - return 4 + len(p.Data) + l, _ := multipartMessageBodyDataLen(proto, p.Data, p.Extensions) + return l } // Marshal implements the Marshal method of MessageBody interface. func (p *TimeExceeded) Marshal(proto int) ([]byte, error) { - b := make([]byte, 4+len(p.Data)) - copy(b[4:], p.Data) - return b, nil + return marshalMultipartMessageBody(proto, p.Data, p.Extensions) } // parseTimeExceeded parses b as an ICMP time exceeded message body. func parseTimeExceeded(proto int, b []byte) (MessageBody, error) { - bodyLen := len(b) - if bodyLen < 4 { + if len(b) < 4 { return nil, errMessageTooShort } p := &TimeExceeded{} - if bodyLen > 4 { - p.Data = make([]byte, bodyLen-4) - copy(p.Data, b[4:]) + var err error + p.Data, p.Extensions, err = parseMultipartMessageBody(proto, b) + if err != nil { + return nil, err } return p, nil } |