1 // Go support for Protocol Buffers - Google's data interchange format
3 // Copyright 2014 The Go Authors. All rights reserved.
4 // https://github.com/golang/protobuf
6 // Redistribution and use in source and binary forms, with or without
7 // modification, are permitted provided that the following conditions are
10 // * Redistributions of source code must retain the above copyright
11 // notice, this list of conditions and the following disclaimer.
12 // * Redistributions in binary form must reproduce the above
13 // copyright notice, this list of conditions and the following disclaimer
14 // in the documentation and/or other materials provided with the
16 // * Neither the name of Google Inc. nor the names of its
17 // contributors may be used to endorse or promote products derived from
18 // this software without specific prior written permission.
20 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
41 "github.com/golang/protobuf/proto"
42 pb "github.com/golang/protobuf/proto/testdata"
43 "golang.org/x/sync/errgroup"
46 func TestGetExtensionsWithMissingExtensions(t *testing.T) {
47 msg := &pb.MyMessage{}
49 if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
50 t.Fatalf("Could not set ext1: %s", err)
52 exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{
57 t.Fatalf("GetExtensions() failed: %s", err)
60 t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0])
63 t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1])
67 func TestExtensionDescsWithMissingExtensions(t *testing.T) {
68 msg := &pb.MyMessage{Count: proto.Int32(0)}
69 extdesc1 := pb.E_Ext_More
70 if descs, err := proto.ExtensionDescs(msg); len(descs) != 0 || err != nil {
71 t.Errorf("proto.ExtensionDescs: got %d descs, error %v; want 0, nil", len(descs), err)
75 if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
76 t.Fatalf("Could not set ext1: %s", err)
78 extdesc2 := &proto.ExtensionDesc{
79 ExtendedType: (*pb.MyMessage)(nil),
80 ExtensionType: (*bool)(nil),
83 Tag: "varint,123456789,opt",
85 ext2 := proto.Bool(false)
86 if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
87 t.Fatalf("Could not set ext2: %s", err)
90 b, err := proto.Marshal(msg)
92 t.Fatalf("Could not marshal msg: %v", err)
94 if err := proto.Unmarshal(b, msg); err != nil {
95 t.Fatalf("Could not unmarshal into msg: %v", err)
98 descs, err := proto.ExtensionDescs(msg)
100 t.Fatalf("proto.ExtensionDescs: got error %v", err)
103 wantDescs := []*proto.ExtensionDesc{extdesc1, &proto.ExtensionDesc{Field: extdesc2.Field}}
104 if !reflect.DeepEqual(descs, wantDescs) {
105 t.Errorf("proto.ExtensionDescs(msg) sorted extension ids: got %+v, want %+v", descs, wantDescs)
109 type ExtensionDescSlice []*proto.ExtensionDesc
111 func (s ExtensionDescSlice) Len() int { return len(s) }
112 func (s ExtensionDescSlice) Less(i, j int) bool { return s[i].Field < s[j].Field }
113 func (s ExtensionDescSlice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
115 func sortExtDescs(s []*proto.ExtensionDesc) {
116 sort.Sort(ExtensionDescSlice(s))
119 func TestGetExtensionStability(t *testing.T) {
120 check := func(m *pb.MyMessage) bool {
121 ext1, err := proto.GetExtension(m, pb.E_Ext_More)
123 t.Fatalf("GetExtension() failed: %s", err)
125 ext2, err := proto.GetExtension(m, pb.E_Ext_More)
127 t.Fatalf("GetExtension() failed: %s", err)
131 msg := &pb.MyMessage{Count: proto.Int32(4)}
133 if err := proto.SetExtension(msg, pb.E_Ext_More, ext0); err != nil {
134 t.Fatalf("Could not set ext1: %s", ext0)
137 t.Errorf("GetExtension() not stable before marshaling")
139 bb, err := proto.Marshal(msg)
141 t.Fatalf("Marshal() failed: %s", err)
143 msg1 := &pb.MyMessage{}
144 err = proto.Unmarshal(bb, msg1)
146 t.Fatalf("Unmarshal() failed: %s", err)
149 t.Errorf("GetExtension() not stable after unmarshaling")
153 func TestGetExtensionDefaults(t *testing.T) {
154 var setFloat64 float64 = 1
155 var setFloat32 float32 = 2
156 var setInt32 int32 = 3
157 var setInt64 int64 = 4
158 var setUint32 uint32 = 5
159 var setUint64 uint64 = 6
162 var setString = "Goodnight string"
163 var setBytes = []byte("Goodnight bytes")
164 var setEnum = pb.DefaultsMessage_TWO
166 type testcase struct {
167 ext *proto.ExtensionDesc // Extension we are testing.
168 want interface{} // Expected value of extension, or nil (meaning that GetExtension will fail).
169 def interface{} // Expected value of extension after ClearExtension().
172 {pb.E_NoDefaultDouble, setFloat64, nil},
173 {pb.E_NoDefaultFloat, setFloat32, nil},
174 {pb.E_NoDefaultInt32, setInt32, nil},
175 {pb.E_NoDefaultInt64, setInt64, nil},
176 {pb.E_NoDefaultUint32, setUint32, nil},
177 {pb.E_NoDefaultUint64, setUint64, nil},
178 {pb.E_NoDefaultSint32, setInt32, nil},
179 {pb.E_NoDefaultSint64, setInt64, nil},
180 {pb.E_NoDefaultFixed32, setUint32, nil},
181 {pb.E_NoDefaultFixed64, setUint64, nil},
182 {pb.E_NoDefaultSfixed32, setInt32, nil},
183 {pb.E_NoDefaultSfixed64, setInt64, nil},
184 {pb.E_NoDefaultBool, setBool, nil},
185 {pb.E_NoDefaultBool, setBool2, nil},
186 {pb.E_NoDefaultString, setString, nil},
187 {pb.E_NoDefaultBytes, setBytes, nil},
188 {pb.E_NoDefaultEnum, setEnum, nil},
189 {pb.E_DefaultDouble, setFloat64, float64(3.1415)},
190 {pb.E_DefaultFloat, setFloat32, float32(3.14)},
191 {pb.E_DefaultInt32, setInt32, int32(42)},
192 {pb.E_DefaultInt64, setInt64, int64(43)},
193 {pb.E_DefaultUint32, setUint32, uint32(44)},
194 {pb.E_DefaultUint64, setUint64, uint64(45)},
195 {pb.E_DefaultSint32, setInt32, int32(46)},
196 {pb.E_DefaultSint64, setInt64, int64(47)},
197 {pb.E_DefaultFixed32, setUint32, uint32(48)},
198 {pb.E_DefaultFixed64, setUint64, uint64(49)},
199 {pb.E_DefaultSfixed32, setInt32, int32(50)},
200 {pb.E_DefaultSfixed64, setInt64, int64(51)},
201 {pb.E_DefaultBool, setBool, true},
202 {pb.E_DefaultBool, setBool2, true},
203 {pb.E_DefaultString, setString, "Hello, string"},
204 {pb.E_DefaultBytes, setBytes, []byte("Hello, bytes")},
205 {pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE},
208 checkVal := func(test testcase, msg *pb.DefaultsMessage, valWant interface{}) error {
209 val, err := proto.GetExtension(msg, test.ext)
212 return fmt.Errorf("GetExtension(): %s", err)
214 if want := proto.ErrMissingExtension; err != want {
215 return fmt.Errorf("Unexpected error: got %v, want %v", err, want)
220 // All proto2 extension values are either a pointer to a value or a slice of values.
221 ty := reflect.TypeOf(val)
222 tyWant := reflect.TypeOf(test.ext.ExtensionType)
223 if got, want := ty, tyWant; got != want {
224 return fmt.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want)
227 tyeWant := tyWant.Elem()
228 if got, want := tye, tyeWant; got != want {
229 return fmt.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want)
232 // Check the name of the type of the value.
233 // If it is an enum it will be type int32 with the name of the enum.
234 if got, want := tye.Name(), tye.Name(); got != want {
235 return fmt.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want)
238 // Check that value is what we expect.
239 // If we have a pointer in val, get the value it points to.
241 if ty.Kind() == reflect.Ptr {
242 valExp = reflect.ValueOf(val).Elem().Interface()
244 if got, want := valExp, valWant; !reflect.DeepEqual(got, want) {
245 return fmt.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want)
251 setTo := func(test testcase) interface{} {
252 setTo := reflect.ValueOf(test.want)
253 if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr {
254 setTo = reflect.New(typ).Elem()
255 setTo.Set(reflect.New(setTo.Type().Elem()))
256 setTo.Elem().Set(reflect.ValueOf(test.want))
258 return setTo.Interface()
261 for _, test := range tests {
262 msg := &pb.DefaultsMessage{}
263 name := test.ext.Name
265 // Check the initial value.
266 if err := checkVal(test, msg, test.def); err != nil {
267 t.Errorf("%s: %v", name, err)
270 // Set the per-type value and check value.
271 name = fmt.Sprintf("%s (set to %T %v)", name, test.want, test.want)
272 if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil {
273 t.Errorf("%s: SetExtension(): %v", name, err)
276 if err := checkVal(test, msg, test.want); err != nil {
277 t.Errorf("%s: %v", name, err)
281 // Set and check the value.
283 proto.ClearExtension(msg, test.ext)
284 if err := checkVal(test, msg, test.def); err != nil {
285 t.Errorf("%s: %v", name, err)
290 func TestExtensionsRoundTrip(t *testing.T) {
291 msg := &pb.MyMessage{}
293 Data: proto.String("hi"),
296 Data: proto.String("there"),
298 exists := proto.HasExtension(msg, pb.E_Ext_More)
300 t.Error("Extension More present unexpectedly")
302 if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
305 if err := proto.SetExtension(msg, pb.E_Ext_More, ext2); err != nil {
308 e, err := proto.GetExtension(msg, pb.E_Ext_More)
314 t.Errorf("e has type %T, expected testdata.Ext", e)
315 } else if *x.Data != "there" {
316 t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x)
318 proto.ClearExtension(msg, pb.E_Ext_More)
319 if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension {
320 t.Errorf("got %v, expected ErrMissingExtension", e)
322 if _, err := proto.GetExtension(msg, pb.E_X215); err == nil {
323 t.Error("expected bad extension error, got nil")
325 if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil {
326 t.Error("expected extension err")
328 if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil {
329 t.Error("expected some sort of type mismatch error, got nil")
333 func TestNilExtension(t *testing.T) {
334 msg := &pb.MyMessage{
335 Count: proto.Int32(1),
337 if err := proto.SetExtension(msg, pb.E_Ext_Text, proto.String("hello")); err != nil {
340 if err := proto.SetExtension(msg, pb.E_Ext_More, (*pb.Ext)(nil)); err == nil {
341 t.Error("expected SetExtension to fail due to a nil extension")
342 } else if want := "proto: SetExtension called with nil value of type *testdata.Ext"; err.Error() != want {
343 t.Errorf("expected error %v, got %v", want, err)
345 // Note: if the behavior of Marshal is ever changed to ignore nil extensions, update
346 // this test to verify that E_Ext_Text is properly propagated through marshal->unmarshal.
349 func TestMarshalUnmarshalRepeatedExtension(t *testing.T) {
350 // Add a repeated extension to the result.
353 ext []*pb.ComplexExtension
357 []*pb.ComplexExtension{
358 {First: proto.Int32(7)},
359 {Second: proto.Int32(11)},
364 []*pb.ComplexExtension{
365 {Third: []int32{1000}},
366 {Third: []int32{2000}},
370 "two fields and repeated field",
371 []*pb.ComplexExtension{
372 {Third: []int32{1000}},
373 {First: proto.Int32(9)},
374 {Second: proto.Int32(21)},
375 {Third: []int32{2000}},
379 for _, test := range tests {
380 // Marshal message with a repeated extension.
381 msg1 := new(pb.OtherMessage)
382 err := proto.SetExtension(msg1, pb.E_RComplex, test.ext)
384 t.Fatalf("[%s] Error setting extension: %v", test.name, err)
386 b, err := proto.Marshal(msg1)
388 t.Fatalf("[%s] Error marshaling message: %v", test.name, err)
391 // Unmarshal and read the merged proto.
392 msg2 := new(pb.OtherMessage)
393 err = proto.Unmarshal(b, msg2)
395 t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
397 e, err := proto.GetExtension(msg2, pb.E_RComplex)
399 t.Fatalf("[%s] Error getting extension: %v", test.name, err)
401 ext := e.([]*pb.ComplexExtension)
403 t.Fatalf("[%s] Invalid extension", test.name)
405 if !reflect.DeepEqual(ext, test.ext) {
406 t.Errorf("[%s] Wrong value for ComplexExtension: got: %v want: %v\n", test.name, ext, test.ext)
411 func TestUnmarshalRepeatingNonRepeatedExtension(t *testing.T) {
412 // We may see multiple instances of the same extension in the wire
413 // format. For example, the proto compiler may encode custom options in
414 // this way. Here, we verify that we merge the extensions together.
417 ext []*pb.ComplexExtension
421 []*pb.ComplexExtension{
422 {First: proto.Int32(7)},
423 {Second: proto.Int32(11)},
428 []*pb.ComplexExtension{
429 {Third: []int32{1000}},
430 {Third: []int32{2000}},
434 "two fields and repeated field",
435 []*pb.ComplexExtension{
436 {Third: []int32{1000}},
437 {First: proto.Int32(9)},
438 {Second: proto.Int32(21)},
439 {Third: []int32{2000}},
443 for _, test := range tests {
445 var want pb.ComplexExtension
447 // Generate a serialized representation of a repeated extension
448 // by catenating bytes together.
449 for i, e := range test.ext {
450 // Merge to create the wanted proto.
451 proto.Merge(&want, e)
453 // serialize the message
454 msg := new(pb.OtherMessage)
455 err := proto.SetExtension(msg, pb.E_Complex, e)
457 t.Fatalf("[%s] Error setting extension %d: %v", test.name, i, err)
459 b, err := proto.Marshal(msg)
461 t.Fatalf("[%s] Error marshaling message %d: %v", test.name, i, err)
466 // Unmarshal and read the merged proto.
467 msg2 := new(pb.OtherMessage)
468 err := proto.Unmarshal(buf.Bytes(), msg2)
470 t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
472 e, err := proto.GetExtension(msg2, pb.E_Complex)
474 t.Fatalf("[%s] Error getting extension: %v", test.name, err)
476 ext := e.(*pb.ComplexExtension)
478 t.Fatalf("[%s] Invalid extension", test.name)
480 if !reflect.DeepEqual(*ext, want) {
481 t.Errorf("[%s] Wrong value for ComplexExtension: got: %s want: %s\n", test.name, ext, want)
486 func TestClearAllExtensions(t *testing.T) {
487 // unregistered extension
488 desc := &proto.ExtensionDesc{
489 ExtendedType: (*pb.MyMessage)(nil),
490 ExtensionType: (*bool)(nil),
492 Name: "emptyextension",
496 if proto.HasExtension(m, desc) {
497 t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
499 if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil {
500 t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
502 if !proto.HasExtension(m, desc) {
503 t.Errorf("proto.HasExtension(%s): got false, want true", proto.MarshalTextString(m))
505 proto.ClearAllExtensions(m)
506 if proto.HasExtension(m, desc) {
507 t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
511 func TestMarshalRace(t *testing.T) {
512 // unregistered extension
513 desc := &proto.ExtensionDesc{
514 ExtendedType: (*pb.MyMessage)(nil),
515 ExtensionType: (*bool)(nil),
517 Name: "emptyextension",
521 m := &pb.MyMessage{Count: proto.Int32(4)}
522 if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil {
523 t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
527 for n := 3; n > 0; n-- {
529 _, err := proto.Marshal(m)
533 if err := g.Wait(); err != nil {