aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Symonds <dsymonds@golang.org>2014-12-22 16:15:28 +1100
committerDavid Symonds <dsymonds@golang.org>2014-12-22 16:15:41 +1100
commit3ea3e05dbfae8d2dd0cc219b5f0f0ba4606d15eb (patch)
tree942d3f6c3a75d32d2fd0cac4328756820880127f
parent904b4403245c2d9a240f8ae32ada68850110e736 (diff)
downloadprotobuf-3ea3e05dbfae8d2dd0cc219b5f0f0ba4606d15eb.tar.gz
Support map<k,v> protocol buffer fields.
-rw-r--r--proto/all_test.go80
-rw-r--r--proto/clone.go23
-rw-r--r--proto/clone_test.go25
-rw-r--r--proto/decode.go68
-rw-r--r--proto/encode.go105
-rw-r--r--proto/equal.go15
-rw-r--r--proto/equal_test.go25
-rw-r--r--proto/pointer_reflect.go5
-rw-r--r--proto/pointer_unsafe.go5
-rw-r--r--proto/properties.go27
-rw-r--r--proto/size_test.go4
-rw-r--r--proto/testdata/test.pb.go33
-rw-r--r--proto/testdata/test.proto6
-rw-r--r--proto/text.go64
-rw-r--r--proto/text_parser.go66
-rw-r--r--proto/text_parser_test.go25
-rw-r--r--proto/text_test.go7
-rw-r--r--protoc-gen-go/generator/generator.go40
-rw-r--r--protoc-gen-go/testdata/my_test/test.pb.go26
-rw-r--r--protoc-gen-go/testdata/my_test/test.pb.go.golden26
-rw-r--r--protoc-gen-go/testdata/my_test/test.proto5
21 files changed, 667 insertions, 13 deletions
diff --git a/proto/all_test.go b/proto/all_test.go
index 6d74ddf..3fade17 100644
--- a/proto/all_test.go
+++ b/proto/all_test.go
@@ -1833,6 +1833,86 @@ func fuzzUnmarshal(t *testing.T, data []byte) {
Unmarshal(data, pb)
}
+func TestMapFieldMarshal(t *testing.T) {
+ m := &MessageWithMap{
+ NameMapping: map[int32]string{
+ 1: "Rob",
+ 4: "Ian",
+ 8: "Dave",
+ },
+ }
+ b, err := Marshal(m)
+ if err != nil {
+ t.Fatalf("Marshal: %v", err)
+ }
+
+ // b should be the concatenation of these three byte sequences in some order.
+ parts := []string{
+ "\n\a\b\x01\x12\x03Rob",
+ "\n\a\b\x04\x12\x03Ian",
+ "\n\b\b\x08\x12\x04Dave",
+ }
+ ok := false
+ for i := range parts {
+ for j := range parts {
+ if j == i {
+ continue
+ }
+ for k := range parts {
+ if k == i || k == j {
+ continue
+ }
+ try := parts[i] + parts[j] + parts[k]
+ if bytes.Equal(b, []byte(try)) {
+ ok = true
+ break
+ }
+ }
+ }
+ }
+ if !ok {
+ t.Fatalf("Incorrect Marshal output.\n got %q\nwant %q (or a permutation of that)", b, parts[0]+parts[1]+parts[2])
+ }
+ t.Logf("FYI b: %q", b)
+
+ (new(Buffer)).DebugPrint("Dump of b", b)
+}
+
+func TestMapFieldRoundTrips(t *testing.T) {
+ m := &MessageWithMap{
+ NameMapping: map[int32]string{
+ 1: "Rob",
+ 4: "Ian",
+ 8: "Dave",
+ },
+ MsgMapping: map[int64]*FloatingPoint{
+ 0x7001: &FloatingPoint{F: Float64(2.0)},
+ },
+ ByteMapping: map[bool][]byte{
+ false: []byte("that's not right!"),
+ true: []byte("aye, 'tis true!"),
+ },
+ }
+ b, err := Marshal(m)
+ if err != nil {
+ t.Fatalf("Marshal: %v", err)
+ }
+ t.Logf("FYI b: %q", b)
+ m2 := new(MessageWithMap)
+ if err := Unmarshal(b, m2); err != nil {
+ t.Fatalf("Unmarshal: %v", err)
+ }
+ for _, pair := range [][2]interface{}{
+ {m.NameMapping, m2.NameMapping},
+ {m.MsgMapping, m2.MsgMapping},
+ {m.ByteMapping, m2.ByteMapping},
+ } {
+ if !reflect.DeepEqual(pair[0], pair[1]) {
+ t.Errorf("Map did not survive a round trip.\ninitial: %v\n final: %v", pair[0], pair[1])
+ }
+ }
+}
+
// Benchmarks
func testMsg() *GoTest {
diff --git a/proto/clone.go b/proto/clone.go
index fd11aaf..ae276fd 100644
--- a/proto/clone.go
+++ b/proto/clone.go
@@ -113,6 +113,29 @@ func mergeAny(out, in reflect.Value) {
case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64,
reflect.String, reflect.Uint32, reflect.Uint64:
out.Set(in)
+ case reflect.Map:
+ if in.Len() == 0 {
+ return
+ }
+ if out.IsNil() {
+ out.Set(reflect.MakeMap(in.Type()))
+ }
+ // For maps with value types of *T or []byte we need to deep copy each value.
+ elemKind := in.Type().Elem().Kind()
+ for _, key := range in.MapKeys() {
+ var val reflect.Value
+ switch elemKind {
+ case reflect.Ptr:
+ val = reflect.New(in.Type().Elem().Elem())
+ mergeAny(val, in.MapIndex(key))
+ case reflect.Slice:
+ val = in.MapIndex(key)
+ val = reflect.ValueOf(append([]byte{}, val.Bytes()...))
+ default:
+ val = in.MapIndex(key)
+ }
+ out.SetMapIndex(key, val)
+ }
case reflect.Ptr:
if in.IsNil() {
return
diff --git a/proto/clone_test.go b/proto/clone_test.go
index 03c2a7b..1ac177d 100644
--- a/proto/clone_test.go
+++ b/proto/clone_test.go
@@ -189,6 +189,31 @@ var mergeTests = []struct {
dst: &pb.OtherMessage{Value: []byte("bar")},
want: &pb.OtherMessage{Value: []byte("foo")},
},
+ {
+ src: &pb.MessageWithMap{
+ NameMapping: map[int32]string{6: "Nigel"},
+ MsgMapping: map[int64]*pb.FloatingPoint{
+ 0x4001: &pb.FloatingPoint{F: proto.Float64(2.0)},
+ },
+ ByteMapping: map[bool][]byte{true: []byte("wowsa")},
+ },
+ dst: &pb.MessageWithMap{
+ NameMapping: map[int32]string{
+ 6: "Bruce", // should be overwritten
+ 7: "Andrew",
+ },
+ },
+ want: &pb.MessageWithMap{
+ NameMapping: map[int32]string{
+ 6: "Nigel",
+ 7: "Andrew",
+ },
+ MsgMapping: map[int64]*pb.FloatingPoint{
+ 0x4001: &pb.FloatingPoint{F: proto.Float64(2.0)},
+ },
+ ByteMapping: map[bool][]byte{true: []byte("wowsa")},
+ },
+ },
}
func TestMerge(t *testing.T) {
diff --git a/proto/decode.go b/proto/decode.go
index 6166dd4..88622c3 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -178,7 +178,7 @@ func (p *Buffer) DecodeZigzag32() (x uint64, err error) {
func (p *Buffer) DecodeRawBytes(alloc bool) (buf []byte, err error) {
n, err := p.DecodeVarint()
if err != nil {
- return
+ return nil, err
}
nb := int(n)
@@ -668,6 +668,72 @@ func (o *Buffer) dec_slice_slice_byte(p *Properties, base structPointer) error {
return nil
}
+// Decode a map field.
+func (o *Buffer) dec_new_map(p *Properties, base structPointer) error {
+ raw, err := o.DecodeRawBytes(false)
+ if err != nil {
+ return err
+ }
+ oi := o.index // index at the end of this map entry
+ o.index -= len(raw) // move buffer back to start of map entry
+
+ mptr := structPointer_Map(base, p.field, p.mtype) // *map[K]V
+ if mptr.Elem().IsNil() {
+ mptr.Elem().Set(reflect.MakeMap(mptr.Type().Elem()))
+ }
+ v := mptr.Elem() // map[K]V
+
+ // Prepare addressable doubly-indirect placeholders for the key and value types.
+ // See enc_new_map for why.
+ keyptr := reflect.New(reflect.PtrTo(p.mtype.Key())).Elem() // addressable *K
+ keybase := toStructPointer(keyptr.Addr()) // **K
+
+ var valbase structPointer
+ var valptr reflect.Value
+ switch p.mtype.Elem().Kind() {
+ case reflect.Slice:
+ // []byte
+ var dummy []byte
+ valptr = reflect.ValueOf(&dummy) // *[]byte
+ valbase = toStructPointer(valptr) // *[]byte
+ case reflect.Ptr:
+ // message; valptr is **Msg; need to allocate the intermediate pointer
+ valptr = reflect.New(reflect.PtrTo(p.mtype.Elem())).Elem() // addressable *V
+ valptr.Set(reflect.New(valptr.Type().Elem()))
+ valbase = toStructPointer(valptr)
+ default:
+ // everything else
+ valptr = reflect.New(reflect.PtrTo(p.mtype.Elem())).Elem() // addressable *V
+ valbase = toStructPointer(valptr.Addr()) // **V
+ }
+
+ // Decode.
+ // This parses a restricted wire format, namely the encoding of a message
+ // with two fields. See enc_new_map for the format.
+ for o.index < oi {
+ // tagcode for key and value properties are always a single byte
+ // because they have tags 1 and 2.
+ tagcode := o.buf[o.index]
+ o.index++
+ switch tagcode {
+ case p.mkeyprop.tagcode[0]:
+ if err := p.mkeyprop.dec(o, p.mkeyprop, keybase); err != nil {
+ return err
+ }
+ case p.mvalprop.tagcode[0]:
+ if err := p.mvalprop.dec(o, p.mvalprop, valbase); err != nil {
+ return err
+ }
+ default:
+ // TODO: Should we silently skip this instead?
+ return fmt.Errorf("proto: bad map data tag %d", raw[0])
+ }
+ }
+
+ v.SetMapIndex(keyptr.Elem(), valptr.Elem())
+ return nil
+}
+
// Decode a group.
func (o *Buffer) dec_struct_group(p *Properties, base structPointer) error {
bas := structPointer_GetStructPointer(base, p.field)
diff --git a/proto/encode.go b/proto/encode.go
index cc202cd..f5050e3 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -1069,6 +1069,104 @@ func size_map(p *Properties, base structPointer) int {
return sizeExtensionMap(v)
}
+// Encode a map field.
+func (o *Buffer) enc_new_map(p *Properties, base structPointer) error {
+ var state errorState // XXX: or do we need to plumb this through?
+
+ /*
+ A map defined as
+ map<key_type, value_type> map_field = N;
+ is encoded in the same way as
+ message MapFieldEntry {
+ key_type key = 1;
+ value_type value = 2;
+ }
+ repeated MapFieldEntry map_field = N;
+ */
+
+ v := structPointer_Map(base, p.field, p.mtype).Elem() // map[K]V
+ if v.Len() == 0 {
+ return nil
+ }
+
+ keycopy, valcopy, keybase, valbase := mapEncodeScratch(p.mtype)
+
+ enc := func() error {
+ if err := p.mkeyprop.enc(o, p.mkeyprop, keybase); err != nil {
+ return err
+ }
+ if err := p.mvalprop.enc(o, p.mvalprop, valbase); err != nil {
+ return err
+ }
+ return nil
+ }
+
+ for _, key := range v.MapKeys() {
+ val := v.MapIndex(key)
+
+ keycopy.Set(key)
+ valcopy.Set(val)
+
+ o.buf = append(o.buf, p.tagcode...)
+ if err := o.enc_len_thing(enc, &state); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func size_new_map(p *Properties, base structPointer) int {
+ v := structPointer_Map(base, p.field, p.mtype).Elem() // map[K]V
+
+ keycopy, valcopy, keybase, valbase := mapEncodeScratch(p.mtype)
+
+ n := 0
+ for _, key := range v.MapKeys() {
+ val := v.MapIndex(key)
+ keycopy.Set(key)
+ valcopy.Set(val)
+
+ // Tag codes are two bytes per map entry.
+ n += 2
+ n += p.mkeyprop.size(p.mkeyprop, keybase)
+ n += p.mvalprop.size(p.mvalprop, valbase)
+ }
+ return n
+}
+
+// mapEncodeScratch returns a new reflect.Value matching the map's value type,
+// and a structPointer suitable for passing to an encoder or sizer.
+func mapEncodeScratch(mapType reflect.Type) (keycopy, valcopy reflect.Value, keybase, valbase structPointer) {
+ // Prepare addressable doubly-indirect placeholders for the key and value types.
+ // This is needed because the element-type encoders expect **T, but the map iteration produces T.
+
+ keycopy = reflect.New(mapType.Key()).Elem() // addressable K
+ keyptr := reflect.New(reflect.PtrTo(keycopy.Type())).Elem() // addressable *K
+ keyptr.Set(keycopy.Addr()) //
+ keybase = toStructPointer(keyptr.Addr()) // **K
+
+ // Value types are more varied and require special handling.
+ switch mapType.Elem().Kind() {
+ case reflect.Slice:
+ // []byte
+ var dummy []byte
+ valcopy = reflect.ValueOf(&dummy).Elem() // addressable []byte
+ valbase = toStructPointer(valcopy.Addr())
+ case reflect.Ptr:
+ // message; the generated field type is map[K]*Msg (so V is *Msg),
+ // so we only need one level of indirection.
+ valcopy = reflect.New(mapType.Elem()).Elem() // addressable V
+ valbase = toStructPointer(valcopy.Addr())
+ default:
+ // everything else
+ valcopy = reflect.New(mapType.Elem()).Elem() // addressable V
+ valptr := reflect.New(reflect.PtrTo(valcopy.Type())).Elem() // addressable *V
+ valptr.Set(valcopy.Addr()) //
+ valbase = toStructPointer(valptr.Addr()) // **V
+ }
+ return
+}
+
// Encode a struct.
func (o *Buffer) enc_struct(prop *StructProperties, base structPointer) error {
var state errorState
@@ -1123,10 +1221,15 @@ var zeroes [20]byte // longer than any conceivable sizeVarint
// Encode a struct, preceded by its encoded length (as a varint).
func (o *Buffer) enc_len_struct(prop *StructProperties, base structPointer, state *errorState) error {
+ return o.enc_len_thing(func() error { return o.enc_struct(prop, base) }, state)
+}
+
+// Encode something, preceded by its encoded length (as a varint).
+func (o *Buffer) enc_len_thing(enc func() error, state *errorState) error {
iLen := len(o.buf)
o.buf = append(o.buf, 0, 0, 0, 0) // reserve four bytes for length
iMsg := len(o.buf)
- err := o.enc_struct(prop, base)
+ err := enc()
if err != nil && !state.shouldContinue(err, nil) {
return err
}
diff --git a/proto/equal.go b/proto/equal.go
index ebdfdca..d8673a3 100644
--- a/proto/equal.go
+++ b/proto/equal.go
@@ -154,6 +154,21 @@ func equalAny(v1, v2 reflect.Value) bool {
return v1.Float() == v2.Float()
case reflect.Int32, reflect.Int64:
return v1.Int() == v2.Int()
+ case reflect.Map:
+ if v1.Len() != v2.Len() {
+ return false
+ }
+ for _, key := range v1.MapKeys() {
+ val2 := v2.MapIndex(key)
+ if !val2.IsValid() {
+ // This key was not found in the second map.
+ return false
+ }
+ if !equalAny(v1.MapIndex(key), val2) {
+ return false
+ }
+ }
+ return true
case reflect.Ptr:
return equalAny(v1.Elem(), v2.Elem())
case reflect.Slice:
diff --git a/proto/equal_test.go b/proto/equal_test.go
index ebcf340..cc25833 100644
--- a/proto/equal_test.go
+++ b/proto/equal_test.go
@@ -155,6 +155,31 @@ var EqualTests = []struct {
},
true,
},
+
+ {
+ "map same",
+ &pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}},
+ &pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}},
+ true,
+ },
+ {
+ "map different entry",
+ &pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}},
+ &pb.MessageWithMap{NameMapping: map[int32]string{2: "Rob"}},
+ false,
+ },
+ {
+ "map different key only",
+ &pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}},
+ &pb.MessageWithMap{NameMapping: map[int32]string{2: "Ken"}},
+ false,
+ },
+ {
+ "map different value only",
+ &pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}},
+ &pb.MessageWithMap{NameMapping: map[int32]string{1: "Rob"}},
+ false,
+ },
}
func TestEqual(t *testing.T) {
diff --git a/proto/pointer_reflect.go b/proto/pointer_reflect.go
index 42c387a..93259a3 100644
--- a/proto/pointer_reflect.go
+++ b/proto/pointer_reflect.go
@@ -144,6 +144,11 @@ func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension {
return structPointer_ifield(p, f).(*map[int32]Extension)
}
+// Map returns the reflect.Value for the address of a map field in the struct.
+func structPointer_Map(p structPointer, f field, typ reflect.Type) reflect.Value {
+ return structPointer_field(p, f).Addr()
+}
+
// SetStructPointer writes a *struct field in the struct.
func structPointer_SetStructPointer(p structPointer, f field, q structPointer) {
structPointer_field(p, f).Set(q.v)
diff --git a/proto/pointer_unsafe.go b/proto/pointer_unsafe.go
index cf9fc9a..c52db1c 100644
--- a/proto/pointer_unsafe.go
+++ b/proto/pointer_unsafe.go
@@ -130,6 +130,11 @@ func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension {
return (*map[int32]Extension)(unsafe.Pointer(uintptr(p) + uintptr(f)))
}
+// Map returns the reflect.Value for the address of a map field in the struct.
+func structPointer_Map(p structPointer, f field, typ reflect.Type) reflect.Value {
+ return reflect.NewAt(typ, unsafe.Pointer(uintptr(p)+uintptr(f)))
+}
+
// SetStructPointer writes a *struct field in the struct.
func structPointer_SetStructPointer(p structPointer, f field, q structPointer) {
*(*structPointer)(unsafe.Pointer(uintptr(p) + uintptr(f))) = q
diff --git a/proto/properties.go b/proto/properties.go
index 4420881..730a595 100644
--- a/proto/properties.go
+++ b/proto/properties.go
@@ -171,6 +171,10 @@ type Properties struct {
isMarshaler bool
isUnmarshaler bool
+ mtype reflect.Type // set for map types only
+ mkeyprop *Properties // set for map types only
+ mvalprop *Properties // set for map types only
+
size sizer
valSize valueSizer // set for bool and numeric types only
@@ -299,7 +303,7 @@ func logNoSliceEnc(t1, t2 reflect.Type) {
var protoMessageType = reflect.TypeOf((*Message)(nil)).Elem()
// Initialize the fields for encoding and decoding.
-func (p *Properties) setEncAndDec(typ reflect.Type, lockGetProp bool) {
+func (p *Properties) setEncAndDec(typ reflect.Type, f *reflect.StructField, lockGetProp bool) {
p.enc = nil
p.dec = nil
p.size = nil
@@ -342,7 +346,7 @@ func (p *Properties) setEncAndDec(typ reflect.Type, lockGetProp bool) {
case reflect.Ptr:
switch t2 := t1.Elem(); t2.Kind() {
default:
- fmt.Fprintf(os.Stderr, "proto: no encoder function for %T -> %T\n", t1, t2)
+ fmt.Fprintf(os.Stderr, "proto: no encoder function for %v -> %v\n", t1, t2)
break
case reflect.Bool:
p.enc = (*Buffer).enc_bool
@@ -502,6 +506,23 @@ func (p *Properties) setEncAndDec(typ reflect.Type, lockGetProp bool) {
p.size = size_slice_slice_byte
}
}
+
+ case reflect.Map:
+ p.enc = (*Buffer).enc_new_map
+ p.dec = (*Buffer).dec_new_map
+ p.size = size_new_map
+
+ p.mtype = t1
+ p.mkeyprop = &Properties{}
+ p.mkeyprop.init(reflect.PtrTo(p.mtype.Key()), "Key", f.Tag.Get("protobuf_key"), nil, lockGetProp)
+ p.mvalprop = &Properties{}
+ vtype := p.mtype.Elem()
+ if vtype.Kind() != reflect.Ptr && vtype.Kind() != reflect.Slice {
+ // The value type is not a message (*T) or bytes ([]byte),
+ // so we need encoders for the pointer to this type.
+ vtype = reflect.PtrTo(vtype)
+ }
+ p.mvalprop.init(vtype, "Value", f.Tag.Get("protobuf_val"), nil, lockGetProp)
}
// precalculate tag code
@@ -570,7 +591,7 @@ func (p *Properties) init(typ reflect.Type, name, tag string, f *reflect.StructF
return
}
p.Parse(tag)
- p.setEncAndDec(typ, lockGetProp)
+ p.setEncAndDec(typ, f, lockGetProp)
}
var (
diff --git a/proto/size_test.go b/proto/size_test.go
index 4f87f3b..e5f92d6 100644
--- a/proto/size_test.go
+++ b/proto/size_test.go
@@ -113,6 +113,10 @@ var SizeTests = []struct {
{"proto3 bytes", &proto3pb.Message{Data: []byte("wowsa")}},
{"proto3 bytes, empty", &proto3pb.Message{Data: []byte{}}},
{"proto3 enum", &proto3pb.Message{Hilarity: proto3pb.Message_PUNS}},
+
+ {"map field", &pb.MessageWithMap{NameMapping: map[int32]string{1: "Rob", 7: "Andrew"}}},
+ {"map field with message", &pb.MessageWithMap{MsgMapping: map[int64]*pb.FloatingPoint{0x7001: &pb.FloatingPoint{F: Float64(2.0)}}}},
+ {"map field with bytes", &pb.MessageWithMap{ByteMapping: map[bool][]byte{true: []byte("this time for sure")}}},
}
func TestSize(t *testing.T) {
diff --git a/proto/testdata/test.pb.go b/proto/testdata/test.pb.go
index 4d46e6e..f47d9e0 100644
--- a/proto/testdata/test.pb.go
+++ b/proto/testdata/test.pb.go
@@ -33,6 +33,7 @@ It has these top-level messages:
GroupOld
GroupNew
FloatingPoint
+ MessageWithMap
*/
package testdata
@@ -1885,6 +1886,38 @@ func (m *FloatingPoint) GetF() float64 {
return 0
}
+type MessageWithMap struct {
+ NameMapping map[int32]string `protobuf:"bytes,1,rep,name=name_mapping" json:"name_mapping,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
+ MsgMapping map[int64]*FloatingPoint `protobuf:"bytes,2,rep,name=msg_mapping" json:"msg_mapping,omitempty" protobuf_key:"zigzag64,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
+ ByteMapping map[bool][]byte `protobuf:"bytes,3,rep,name=byte_mapping" json:"byte_mapping,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
+ XXX_unrecognized []byte `json:"-"`
+}
+
+func (m *MessageWithMap) Reset() { *m = MessageWithMap{} }
+func (m *MessageWithMap) String() string { return proto.CompactTextString(m) }
+func (*MessageWithMap) ProtoMessage() {}
+
+func (m *MessageWithMap) GetNameMapping() map[int32]string {
+ if m != nil {
+ return m.NameMapping
+ }
+ return nil
+}
+
+func (m *MessageWithMap) GetMsgMapping() map[int64]*FloatingPoint {
+ if m != nil {
+ return m.MsgMapping
+ }
+ return nil
+}
+
+func (m *MessageWithMap) GetByteMapping() map[bool][]byte {
+ if m != nil {
+ return m.ByteMapping
+ }
+ return nil
+}
+
var E_Greeting = &proto.ExtensionDesc{
ExtendedType: (*MyMessage)(nil),
ExtensionType: ([]string)(nil),
diff --git a/proto/testdata/test.proto b/proto/testdata/test.proto
index ac9542a..6cc755b 100644
--- a/proto/testdata/test.proto
+++ b/proto/testdata/test.proto
@@ -426,3 +426,9 @@ message GroupNew {
message FloatingPoint {
required double f = 1;
}
+
+message MessageWithMap {
+ map<int32, string> name_mapping = 1;
+ map<sint64, FloatingPoint> msg_mapping = 2;
+ map<bool, bytes> byte_mapping = 3;
+}
diff --git a/proto/text.go b/proto/text.go
index 426db1e..f41a946 100644
--- a/proto/text.go
+++ b/proto/text.go
@@ -244,6 +244,70 @@ func writeStruct(w *textWriter, sv reflect.Value) error {
}
continue
}
+ if fv.Kind() == reflect.Map {
+ // Map fields are rendered as a repeated struct with key/value fields.
+ keys := fv.MapKeys() // TODO: should we sort these for deterministic output?
+ for _, key := range keys {
+ val := fv.MapIndex(key)
+ if err := writeName(w, props); err != nil {
+ return err
+ }
+ if !w.compact {
+ if err := w.WriteByte(' '); err != nil {
+ return err
+ }
+ }
+ // open struct
+ if err := w.WriteByte('<'); err != nil {
+ return err
+ }
+ if !w.compact {
+ if err := w.WriteByte('\n'); err != nil {
+ return err
+ }
+ }
+ w.indent()
+ // key
+ if _, err := w.WriteString("key:"); err != nil {
+ return err
+ }
+ if !w.compact {
+ if err := w.WriteByte(' '); err != nil {
+ return err
+ }
+ }
+ if err := writeAny(w, key, props.mkeyprop); err != nil {
+ return err
+ }
+ if err := w.WriteByte('\n'); err != nil {
+ return err
+ }
+ // value
+ if _, err := w.WriteString("value:"); err != nil {
+ return err
+ }
+ if !w.compact {
+ if err := w.WriteByte(' '); err != nil {
+ return err
+ }
+ }
+ if err := writeAny(w, val, props.mvalprop); err != nil {
+ return err
+ }
+ if err := w.WriteByte('\n'); err != nil {
+ return err
+ }
+ // close struct
+ w.unindent()
+ if err := w.WriteByte('>'); err != nil {
+ return err
+ }
+ if err := w.WriteByte('\n'); err != nil {
+ return err
+ }
+ }
+ continue
+ }
if props.proto3 && fv.Kind() == reflect.Slice && fv.Len() == 0 {
// empty bytes field
continue
diff --git a/proto/text_parser.go b/proto/text_parser.go
index f733f30..ddd9579 100644
--- a/proto/text_parser.go
+++ b/proto/text_parser.go
@@ -355,6 +355,18 @@ func (p *textParser) next() *token {
return &p.cur
}
+func (p *textParser) consumeToken(s string) error {
+ tok := p.next()
+ if tok.err != nil {
+ return tok.err
+ }
+ if tok.value != s {
+ p.back()
+ return p.errorf("expected %q, found %q", s, tok.value)
+ }
+ return nil
+}
+
// Return a RequiredNotSetError indicating which required field was not set.
func (p *textParser) missingRequiredFieldError(sv reflect.Value) *RequiredNotSetError {
st := sv.Type()
@@ -518,6 +530,60 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error {
dst := sv.Field(fi)
+ if dst.Kind() == reflect.Map {
+ // Consume any colon.
+ if err := p.checkForColon(props, dst.Type()); err != nil {
+ return err
+ }
+
+ // Construct the map if it doesn't already exist.
+ if dst.IsNil() {
+ dst.Set(reflect.MakeMap(dst.Type()))
+ }
+ key := reflect.New(dst.Type().Key()).Elem()
+ val := reflect.New(dst.Type().Elem()).Elem()
+
+ // The map entry should be this sequence of tokens:
+ // < key : KEY value : VALUE >
+ // Technically the "key" and "value" could come in any order,
+ // but in practice they won't.
+
+ tok := p.next()
+ var terminator string
+ switch tok.value {
+ case "<":
+ terminator = ">"
+ case "{":
+ terminator = "}"
+ default:
+ return p.errorf("expected '{' or '<', found %q", tok.value)
+ }
+ if err := p.consumeToken("key"); err != nil {
+ return err
+ }
+ if err := p.consumeToken(":"); err != nil {
+ return err
+ }
+ if err := p.readAny(key, props.mkeyprop); err != nil {
+ return err
+ }
+ if err := p.consumeToken("value"); err != nil {
+ return err
+ }
+ if err := p.consumeToken(":"); err != nil {
+ return err
+ }
+ if err := p.readAny(val, props.mvalprop); err != nil {
+ return err
+ }
+ if err := p.consumeToken(terminator); err != nil {
+ return err
+ }
+
+ dst.SetMapIndex(key, val)
+ continue
+ }
+
// Check that it's not already set if it's not a repeated field.
if !props.Repeated && fieldSet[name] {
return p.errorf("non-repeated field %q was repeated", name)
diff --git a/proto/text_parser_test.go b/proto/text_parser_test.go
index 89ab106..e5ee8b9 100644
--- a/proto/text_parser_test.go
+++ b/proto/text_parser_test.go
@@ -459,6 +459,31 @@ func TestProto3TextParsing(t *testing.T) {
}
}
+func TestMapParsing(t *testing.T) {
+ m := new(MessageWithMap)
+ const in = `name_mapping:<key:1234 value:"Feist"> name_mapping:<key:1 value:"Beatles">` +
+ `msg_mapping:<key:-4 value:<f: 2.0>>` +
+ `byte_mapping:<key:true value:"so be it">`
+ want := &MessageWithMap{
+ NameMapping: map[int32]string{
+ 1: "Beatles",
+ 1234: "Feist",
+ },
+ MsgMapping: map[int64]*FloatingPoint{
+ -4: {F: Float64(2.0)},
+ },
+ ByteMapping: map[bool][]byte{
+ true: []byte("so be it"),
+ },
+ }
+ if err := UnmarshalText(in, m); err != nil {
+ t.Fatal(err)
+ }
+ if !Equal(m, want) {
+ t.Errorf("\n got %v\nwant %v", m, want)
+ }
+}
+
var benchInput string
func init() {
diff --git a/proto/text_test.go b/proto/text_test.go
index 404920e..707bedd 100644
--- a/proto/text_test.go
+++ b/proto/text_test.go
@@ -419,6 +419,13 @@ func TestProto3Text(t *testing.T) {
{&proto3pb.Message{Data: []byte{}}, ``},
// trivial case
{&proto3pb.Message{Name: "Rob", HeightInCm: 175}, `name:"Rob" height_in_cm:175`},
+ // empty map
+ {&pb.MessageWithMap{}, ``},
+ // non-empty map; current map format is the same as a repeated struct
+ {
+ &pb.MessageWithMap{NameMapping: map[int32]string{1234: "Feist"}},
+ `name_mapping:<key:1234 value:"Feist" >`,
+ },
}
for _, test := range tests {
got := strings.TrimSpace(test.m.String())
diff --git a/protoc-gen-go/generator/generator.go b/protoc-gen-go/generator/generator.go
index 4b309cb..8b02ba5 100644
--- a/protoc-gen-go/generator/generator.go
+++ b/protoc-gen-go/generator/generator.go
@@ -102,6 +102,12 @@ func fileIsProto3(file *descriptor.FileDescriptorProto) bool {
func (c *common) proto3() bool { return fileIsProto3(c.file) }
+func fileUsesMaps(file *descriptor.FileDescriptorProto) bool {
+ return true
+}
+
+func (c *common) usesMaps() bool { return fileUsesMaps(c.file) }
+
// Descriptor represents a protocol buffer message.
type Descriptor struct {
common
@@ -1011,6 +1017,10 @@ func (g *Generator) generate(file *FileDescriptor) {
g.generateEnum(enum)
}
for _, desc := range g.file.desc {
+ // Don't generate virtual messages for maps.
+ if desc.GetOptions().GetMapEntry() && desc.usesMaps() {
+ continue
+ }
g.generateMessage(desc)
}
for _, ext := range g.file.ext {
@@ -1499,6 +1509,7 @@ func (g *Generator) generateMessage(message *Descriptor) {
}
fieldNames := make(map[*descriptor.FieldDescriptorProto]string)
fieldGetterNames := make(map[*descriptor.FieldDescriptorProto]string)
+ mapFieldTypes := make(map[*descriptor.FieldDescriptorProto]string)
g.PrintComments(message.path)
g.P("type ", ccTypeName, " struct {")
@@ -1516,6 +1527,32 @@ func (g *Generator) generateMessage(message *Descriptor) {
typename, wiretype := g.GoType(message, field)
jsonName := *field.Name
tag := fmt.Sprintf("protobuf:%s json:%q", g.goTag(message, field, wiretype), jsonName+",omitempty")
+
+ if *field.Type == descriptor.FieldDescriptorProto_TYPE_MESSAGE {
+ desc := g.ObjectNamed(field.GetTypeName())
+ if d, ok := desc.(*Descriptor); ok && d.GetOptions().GetMapEntry() && d.usesMaps() {
+ // Figure out the Go types and tags for the key and value types.
+ keyField, valField := d.Field[0], d.Field[1]
+ keyType, keyWire := g.GoType(d, keyField)
+ valType, valWire := g.GoType(d, valField)
+ keyTag, valTag := g.goTag(d, keyField, keyWire), g.goTag(d, valField, valWire)
+
+ // We don't use stars, except for message-typed values.
+ keyType = strings.TrimPrefix(keyType, "*")
+ switch *valField.Type {
+ case descriptor.FieldDescriptorProto_TYPE_MESSAGE:
+ g.RecordTypeUse(valField.GetTypeName())
+ default:
+ valType = strings.TrimPrefix(valType, "*")
+ }
+
+ typename = fmt.Sprintf("map[%s]%s", keyType, valType)
+ mapFieldTypes[field] = typename // record for the getter generation
+
+ tag += fmt.Sprintf(" protobuf_key:%s protobuf_val:%s", keyTag, valTag)
+ }
+ }
+
fieldNames[field] = fieldName
fieldGetterNames[field] = fieldGetterName
g.P(fieldName, "\t", typename, "\t`", tag, "`")
@@ -1655,6 +1692,9 @@ func (g *Generator) generateMessage(message *Descriptor) {
for _, field := range message.Field {
fname := fieldNames[field]
typename, _ := g.GoType(message, field)
+ if t, ok := mapFieldTypes[field]; ok {
+ typename = t
+ }
mname := "Get" + fieldGetterNames[field]
star := ""
if needsStar(*field.Type) && typename[0] == '*' {
diff --git a/protoc-gen-go/testdata/my_test/test.pb.go b/protoc-gen-go/testdata/my_test/test.pb.go
index c39e58f..5cd7b2a 100644
--- a/protoc-gen-go/testdata/my_test/test.pb.go
+++ b/protoc-gen-go/testdata/my_test/test.pb.go
@@ -174,10 +174,14 @@ type Request struct {
Hue *Request_Color `protobuf:"varint,3,opt,name=hue,enum=my.test.Request_Color" json:"hue,omitempty"`
Hat *HatType `protobuf:"varint,4,opt,name=hat,enum=my.test.HatType,def=1" json:"hat,omitempty"`
// optional imp.ImportedMessage.Owner owner = 6;
- Deadline *float32 `protobuf:"fixed32,7,opt,name=deadline,def=inf" json:"deadline,omitempty"`
- Somegroup *Request_SomeGroup `protobuf:"group,8,opt,name=SomeGroup" json:"somegroup,omitempty"`
- Reset_ *int32 `protobuf:"varint,12,opt,name=reset" json:"reset,omitempty"`
- XXX_unrecognized []byte `json:"-"`
+ Deadline *float32 `protobuf:"fixed32,7,opt,name=deadline,def=inf" json:"deadline,omitempty"`
+ Somegroup *Request_SomeGroup `protobuf:"group,8,opt,name=SomeGroup" json:"somegroup,omitempty"`
+ // This is a map field. It will generate map[int32]string.
+ NameMapping map[int32]string `protobuf:"bytes,14,rep,name=name_mapping" json:"name_mapping,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
+ // This is a map field whose value type is a message.
+ MsgMapping map[int64]*Reply `protobuf:"bytes,15,rep,name=msg_mapping" json:"msg_mapping,omitempty" protobuf_key:"zigzag64,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
+ Reset_ *int32 `protobuf:"varint,12,opt,name=reset" json:"reset,omitempty"`
+ XXX_unrecognized []byte `json:"-"`
}
func (m *Request) Reset() { *m = Request{} }
@@ -223,6 +227,20 @@ func (m *Request) GetSomegroup() *Request_SomeGroup {
return nil
}
+func (m *Request) GetNameMapping() map[int32]string {
+ if m != nil {
+ return m.NameMapping
+ }
+ return nil
+}
+
+func (m *Request) GetMsgMapping() map[int64]*Reply {
+ if m != nil {
+ return m.MsgMapping
+ }
+ return nil
+}
+
func (m *Request) GetReset_() int32 {
if m != nil && m.Reset_ != nil {
return *m.Reset_
diff --git a/protoc-gen-go/testdata/my_test/test.pb.go.golden b/protoc-gen-go/testdata/my_test/test.pb.go.golden
index c39e58f..5cd7b2a 100644
--- a/protoc-gen-go/testdata/my_test/test.pb.go.golden
+++ b/protoc-gen-go/testdata/my_test/test.pb.go.golden
@@ -174,10 +174,14 @@ type Request struct {
Hue *Request_Color `protobuf:"varint,3,opt,name=hue,enum=my.test.Request_Color" json:"hue,omitempty"`
Hat *HatType `protobuf:"varint,4,opt,name=hat,enum=my.test.HatType,def=1" json:"hat,omitempty"`
// optional imp.ImportedMessage.Owner owner = 6;
- Deadline *float32 `protobuf:"fixed32,7,opt,name=deadline,def=inf" json:"deadline,omitempty"`
- Somegroup *Request_SomeGroup `protobuf:"group,8,opt,name=SomeGroup" json:"somegroup,omitempty"`
- Reset_ *int32 `protobuf:"varint,12,opt,name=reset" json:"reset,omitempty"`
- XXX_unrecognized []byte `json:"-"`
+ Deadline *float32 `protobuf:"fixed32,7,opt,name=deadline,def=inf" json:"deadline,omitempty"`
+ Somegroup *Request_SomeGroup `protobuf:"group,8,opt,name=SomeGroup" json:"somegroup,omitempty"`
+ // This is a map field. It will generate map[int32]string.
+ NameMapping map[int32]string `protobuf:"bytes,14,rep,name=name_mapping" json:"name_mapping,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
+ // This is a map field whose value type is a message.
+ MsgMapping map[int64]*Reply `protobuf:"bytes,15,rep,name=msg_mapping" json:"msg_mapping,omitempty" protobuf_key:"zigzag64,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
+ Reset_ *int32 `protobuf:"varint,12,opt,name=reset" json:"reset,omitempty"`
+ XXX_unrecognized []byte `json:"-"`
}
func (m *Request) Reset() { *m = Request{} }
@@ -223,6 +227,20 @@ func (m *Request) GetSomegroup() *Request_SomeGroup {
return nil
}
+func (m *Request) GetNameMapping() map[int32]string {
+ if m != nil {
+ return m.NameMapping
+ }
+ return nil
+}
+
+func (m *Request) GetMsgMapping() map[int64]*Reply {
+ if m != nil {
+ return m.MsgMapping
+ }
+ return nil
+}
+
func (m *Request) GetReset_() int32 {
if m != nil && m.Reset_ != nil {
return *m.Reset_
diff --git a/protoc-gen-go/testdata/my_test/test.proto b/protoc-gen-go/testdata/my_test/test.proto
index efcc0a0..af69c47 100644
--- a/protoc-gen-go/testdata/my_test/test.proto
+++ b/protoc-gen-go/testdata/my_test/test.proto
@@ -72,6 +72,11 @@ message Request {
// optional imp.PubliclyImportedEnum pub_enum = 13 [default=HAIR];
+ // This is a map field. It will generate map[int32]string.
+ map<int32, string> name_mapping = 14;
+ // This is a map field whose value type is a message.
+ map<sint64, Reply> msg_mapping = 15;
+
optional int32 reset = 12;
}