diff options
author | David Symonds <dsymonds@golang.org> | 2014-12-22 16:15:28 +1100 |
---|---|---|
committer | David Symonds <dsymonds@golang.org> | 2014-12-22 16:15:41 +1100 |
commit | 3ea3e05dbfae8d2dd0cc219b5f0f0ba4606d15eb (patch) | |
tree | 942d3f6c3a75d32d2fd0cac4328756820880127f | |
parent | 904b4403245c2d9a240f8ae32ada68850110e736 (diff) | |
download | protobuf-3ea3e05dbfae8d2dd0cc219b5f0f0ba4606d15eb.tar.gz |
Support map<k,v> protocol buffer fields.
-rw-r--r-- | proto/all_test.go | 80 | ||||
-rw-r--r-- | proto/clone.go | 23 | ||||
-rw-r--r-- | proto/clone_test.go | 25 | ||||
-rw-r--r-- | proto/decode.go | 68 | ||||
-rw-r--r-- | proto/encode.go | 105 | ||||
-rw-r--r-- | proto/equal.go | 15 | ||||
-rw-r--r-- | proto/equal_test.go | 25 | ||||
-rw-r--r-- | proto/pointer_reflect.go | 5 | ||||
-rw-r--r-- | proto/pointer_unsafe.go | 5 | ||||
-rw-r--r-- | proto/properties.go | 27 | ||||
-rw-r--r-- | proto/size_test.go | 4 | ||||
-rw-r--r-- | proto/testdata/test.pb.go | 33 | ||||
-rw-r--r-- | proto/testdata/test.proto | 6 | ||||
-rw-r--r-- | proto/text.go | 64 | ||||
-rw-r--r-- | proto/text_parser.go | 66 | ||||
-rw-r--r-- | proto/text_parser_test.go | 25 | ||||
-rw-r--r-- | proto/text_test.go | 7 | ||||
-rw-r--r-- | protoc-gen-go/generator/generator.go | 40 | ||||
-rw-r--r-- | protoc-gen-go/testdata/my_test/test.pb.go | 26 | ||||
-rw-r--r-- | protoc-gen-go/testdata/my_test/test.pb.go.golden | 26 | ||||
-rw-r--r-- | protoc-gen-go/testdata/my_test/test.proto | 5 |
21 files changed, 667 insertions, 13 deletions
diff --git a/proto/all_test.go b/proto/all_test.go index 6d74ddf..3fade17 100644 --- a/proto/all_test.go +++ b/proto/all_test.go @@ -1833,6 +1833,86 @@ func fuzzUnmarshal(t *testing.T, data []byte) { Unmarshal(data, pb) } +func TestMapFieldMarshal(t *testing.T) { + m := &MessageWithMap{ + NameMapping: map[int32]string{ + 1: "Rob", + 4: "Ian", + 8: "Dave", + }, + } + b, err := Marshal(m) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + + // b should be the concatenation of these three byte sequences in some order. + parts := []string{ + "\n\a\b\x01\x12\x03Rob", + "\n\a\b\x04\x12\x03Ian", + "\n\b\b\x08\x12\x04Dave", + } + ok := false + for i := range parts { + for j := range parts { + if j == i { + continue + } + for k := range parts { + if k == i || k == j { + continue + } + try := parts[i] + parts[j] + parts[k] + if bytes.Equal(b, []byte(try)) { + ok = true + break + } + } + } + } + if !ok { + t.Fatalf("Incorrect Marshal output.\n got %q\nwant %q (or a permutation of that)", b, parts[0]+parts[1]+parts[2]) + } + t.Logf("FYI b: %q", b) + + (new(Buffer)).DebugPrint("Dump of b", b) +} + +func TestMapFieldRoundTrips(t *testing.T) { + m := &MessageWithMap{ + NameMapping: map[int32]string{ + 1: "Rob", + 4: "Ian", + 8: "Dave", + }, + MsgMapping: map[int64]*FloatingPoint{ + 0x7001: &FloatingPoint{F: Float64(2.0)}, + }, + ByteMapping: map[bool][]byte{ + false: []byte("that's not right!"), + true: []byte("aye, 'tis true!"), + }, + } + b, err := Marshal(m) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + t.Logf("FYI b: %q", b) + m2 := new(MessageWithMap) + if err := Unmarshal(b, m2); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + for _, pair := range [][2]interface{}{ + {m.NameMapping, m2.NameMapping}, + {m.MsgMapping, m2.MsgMapping}, + {m.ByteMapping, m2.ByteMapping}, + } { + if !reflect.DeepEqual(pair[0], pair[1]) { + t.Errorf("Map did not survive a round trip.\ninitial: %v\n final: %v", pair[0], pair[1]) + } + } +} + // Benchmarks func testMsg() *GoTest { diff --git a/proto/clone.go b/proto/clone.go index fd11aaf..ae276fd 100644 --- a/proto/clone.go +++ b/proto/clone.go @@ -113,6 +113,29 @@ func mergeAny(out, in reflect.Value) { case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64, reflect.String, reflect.Uint32, reflect.Uint64: out.Set(in) + case reflect.Map: + if in.Len() == 0 { + return + } + if out.IsNil() { + out.Set(reflect.MakeMap(in.Type())) + } + // For maps with value types of *T or []byte we need to deep copy each value. + elemKind := in.Type().Elem().Kind() + for _, key := range in.MapKeys() { + var val reflect.Value + switch elemKind { + case reflect.Ptr: + val = reflect.New(in.Type().Elem().Elem()) + mergeAny(val, in.MapIndex(key)) + case reflect.Slice: + val = in.MapIndex(key) + val = reflect.ValueOf(append([]byte{}, val.Bytes()...)) + default: + val = in.MapIndex(key) + } + out.SetMapIndex(key, val) + } case reflect.Ptr: if in.IsNil() { return diff --git a/proto/clone_test.go b/proto/clone_test.go index 03c2a7b..1ac177d 100644 --- a/proto/clone_test.go +++ b/proto/clone_test.go @@ -189,6 +189,31 @@ var mergeTests = []struct { dst: &pb.OtherMessage{Value: []byte("bar")}, want: &pb.OtherMessage{Value: []byte("foo")}, }, + { + src: &pb.MessageWithMap{ + NameMapping: map[int32]string{6: "Nigel"}, + MsgMapping: map[int64]*pb.FloatingPoint{ + 0x4001: &pb.FloatingPoint{F: proto.Float64(2.0)}, + }, + ByteMapping: map[bool][]byte{true: []byte("wowsa")}, + }, + dst: &pb.MessageWithMap{ + NameMapping: map[int32]string{ + 6: "Bruce", // should be overwritten + 7: "Andrew", + }, + }, + want: &pb.MessageWithMap{ + NameMapping: map[int32]string{ + 6: "Nigel", + 7: "Andrew", + }, + MsgMapping: map[int64]*pb.FloatingPoint{ + 0x4001: &pb.FloatingPoint{F: proto.Float64(2.0)}, + }, + ByteMapping: map[bool][]byte{true: []byte("wowsa")}, + }, + }, } func TestMerge(t *testing.T) { diff --git a/proto/decode.go b/proto/decode.go index 6166dd4..88622c3 100644 --- a/proto/decode.go +++ b/proto/decode.go @@ -178,7 +178,7 @@ func (p *Buffer) DecodeZigzag32() (x uint64, err error) { func (p *Buffer) DecodeRawBytes(alloc bool) (buf []byte, err error) { n, err := p.DecodeVarint() if err != nil { - return + return nil, err } nb := int(n) @@ -668,6 +668,72 @@ func (o *Buffer) dec_slice_slice_byte(p *Properties, base structPointer) error { return nil } +// Decode a map field. +func (o *Buffer) dec_new_map(p *Properties, base structPointer) error { + raw, err := o.DecodeRawBytes(false) + if err != nil { + return err + } + oi := o.index // index at the end of this map entry + o.index -= len(raw) // move buffer back to start of map entry + + mptr := structPointer_Map(base, p.field, p.mtype) // *map[K]V + if mptr.Elem().IsNil() { + mptr.Elem().Set(reflect.MakeMap(mptr.Type().Elem())) + } + v := mptr.Elem() // map[K]V + + // Prepare addressable doubly-indirect placeholders for the key and value types. + // See enc_new_map for why. + keyptr := reflect.New(reflect.PtrTo(p.mtype.Key())).Elem() // addressable *K + keybase := toStructPointer(keyptr.Addr()) // **K + + var valbase structPointer + var valptr reflect.Value + switch p.mtype.Elem().Kind() { + case reflect.Slice: + // []byte + var dummy []byte + valptr = reflect.ValueOf(&dummy) // *[]byte + valbase = toStructPointer(valptr) // *[]byte + case reflect.Ptr: + // message; valptr is **Msg; need to allocate the intermediate pointer + valptr = reflect.New(reflect.PtrTo(p.mtype.Elem())).Elem() // addressable *V + valptr.Set(reflect.New(valptr.Type().Elem())) + valbase = toStructPointer(valptr) + default: + // everything else + valptr = reflect.New(reflect.PtrTo(p.mtype.Elem())).Elem() // addressable *V + valbase = toStructPointer(valptr.Addr()) // **V + } + + // Decode. + // This parses a restricted wire format, namely the encoding of a message + // with two fields. See enc_new_map for the format. + for o.index < oi { + // tagcode for key and value properties are always a single byte + // because they have tags 1 and 2. + tagcode := o.buf[o.index] + o.index++ + switch tagcode { + case p.mkeyprop.tagcode[0]: + if err := p.mkeyprop.dec(o, p.mkeyprop, keybase); err != nil { + return err + } + case p.mvalprop.tagcode[0]: + if err := p.mvalprop.dec(o, p.mvalprop, valbase); err != nil { + return err + } + default: + // TODO: Should we silently skip this instead? + return fmt.Errorf("proto: bad map data tag %d", raw[0]) + } + } + + v.SetMapIndex(keyptr.Elem(), valptr.Elem()) + return nil +} + // Decode a group. func (o *Buffer) dec_struct_group(p *Properties, base structPointer) error { bas := structPointer_GetStructPointer(base, p.field) diff --git a/proto/encode.go b/proto/encode.go index cc202cd..f5050e3 100644 --- a/proto/encode.go +++ b/proto/encode.go @@ -1069,6 +1069,104 @@ func size_map(p *Properties, base structPointer) int { return sizeExtensionMap(v) } +// Encode a map field. +func (o *Buffer) enc_new_map(p *Properties, base structPointer) error { + var state errorState // XXX: or do we need to plumb this through? + + /* + A map defined as + map<key_type, value_type> map_field = N; + is encoded in the same way as + message MapFieldEntry { + key_type key = 1; + value_type value = 2; + } + repeated MapFieldEntry map_field = N; + */ + + v := structPointer_Map(base, p.field, p.mtype).Elem() // map[K]V + if v.Len() == 0 { + return nil + } + + keycopy, valcopy, keybase, valbase := mapEncodeScratch(p.mtype) + + enc := func() error { + if err := p.mkeyprop.enc(o, p.mkeyprop, keybase); err != nil { + return err + } + if err := p.mvalprop.enc(o, p.mvalprop, valbase); err != nil { + return err + } + return nil + } + + for _, key := range v.MapKeys() { + val := v.MapIndex(key) + + keycopy.Set(key) + valcopy.Set(val) + + o.buf = append(o.buf, p.tagcode...) + if err := o.enc_len_thing(enc, &state); err != nil { + return err + } + } + return nil +} + +func size_new_map(p *Properties, base structPointer) int { + v := structPointer_Map(base, p.field, p.mtype).Elem() // map[K]V + + keycopy, valcopy, keybase, valbase := mapEncodeScratch(p.mtype) + + n := 0 + for _, key := range v.MapKeys() { + val := v.MapIndex(key) + keycopy.Set(key) + valcopy.Set(val) + + // Tag codes are two bytes per map entry. + n += 2 + n += p.mkeyprop.size(p.mkeyprop, keybase) + n += p.mvalprop.size(p.mvalprop, valbase) + } + return n +} + +// mapEncodeScratch returns a new reflect.Value matching the map's value type, +// and a structPointer suitable for passing to an encoder or sizer. +func mapEncodeScratch(mapType reflect.Type) (keycopy, valcopy reflect.Value, keybase, valbase structPointer) { + // Prepare addressable doubly-indirect placeholders for the key and value types. + // This is needed because the element-type encoders expect **T, but the map iteration produces T. + + keycopy = reflect.New(mapType.Key()).Elem() // addressable K + keyptr := reflect.New(reflect.PtrTo(keycopy.Type())).Elem() // addressable *K + keyptr.Set(keycopy.Addr()) // + keybase = toStructPointer(keyptr.Addr()) // **K + + // Value types are more varied and require special handling. + switch mapType.Elem().Kind() { + case reflect.Slice: + // []byte + var dummy []byte + valcopy = reflect.ValueOf(&dummy).Elem() // addressable []byte + valbase = toStructPointer(valcopy.Addr()) + case reflect.Ptr: + // message; the generated field type is map[K]*Msg (so V is *Msg), + // so we only need one level of indirection. + valcopy = reflect.New(mapType.Elem()).Elem() // addressable V + valbase = toStructPointer(valcopy.Addr()) + default: + // everything else + valcopy = reflect.New(mapType.Elem()).Elem() // addressable V + valptr := reflect.New(reflect.PtrTo(valcopy.Type())).Elem() // addressable *V + valptr.Set(valcopy.Addr()) // + valbase = toStructPointer(valptr.Addr()) // **V + } + return +} + // Encode a struct. func (o *Buffer) enc_struct(prop *StructProperties, base structPointer) error { var state errorState @@ -1123,10 +1221,15 @@ var zeroes [20]byte // longer than any conceivable sizeVarint // Encode a struct, preceded by its encoded length (as a varint). func (o *Buffer) enc_len_struct(prop *StructProperties, base structPointer, state *errorState) error { + return o.enc_len_thing(func() error { return o.enc_struct(prop, base) }, state) +} + +// Encode something, preceded by its encoded length (as a varint). +func (o *Buffer) enc_len_thing(enc func() error, state *errorState) error { iLen := len(o.buf) o.buf = append(o.buf, 0, 0, 0, 0) // reserve four bytes for length iMsg := len(o.buf) - err := o.enc_struct(prop, base) + err := enc() if err != nil && !state.shouldContinue(err, nil) { return err } diff --git a/proto/equal.go b/proto/equal.go index ebdfdca..d8673a3 100644 --- a/proto/equal.go +++ b/proto/equal.go @@ -154,6 +154,21 @@ func equalAny(v1, v2 reflect.Value) bool { return v1.Float() == v2.Float() case reflect.Int32, reflect.Int64: return v1.Int() == v2.Int() + case reflect.Map: + if v1.Len() != v2.Len() { + return false + } + for _, key := range v1.MapKeys() { + val2 := v2.MapIndex(key) + if !val2.IsValid() { + // This key was not found in the second map. + return false + } + if !equalAny(v1.MapIndex(key), val2) { + return false + } + } + return true case reflect.Ptr: return equalAny(v1.Elem(), v2.Elem()) case reflect.Slice: diff --git a/proto/equal_test.go b/proto/equal_test.go index ebcf340..cc25833 100644 --- a/proto/equal_test.go +++ b/proto/equal_test.go @@ -155,6 +155,31 @@ var EqualTests = []struct { }, true, }, + + { + "map same", + &pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}}, + &pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}}, + true, + }, + { + "map different entry", + &pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}}, + &pb.MessageWithMap{NameMapping: map[int32]string{2: "Rob"}}, + false, + }, + { + "map different key only", + &pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}}, + &pb.MessageWithMap{NameMapping: map[int32]string{2: "Ken"}}, + false, + }, + { + "map different value only", + &pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}}, + &pb.MessageWithMap{NameMapping: map[int32]string{1: "Rob"}}, + false, + }, } func TestEqual(t *testing.T) { diff --git a/proto/pointer_reflect.go b/proto/pointer_reflect.go index 42c387a..93259a3 100644 --- a/proto/pointer_reflect.go +++ b/proto/pointer_reflect.go @@ -144,6 +144,11 @@ func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension { return structPointer_ifield(p, f).(*map[int32]Extension) } +// Map returns the reflect.Value for the address of a map field in the struct. +func structPointer_Map(p structPointer, f field, typ reflect.Type) reflect.Value { + return structPointer_field(p, f).Addr() +} + // SetStructPointer writes a *struct field in the struct. func structPointer_SetStructPointer(p structPointer, f field, q structPointer) { structPointer_field(p, f).Set(q.v) diff --git a/proto/pointer_unsafe.go b/proto/pointer_unsafe.go index cf9fc9a..c52db1c 100644 --- a/proto/pointer_unsafe.go +++ b/proto/pointer_unsafe.go @@ -130,6 +130,11 @@ func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension { return (*map[int32]Extension)(unsafe.Pointer(uintptr(p) + uintptr(f))) } +// Map returns the reflect.Value for the address of a map field in the struct. +func structPointer_Map(p structPointer, f field, typ reflect.Type) reflect.Value { + return reflect.NewAt(typ, unsafe.Pointer(uintptr(p)+uintptr(f))) +} + // SetStructPointer writes a *struct field in the struct. func structPointer_SetStructPointer(p structPointer, f field, q structPointer) { *(*structPointer)(unsafe.Pointer(uintptr(p) + uintptr(f))) = q diff --git a/proto/properties.go b/proto/properties.go index 4420881..730a595 100644 --- a/proto/properties.go +++ b/proto/properties.go @@ -171,6 +171,10 @@ type Properties struct { isMarshaler bool isUnmarshaler bool + mtype reflect.Type // set for map types only + mkeyprop *Properties // set for map types only + mvalprop *Properties // set for map types only + size sizer valSize valueSizer // set for bool and numeric types only @@ -299,7 +303,7 @@ func logNoSliceEnc(t1, t2 reflect.Type) { var protoMessageType = reflect.TypeOf((*Message)(nil)).Elem() // Initialize the fields for encoding and decoding. -func (p *Properties) setEncAndDec(typ reflect.Type, lockGetProp bool) { +func (p *Properties) setEncAndDec(typ reflect.Type, f *reflect.StructField, lockGetProp bool) { p.enc = nil p.dec = nil p.size = nil @@ -342,7 +346,7 @@ func (p *Properties) setEncAndDec(typ reflect.Type, lockGetProp bool) { case reflect.Ptr: switch t2 := t1.Elem(); t2.Kind() { default: - fmt.Fprintf(os.Stderr, "proto: no encoder function for %T -> %T\n", t1, t2) + fmt.Fprintf(os.Stderr, "proto: no encoder function for %v -> %v\n", t1, t2) break case reflect.Bool: p.enc = (*Buffer).enc_bool @@ -502,6 +506,23 @@ func (p *Properties) setEncAndDec(typ reflect.Type, lockGetProp bool) { p.size = size_slice_slice_byte } } + + case reflect.Map: + p.enc = (*Buffer).enc_new_map + p.dec = (*Buffer).dec_new_map + p.size = size_new_map + + p.mtype = t1 + p.mkeyprop = &Properties{} + p.mkeyprop.init(reflect.PtrTo(p.mtype.Key()), "Key", f.Tag.Get("protobuf_key"), nil, lockGetProp) + p.mvalprop = &Properties{} + vtype := p.mtype.Elem() + if vtype.Kind() != reflect.Ptr && vtype.Kind() != reflect.Slice { + // The value type is not a message (*T) or bytes ([]byte), + // so we need encoders for the pointer to this type. + vtype = reflect.PtrTo(vtype) + } + p.mvalprop.init(vtype, "Value", f.Tag.Get("protobuf_val"), nil, lockGetProp) } // precalculate tag code @@ -570,7 +591,7 @@ func (p *Properties) init(typ reflect.Type, name, tag string, f *reflect.StructF return } p.Parse(tag) - p.setEncAndDec(typ, lockGetProp) + p.setEncAndDec(typ, f, lockGetProp) } var ( diff --git a/proto/size_test.go b/proto/size_test.go index 4f87f3b..e5f92d6 100644 --- a/proto/size_test.go +++ b/proto/size_test.go @@ -113,6 +113,10 @@ var SizeTests = []struct { {"proto3 bytes", &proto3pb.Message{Data: []byte("wowsa")}}, {"proto3 bytes, empty", &proto3pb.Message{Data: []byte{}}}, {"proto3 enum", &proto3pb.Message{Hilarity: proto3pb.Message_PUNS}}, + + {"map field", &pb.MessageWithMap{NameMapping: map[int32]string{1: "Rob", 7: "Andrew"}}}, + {"map field with message", &pb.MessageWithMap{MsgMapping: map[int64]*pb.FloatingPoint{0x7001: &pb.FloatingPoint{F: Float64(2.0)}}}}, + {"map field with bytes", &pb.MessageWithMap{ByteMapping: map[bool][]byte{true: []byte("this time for sure")}}}, } func TestSize(t *testing.T) { diff --git a/proto/testdata/test.pb.go b/proto/testdata/test.pb.go index 4d46e6e..f47d9e0 100644 --- a/proto/testdata/test.pb.go +++ b/proto/testdata/test.pb.go @@ -33,6 +33,7 @@ It has these top-level messages: GroupOld GroupNew FloatingPoint + MessageWithMap */ package testdata @@ -1885,6 +1886,38 @@ func (m *FloatingPoint) GetF() float64 { return 0 } +type MessageWithMap struct { + NameMapping map[int32]string `protobuf:"bytes,1,rep,name=name_mapping" json:"name_mapping,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + MsgMapping map[int64]*FloatingPoint `protobuf:"bytes,2,rep,name=msg_mapping" json:"msg_mapping,omitempty" protobuf_key:"zigzag64,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + ByteMapping map[bool][]byte `protobuf:"bytes,3,rep,name=byte_mapping" json:"byte_mapping,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *MessageWithMap) Reset() { *m = MessageWithMap{} } +func (m *MessageWithMap) String() string { return proto.CompactTextString(m) } +func (*MessageWithMap) ProtoMessage() {} + +func (m *MessageWithMap) GetNameMapping() map[int32]string { + if m != nil { + return m.NameMapping + } + return nil +} + +func (m *MessageWithMap) GetMsgMapping() map[int64]*FloatingPoint { + if m != nil { + return m.MsgMapping + } + return nil +} + +func (m *MessageWithMap) GetByteMapping() map[bool][]byte { + if m != nil { + return m.ByteMapping + } + return nil +} + var E_Greeting = &proto.ExtensionDesc{ ExtendedType: (*MyMessage)(nil), ExtensionType: ([]string)(nil), diff --git a/proto/testdata/test.proto b/proto/testdata/test.proto index ac9542a..6cc755b 100644 --- a/proto/testdata/test.proto +++ b/proto/testdata/test.proto @@ -426,3 +426,9 @@ message GroupNew { message FloatingPoint { required double f = 1; } + +message MessageWithMap { + map<int32, string> name_mapping = 1; + map<sint64, FloatingPoint> msg_mapping = 2; + map<bool, bytes> byte_mapping = 3; +} diff --git a/proto/text.go b/proto/text.go index 426db1e..f41a946 100644 --- a/proto/text.go +++ b/proto/text.go @@ -244,6 +244,70 @@ func writeStruct(w *textWriter, sv reflect.Value) error { } continue } + if fv.Kind() == reflect.Map { + // Map fields are rendered as a repeated struct with key/value fields. + keys := fv.MapKeys() // TODO: should we sort these for deterministic output? + for _, key := range keys { + val := fv.MapIndex(key) + if err := writeName(w, props); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte(' '); err != nil { + return err + } + } + // open struct + if err := w.WriteByte('<'); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte('\n'); err != nil { + return err + } + } + w.indent() + // key + if _, err := w.WriteString("key:"); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte(' '); err != nil { + return err + } + } + if err := writeAny(w, key, props.mkeyprop); err != nil { + return err + } + if err := w.WriteByte('\n'); err != nil { + return err + } + // value + if _, err := w.WriteString("value:"); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte(' '); err != nil { + return err + } + } + if err := writeAny(w, val, props.mvalprop); err != nil { + return err + } + if err := w.WriteByte('\n'); err != nil { + return err + } + // close struct + w.unindent() + if err := w.WriteByte('>'); err != nil { + return err + } + if err := w.WriteByte('\n'); err != nil { + return err + } + } + continue + } if props.proto3 && fv.Kind() == reflect.Slice && fv.Len() == 0 { // empty bytes field continue diff --git a/proto/text_parser.go b/proto/text_parser.go index f733f30..ddd9579 100644 --- a/proto/text_parser.go +++ b/proto/text_parser.go @@ -355,6 +355,18 @@ func (p *textParser) next() *token { return &p.cur } +func (p *textParser) consumeToken(s string) error { + tok := p.next() + if tok.err != nil { + return tok.err + } + if tok.value != s { + p.back() + return p.errorf("expected %q, found %q", s, tok.value) + } + return nil +} + // Return a RequiredNotSetError indicating which required field was not set. func (p *textParser) missingRequiredFieldError(sv reflect.Value) *RequiredNotSetError { st := sv.Type() @@ -518,6 +530,60 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error { dst := sv.Field(fi) + if dst.Kind() == reflect.Map { + // Consume any colon. + if err := p.checkForColon(props, dst.Type()); err != nil { + return err + } + + // Construct the map if it doesn't already exist. + if dst.IsNil() { + dst.Set(reflect.MakeMap(dst.Type())) + } + key := reflect.New(dst.Type().Key()).Elem() + val := reflect.New(dst.Type().Elem()).Elem() + + // The map entry should be this sequence of tokens: + // < key : KEY value : VALUE > + // Technically the "key" and "value" could come in any order, + // but in practice they won't. + + tok := p.next() + var terminator string + switch tok.value { + case "<": + terminator = ">" + case "{": + terminator = "}" + default: + return p.errorf("expected '{' or '<', found %q", tok.value) + } + if err := p.consumeToken("key"); err != nil { + return err + } + if err := p.consumeToken(":"); err != nil { + return err + } + if err := p.readAny(key, props.mkeyprop); err != nil { + return err + } + if err := p.consumeToken("value"); err != nil { + return err + } + if err := p.consumeToken(":"); err != nil { + return err + } + if err := p.readAny(val, props.mvalprop); err != nil { + return err + } + if err := p.consumeToken(terminator); err != nil { + return err + } + + dst.SetMapIndex(key, val) + continue + } + // Check that it's not already set if it's not a repeated field. if !props.Repeated && fieldSet[name] { return p.errorf("non-repeated field %q was repeated", name) diff --git a/proto/text_parser_test.go b/proto/text_parser_test.go index 89ab106..e5ee8b9 100644 --- a/proto/text_parser_test.go +++ b/proto/text_parser_test.go @@ -459,6 +459,31 @@ func TestProto3TextParsing(t *testing.T) { } } +func TestMapParsing(t *testing.T) { + m := new(MessageWithMap) + const in = `name_mapping:<key:1234 value:"Feist"> name_mapping:<key:1 value:"Beatles">` + + `msg_mapping:<key:-4 value:<f: 2.0>>` + + `byte_mapping:<key:true value:"so be it">` + want := &MessageWithMap{ + NameMapping: map[int32]string{ + 1: "Beatles", + 1234: "Feist", + }, + MsgMapping: map[int64]*FloatingPoint{ + -4: {F: Float64(2.0)}, + }, + ByteMapping: map[bool][]byte{ + true: []byte("so be it"), + }, + } + if err := UnmarshalText(in, m); err != nil { + t.Fatal(err) + } + if !Equal(m, want) { + t.Errorf("\n got %v\nwant %v", m, want) + } +} + var benchInput string func init() { diff --git a/proto/text_test.go b/proto/text_test.go index 404920e..707bedd 100644 --- a/proto/text_test.go +++ b/proto/text_test.go @@ -419,6 +419,13 @@ func TestProto3Text(t *testing.T) { {&proto3pb.Message{Data: []byte{}}, ``}, // trivial case {&proto3pb.Message{Name: "Rob", HeightInCm: 175}, `name:"Rob" height_in_cm:175`}, + // empty map + {&pb.MessageWithMap{}, ``}, + // non-empty map; current map format is the same as a repeated struct + { + &pb.MessageWithMap{NameMapping: map[int32]string{1234: "Feist"}}, + `name_mapping:<key:1234 value:"Feist" >`, + }, } for _, test := range tests { got := strings.TrimSpace(test.m.String()) diff --git a/protoc-gen-go/generator/generator.go b/protoc-gen-go/generator/generator.go index 4b309cb..8b02ba5 100644 --- a/protoc-gen-go/generator/generator.go +++ b/protoc-gen-go/generator/generator.go @@ -102,6 +102,12 @@ func fileIsProto3(file *descriptor.FileDescriptorProto) bool { func (c *common) proto3() bool { return fileIsProto3(c.file) } +func fileUsesMaps(file *descriptor.FileDescriptorProto) bool { + return true +} + +func (c *common) usesMaps() bool { return fileUsesMaps(c.file) } + // Descriptor represents a protocol buffer message. type Descriptor struct { common @@ -1011,6 +1017,10 @@ func (g *Generator) generate(file *FileDescriptor) { g.generateEnum(enum) } for _, desc := range g.file.desc { + // Don't generate virtual messages for maps. + if desc.GetOptions().GetMapEntry() && desc.usesMaps() { + continue + } g.generateMessage(desc) } for _, ext := range g.file.ext { @@ -1499,6 +1509,7 @@ func (g *Generator) generateMessage(message *Descriptor) { } fieldNames := make(map[*descriptor.FieldDescriptorProto]string) fieldGetterNames := make(map[*descriptor.FieldDescriptorProto]string) + mapFieldTypes := make(map[*descriptor.FieldDescriptorProto]string) g.PrintComments(message.path) g.P("type ", ccTypeName, " struct {") @@ -1516,6 +1527,32 @@ func (g *Generator) generateMessage(message *Descriptor) { typename, wiretype := g.GoType(message, field) jsonName := *field.Name tag := fmt.Sprintf("protobuf:%s json:%q", g.goTag(message, field, wiretype), jsonName+",omitempty") + + if *field.Type == descriptor.FieldDescriptorProto_TYPE_MESSAGE { + desc := g.ObjectNamed(field.GetTypeName()) + if d, ok := desc.(*Descriptor); ok && d.GetOptions().GetMapEntry() && d.usesMaps() { + // Figure out the Go types and tags for the key and value types. + keyField, valField := d.Field[0], d.Field[1] + keyType, keyWire := g.GoType(d, keyField) + valType, valWire := g.GoType(d, valField) + keyTag, valTag := g.goTag(d, keyField, keyWire), g.goTag(d, valField, valWire) + + // We don't use stars, except for message-typed values. + keyType = strings.TrimPrefix(keyType, "*") + switch *valField.Type { + case descriptor.FieldDescriptorProto_TYPE_MESSAGE: + g.RecordTypeUse(valField.GetTypeName()) + default: + valType = strings.TrimPrefix(valType, "*") + } + + typename = fmt.Sprintf("map[%s]%s", keyType, valType) + mapFieldTypes[field] = typename // record for the getter generation + + tag += fmt.Sprintf(" protobuf_key:%s protobuf_val:%s", keyTag, valTag) + } + } + fieldNames[field] = fieldName fieldGetterNames[field] = fieldGetterName g.P(fieldName, "\t", typename, "\t`", tag, "`") @@ -1655,6 +1692,9 @@ func (g *Generator) generateMessage(message *Descriptor) { for _, field := range message.Field { fname := fieldNames[field] typename, _ := g.GoType(message, field) + if t, ok := mapFieldTypes[field]; ok { + typename = t + } mname := "Get" + fieldGetterNames[field] star := "" if needsStar(*field.Type) && typename[0] == '*' { diff --git a/protoc-gen-go/testdata/my_test/test.pb.go b/protoc-gen-go/testdata/my_test/test.pb.go index c39e58f..5cd7b2a 100644 --- a/protoc-gen-go/testdata/my_test/test.pb.go +++ b/protoc-gen-go/testdata/my_test/test.pb.go @@ -174,10 +174,14 @@ type Request struct { Hue *Request_Color `protobuf:"varint,3,opt,name=hue,enum=my.test.Request_Color" json:"hue,omitempty"` Hat *HatType `protobuf:"varint,4,opt,name=hat,enum=my.test.HatType,def=1" json:"hat,omitempty"` // optional imp.ImportedMessage.Owner owner = 6; - Deadline *float32 `protobuf:"fixed32,7,opt,name=deadline,def=inf" json:"deadline,omitempty"` - Somegroup *Request_SomeGroup `protobuf:"group,8,opt,name=SomeGroup" json:"somegroup,omitempty"` - Reset_ *int32 `protobuf:"varint,12,opt,name=reset" json:"reset,omitempty"` - XXX_unrecognized []byte `json:"-"` + Deadline *float32 `protobuf:"fixed32,7,opt,name=deadline,def=inf" json:"deadline,omitempty"` + Somegroup *Request_SomeGroup `protobuf:"group,8,opt,name=SomeGroup" json:"somegroup,omitempty"` + // This is a map field. It will generate map[int32]string. + NameMapping map[int32]string `protobuf:"bytes,14,rep,name=name_mapping" json:"name_mapping,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + // This is a map field whose value type is a message. + MsgMapping map[int64]*Reply `protobuf:"bytes,15,rep,name=msg_mapping" json:"msg_mapping,omitempty" protobuf_key:"zigzag64,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + Reset_ *int32 `protobuf:"varint,12,opt,name=reset" json:"reset,omitempty"` + XXX_unrecognized []byte `json:"-"` } func (m *Request) Reset() { *m = Request{} } @@ -223,6 +227,20 @@ func (m *Request) GetSomegroup() *Request_SomeGroup { return nil } +func (m *Request) GetNameMapping() map[int32]string { + if m != nil { + return m.NameMapping + } + return nil +} + +func (m *Request) GetMsgMapping() map[int64]*Reply { + if m != nil { + return m.MsgMapping + } + return nil +} + func (m *Request) GetReset_() int32 { if m != nil && m.Reset_ != nil { return *m.Reset_ diff --git a/protoc-gen-go/testdata/my_test/test.pb.go.golden b/protoc-gen-go/testdata/my_test/test.pb.go.golden index c39e58f..5cd7b2a 100644 --- a/protoc-gen-go/testdata/my_test/test.pb.go.golden +++ b/protoc-gen-go/testdata/my_test/test.pb.go.golden @@ -174,10 +174,14 @@ type Request struct { Hue *Request_Color `protobuf:"varint,3,opt,name=hue,enum=my.test.Request_Color" json:"hue,omitempty"` Hat *HatType `protobuf:"varint,4,opt,name=hat,enum=my.test.HatType,def=1" json:"hat,omitempty"` // optional imp.ImportedMessage.Owner owner = 6; - Deadline *float32 `protobuf:"fixed32,7,opt,name=deadline,def=inf" json:"deadline,omitempty"` - Somegroup *Request_SomeGroup `protobuf:"group,8,opt,name=SomeGroup" json:"somegroup,omitempty"` - Reset_ *int32 `protobuf:"varint,12,opt,name=reset" json:"reset,omitempty"` - XXX_unrecognized []byte `json:"-"` + Deadline *float32 `protobuf:"fixed32,7,opt,name=deadline,def=inf" json:"deadline,omitempty"` + Somegroup *Request_SomeGroup `protobuf:"group,8,opt,name=SomeGroup" json:"somegroup,omitempty"` + // This is a map field. It will generate map[int32]string. + NameMapping map[int32]string `protobuf:"bytes,14,rep,name=name_mapping" json:"name_mapping,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + // This is a map field whose value type is a message. + MsgMapping map[int64]*Reply `protobuf:"bytes,15,rep,name=msg_mapping" json:"msg_mapping,omitempty" protobuf_key:"zigzag64,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + Reset_ *int32 `protobuf:"varint,12,opt,name=reset" json:"reset,omitempty"` + XXX_unrecognized []byte `json:"-"` } func (m *Request) Reset() { *m = Request{} } @@ -223,6 +227,20 @@ func (m *Request) GetSomegroup() *Request_SomeGroup { return nil } +func (m *Request) GetNameMapping() map[int32]string { + if m != nil { + return m.NameMapping + } + return nil +} + +func (m *Request) GetMsgMapping() map[int64]*Reply { + if m != nil { + return m.MsgMapping + } + return nil +} + func (m *Request) GetReset_() int32 { if m != nil && m.Reset_ != nil { return *m.Reset_ diff --git a/protoc-gen-go/testdata/my_test/test.proto b/protoc-gen-go/testdata/my_test/test.proto index efcc0a0..af69c47 100644 --- a/protoc-gen-go/testdata/my_test/test.proto +++ b/protoc-gen-go/testdata/my_test/test.proto @@ -72,6 +72,11 @@ message Request { // optional imp.PubliclyImportedEnum pub_enum = 13 [default=HAIR]; + // This is a map field. It will generate map[int32]string. + map<int32, string> name_mapping = 14; + // This is a map field whose value type is a message. + map<sint64, Reply> msg_mapping = 15; + optional int32 reset = 12; } |