aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Symonds <dsymonds@golang.org>2014-04-15 14:01:13 +1000
committerDavid Symonds <dsymonds@golang.org>2014-04-15 14:01:13 +1000
commit29bcc89b25777c71ec610d4230b212c1af63c315 (patch)
treea787c35f9e75e5c2b05282deeb989cc7c2c28e10
parentf530f25b30c27550602600ad056192c4328f577d (diff)
downloadprotobuf-29bcc89b25777c71ec610d4230b212c1af63c315.tar.gz
goprotobuf: Add tests for the Marshaler interface during marshaling.
LGTM=r R=r CC=golang-codereviews https://codereview.appspot.com/87130044
-rw-r--r--proto/all_test.go58
1 files changed, 58 insertions, 0 deletions
diff --git a/proto/all_test.go b/proto/all_test.go
index 73179e4..7b03cc3 100644
--- a/proto/all_test.go
+++ b/proto/all_test.go
@@ -34,6 +34,7 @@ package proto_test
import (
"bytes"
"encoding/json"
+ "errors"
"fmt"
"math"
"math/rand"
@@ -394,6 +395,63 @@ func TestNumericPrimitives(t *testing.T) {
}
}
+// fakeMarshaler is a simple struct implementing Marshaler and Message interfaces.
+type fakeMarshaler struct {
+ b []byte
+ err error
+}
+
+func (f fakeMarshaler) Marshal() ([]byte, error) {
+ return f.b, f.err
+}
+
+func (f fakeMarshaler) String() string {
+ return fmt.Sprintf("Bytes: %v Error: %v", f.b, f.err)
+}
+
+func (f fakeMarshaler) ProtoMessage() {}
+
+func (f fakeMarshaler) Reset() {}
+
+// Simple tests for proto messages that implement the Marshaler interface.
+func TestMarshalerEncoding(t *testing.T) {
+ tests := []struct {
+ name string
+ m Message
+ want []byte
+ wantErr error
+ }{
+ {
+ name: "Marshaler that fails",
+ m: fakeMarshaler{
+ err: errors.New("some marshal err"),
+ b: []byte{5, 6, 7},
+ },
+ // Since there's an error, nothing should be written to buffer.
+ want: nil,
+ wantErr: errors.New("some marshal err"),
+ },
+ {
+ name: "Marshaler that succeeds",
+ m: fakeMarshaler{
+ b: []byte{0, 1, 2, 3, 4, 127, 255},
+ },
+ want: []byte{0, 1, 2, 3, 4, 127, 255},
+ wantErr: nil,
+ },
+ }
+ for _, test := range tests {
+ b := NewBuffer(nil)
+ err := b.Marshal(test.m)
+ if !reflect.DeepEqual(test.wantErr, err) {
+ t.Errorf("%s: got err %v wanted %v", test.name, err, test.wantErr)
+ }
+ if !reflect.DeepEqual(test.want, b.Bytes()) {
+ t.Errorf("%s: got bytes %v wanted %v", test.name, b.Bytes(), test.want)
+ }
+ }
+}
+
// Simple tests for bytes
func TestBytesPrimitives(t *testing.T) {
o := old()