OSDN Git Service

Merge pull request #41 from Bytom/dev
[bytom/vapor.git] / vendor / github.com / gogo / protobuf / proto / extensions_test.go
1 // Go support for Protocol Buffers - Google's data interchange format
2 //
3 // Copyright 2014 The Go Authors.  All rights reserved.
4 // https://github.com/golang/protobuf
5 //
6 // Redistribution and use in source and binary forms, with or without
7 // modification, are permitted provided that the following conditions are
8 // met:
9 //
10 //     * Redistributions of source code must retain the above copyright
11 // notice, this list of conditions and the following disclaimer.
12 //     * Redistributions in binary form must reproduce the above
13 // copyright notice, this list of conditions and the following disclaimer
14 // in the documentation and/or other materials provided with the
15 // distribution.
16 //     * Neither the name of Google Inc. nor the names of its
17 // contributors may be used to endorse or promote products derived from
18 // this software without specific prior written permission.
19 //
20 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
32 package proto_test
33
34 import (
35         "bytes"
36         "fmt"
37         "io"
38         "reflect"
39         "sort"
40         "strings"
41         "testing"
42
43         "github.com/gogo/protobuf/proto"
44         pb "github.com/gogo/protobuf/proto/test_proto"
45 )
46
47 func TestGetExtensionsWithMissingExtensions(t *testing.T) {
48         msg := &pb.MyMessage{}
49         ext1 := &pb.Ext{}
50         if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
51                 t.Fatalf("Could not set ext1: %s", err)
52         }
53         exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{
54                 pb.E_Ext_More,
55                 pb.E_Ext_Text,
56         })
57         if err != nil {
58                 t.Fatalf("GetExtensions() failed: %s", err)
59         }
60         if exts[0] != ext1 {
61                 t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0])
62         }
63         if exts[1] != nil {
64                 t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1])
65         }
66 }
67
68 func TestGetExtensionWithEmptyBuffer(t *testing.T) {
69         // Make sure that GetExtension returns an error if its
70         // undecoded buffer is empty.
71         msg := &pb.MyMessage{}
72         proto.SetRawExtension(msg, pb.E_Ext_More.Field, []byte{})
73         _, err := proto.GetExtension(msg, pb.E_Ext_More)
74         if want := io.ErrUnexpectedEOF; err != want {
75                 t.Errorf("unexpected error in GetExtension from empty buffer: got %v, want %v", err, want)
76         }
77 }
78
79 func TestGetExtensionForIncompleteDesc(t *testing.T) {
80         msg := &pb.MyMessage{Count: proto.Int32(0)}
81         extdesc1 := &proto.ExtensionDesc{
82                 ExtendedType:  (*pb.MyMessage)(nil),
83                 ExtensionType: (*bool)(nil),
84                 Field:         123456789,
85                 Name:          "a.b",
86                 Tag:           "varint,123456789,opt",
87         }
88         ext1 := proto.Bool(true)
89         if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
90                 t.Fatalf("Could not set ext1: %s", err)
91         }
92         extdesc2 := &proto.ExtensionDesc{
93                 ExtendedType:  (*pb.MyMessage)(nil),
94                 ExtensionType: ([]byte)(nil),
95                 Field:         123456790,
96                 Name:          "a.c",
97                 Tag:           "bytes,123456790,opt",
98         }
99         ext2 := []byte{0, 1, 2, 3, 4, 5, 6, 7}
100         if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
101                 t.Fatalf("Could not set ext2: %s", err)
102         }
103         extdesc3 := &proto.ExtensionDesc{
104                 ExtendedType:  (*pb.MyMessage)(nil),
105                 ExtensionType: (*pb.Ext)(nil),
106                 Field:         123456791,
107                 Name:          "a.d",
108                 Tag:           "bytes,123456791,opt",
109         }
110         ext3 := &pb.Ext{Data: proto.String("foo")}
111         if err := proto.SetExtension(msg, extdesc3, ext3); err != nil {
112                 t.Fatalf("Could not set ext3: %s", err)
113         }
114
115         b, err := proto.Marshal(msg)
116         if err != nil {
117                 t.Fatalf("Could not marshal msg: %v", err)
118         }
119         if err := proto.Unmarshal(b, msg); err != nil {
120                 t.Fatalf("Could not unmarshal into msg: %v", err)
121         }
122
123         var expected proto.Buffer
124         if err := expected.EncodeVarint(uint64((extdesc1.Field << 3) | proto.WireVarint)); err != nil {
125                 t.Fatalf("failed to compute expected prefix for ext1: %s", err)
126         }
127         if err := expected.EncodeVarint(1 /* bool true */); err != nil {
128                 t.Fatalf("failed to compute expected value for ext1: %s", err)
129         }
130
131         if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc1.Field}); err != nil {
132                 t.Fatalf("Failed to get raw value for ext1: %s", err)
133         } else if !reflect.DeepEqual(b, expected.Bytes()) {
134                 t.Fatalf("Raw value for ext1: got %v, want %v", b, expected.Bytes())
135         }
136
137         expected = proto.Buffer{} // reset
138         if err := expected.EncodeVarint(uint64((extdesc2.Field << 3) | proto.WireBytes)); err != nil {
139                 t.Fatalf("failed to compute expected prefix for ext2: %s", err)
140         }
141         if err := expected.EncodeRawBytes(ext2); err != nil {
142                 t.Fatalf("failed to compute expected value for ext2: %s", err)
143         }
144
145         if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc2.Field}); err != nil {
146                 t.Fatalf("Failed to get raw value for ext2: %s", err)
147         } else if !reflect.DeepEqual(b, expected.Bytes()) {
148                 t.Fatalf("Raw value for ext2: got %v, want %v", b, expected.Bytes())
149         }
150
151         expected = proto.Buffer{} // reset
152         if err := expected.EncodeVarint(uint64((extdesc3.Field << 3) | proto.WireBytes)); err != nil {
153                 t.Fatalf("failed to compute expected prefix for ext3: %s", err)
154         }
155         if b, err := proto.Marshal(ext3); err != nil {
156                 t.Fatalf("failed to compute expected value for ext3: %s", err)
157         } else if err := expected.EncodeRawBytes(b); err != nil {
158                 t.Fatalf("failed to compute expected value for ext3: %s", err)
159         }
160
161         if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc3.Field}); err != nil {
162                 t.Fatalf("Failed to get raw value for ext3: %s", err)
163         } else if !reflect.DeepEqual(b, expected.Bytes()) {
164                 t.Fatalf("Raw value for ext3: got %v, want %v", b, expected.Bytes())
165         }
166 }
167
168 func TestExtensionDescsWithUnregisteredExtensions(t *testing.T) {
169         msg := &pb.MyMessage{Count: proto.Int32(0)}
170         extdesc1 := pb.E_Ext_More
171         if descs, err := proto.ExtensionDescs(msg); len(descs) != 0 || err != nil {
172                 t.Errorf("proto.ExtensionDescs: got %d descs, error %v; want 0, nil", len(descs), err)
173         }
174
175         ext1 := &pb.Ext{}
176         if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
177                 t.Fatalf("Could not set ext1: %s", err)
178         }
179         extdesc2 := &proto.ExtensionDesc{
180                 ExtendedType:  (*pb.MyMessage)(nil),
181                 ExtensionType: (*bool)(nil),
182                 Field:         123456789,
183                 Name:          "a.b",
184                 Tag:           "varint,123456789,opt",
185         }
186         ext2 := proto.Bool(false)
187         if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
188                 t.Fatalf("Could not set ext2: %s", err)
189         }
190
191         b, err := proto.Marshal(msg)
192         if err != nil {
193                 t.Fatalf("Could not marshal msg: %v", err)
194         }
195         if err = proto.Unmarshal(b, msg); err != nil {
196                 t.Fatalf("Could not unmarshal into msg: %v", err)
197         }
198
199         descs, err := proto.ExtensionDescs(msg)
200         if err != nil {
201                 t.Fatalf("proto.ExtensionDescs: got error %v", err)
202         }
203         sortExtDescs(descs)
204         wantDescs := []*proto.ExtensionDesc{extdesc1, {Field: extdesc2.Field}}
205         if !reflect.DeepEqual(descs, wantDescs) {
206                 t.Errorf("proto.ExtensionDescs(msg) sorted extension ids: got %+v, want %+v", descs, wantDescs)
207         }
208 }
209
210 type ExtensionDescSlice []*proto.ExtensionDesc
211
212 func (s ExtensionDescSlice) Len() int           { return len(s) }
213 func (s ExtensionDescSlice) Less(i, j int) bool { return s[i].Field < s[j].Field }
214 func (s ExtensionDescSlice) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
215
216 func sortExtDescs(s []*proto.ExtensionDesc) {
217         sort.Sort(ExtensionDescSlice(s))
218 }
219
220 func TestGetExtensionStability(t *testing.T) {
221         check := func(m *pb.MyMessage) bool {
222                 ext1, err := proto.GetExtension(m, pb.E_Ext_More)
223                 if err != nil {
224                         t.Fatalf("GetExtension() failed: %s", err)
225                 }
226                 ext2, err := proto.GetExtension(m, pb.E_Ext_More)
227                 if err != nil {
228                         t.Fatalf("GetExtension() failed: %s", err)
229                 }
230                 return ext1 == ext2
231         }
232         msg := &pb.MyMessage{Count: proto.Int32(4)}
233         ext0 := &pb.Ext{}
234         if err := proto.SetExtension(msg, pb.E_Ext_More, ext0); err != nil {
235                 t.Fatalf("Could not set ext1: %s", ext0)
236         }
237         if !check(msg) {
238                 t.Errorf("GetExtension() not stable before marshaling")
239         }
240         bb, err := proto.Marshal(msg)
241         if err != nil {
242                 t.Fatalf("Marshal() failed: %s", err)
243         }
244         msg1 := &pb.MyMessage{}
245         err = proto.Unmarshal(bb, msg1)
246         if err != nil {
247                 t.Fatalf("Unmarshal() failed: %s", err)
248         }
249         if !check(msg1) {
250                 t.Errorf("GetExtension() not stable after unmarshaling")
251         }
252 }
253
254 func TestGetExtensionDefaults(t *testing.T) {
255         var setFloat64 float64 = 1
256         var setFloat32 float32 = 2
257         var setInt32 int32 = 3
258         var setInt64 int64 = 4
259         var setUint32 uint32 = 5
260         var setUint64 uint64 = 6
261         var setBool = true
262         var setBool2 = false
263         var setString = "Goodnight string"
264         var setBytes = []byte("Goodnight bytes")
265         var setEnum = pb.DefaultsMessage_TWO
266
267         type testcase struct {
268                 ext  *proto.ExtensionDesc // Extension we are testing.
269                 want interface{}          // Expected value of extension, or nil (meaning that GetExtension will fail).
270                 def  interface{}          // Expected value of extension after ClearExtension().
271         }
272         tests := []testcase{
273                 {pb.E_NoDefaultDouble, setFloat64, nil},
274                 {pb.E_NoDefaultFloat, setFloat32, nil},
275                 {pb.E_NoDefaultInt32, setInt32, nil},
276                 {pb.E_NoDefaultInt64, setInt64, nil},
277                 {pb.E_NoDefaultUint32, setUint32, nil},
278                 {pb.E_NoDefaultUint64, setUint64, nil},
279                 {pb.E_NoDefaultSint32, setInt32, nil},
280                 {pb.E_NoDefaultSint64, setInt64, nil},
281                 {pb.E_NoDefaultFixed32, setUint32, nil},
282                 {pb.E_NoDefaultFixed64, setUint64, nil},
283                 {pb.E_NoDefaultSfixed32, setInt32, nil},
284                 {pb.E_NoDefaultSfixed64, setInt64, nil},
285                 {pb.E_NoDefaultBool, setBool, nil},
286                 {pb.E_NoDefaultBool, setBool2, nil},
287                 {pb.E_NoDefaultString, setString, nil},
288                 {pb.E_NoDefaultBytes, setBytes, nil},
289                 {pb.E_NoDefaultEnum, setEnum, nil},
290                 {pb.E_DefaultDouble, setFloat64, float64(3.1415)},
291                 {pb.E_DefaultFloat, setFloat32, float32(3.14)},
292                 {pb.E_DefaultInt32, setInt32, int32(42)},
293                 {pb.E_DefaultInt64, setInt64, int64(43)},
294                 {pb.E_DefaultUint32, setUint32, uint32(44)},
295                 {pb.E_DefaultUint64, setUint64, uint64(45)},
296                 {pb.E_DefaultSint32, setInt32, int32(46)},
297                 {pb.E_DefaultSint64, setInt64, int64(47)},
298                 {pb.E_DefaultFixed32, setUint32, uint32(48)},
299                 {pb.E_DefaultFixed64, setUint64, uint64(49)},
300                 {pb.E_DefaultSfixed32, setInt32, int32(50)},
301                 {pb.E_DefaultSfixed64, setInt64, int64(51)},
302                 {pb.E_DefaultBool, setBool, true},
303                 {pb.E_DefaultBool, setBool2, true},
304                 {pb.E_DefaultString, setString, "Hello, string,def=foo"},
305                 {pb.E_DefaultBytes, setBytes, []byte("Hello, bytes")},
306                 {pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE},
307         }
308
309         checkVal := func(test testcase, msg *pb.DefaultsMessage, valWant interface{}) error {
310                 val, err := proto.GetExtension(msg, test.ext)
311                 if err != nil {
312                         if valWant != nil {
313                                 return fmt.Errorf("GetExtension(): %s", err)
314                         }
315                         if want := proto.ErrMissingExtension; err != want {
316                                 return fmt.Errorf("Unexpected error: got %v, want %v", err, want)
317                         }
318                         return nil
319                 }
320
321                 // All proto2 extension values are either a pointer to a value or a slice of values.
322                 ty := reflect.TypeOf(val)
323                 tyWant := reflect.TypeOf(test.ext.ExtensionType)
324                 if got, want := ty, tyWant; got != want {
325                         return fmt.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want)
326                 }
327                 tye := ty.Elem()
328                 tyeWant := tyWant.Elem()
329                 if got, want := tye, tyeWant; got != want {
330                         return fmt.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want)
331                 }
332
333                 // Check the name of the type of the value.
334                 // If it is an enum it will be type int32 with the name of the enum.
335                 if got, want := tye.Name(), tye.Name(); got != want {
336                         return fmt.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want)
337                 }
338
339                 // Check that value is what we expect.
340                 // If we have a pointer in val, get the value it points to.
341                 valExp := val
342                 if ty.Kind() == reflect.Ptr {
343                         valExp = reflect.ValueOf(val).Elem().Interface()
344                 }
345                 if got, want := valExp, valWant; !reflect.DeepEqual(got, want) {
346                         return fmt.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want)
347                 }
348
349                 return nil
350         }
351
352         setTo := func(test testcase) interface{} {
353                 setTo := reflect.ValueOf(test.want)
354                 if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr {
355                         setTo = reflect.New(typ).Elem()
356                         setTo.Set(reflect.New(setTo.Type().Elem()))
357                         setTo.Elem().Set(reflect.ValueOf(test.want))
358                 }
359                 return setTo.Interface()
360         }
361
362         for _, test := range tests {
363                 msg := &pb.DefaultsMessage{}
364                 name := test.ext.Name
365
366                 // Check the initial value.
367                 if err := checkVal(test, msg, test.def); err != nil {
368                         t.Errorf("%s: %v", name, err)
369                 }
370
371                 // Set the per-type value and check value.
372                 name = fmt.Sprintf("%s (set to %T %v)", name, test.want, test.want)
373                 if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil {
374                         t.Errorf("%s: SetExtension(): %v", name, err)
375                         continue
376                 }
377                 if err := checkVal(test, msg, test.want); err != nil {
378                         t.Errorf("%s: %v", name, err)
379                         continue
380                 }
381
382                 // Set and check the value.
383                 name += " (cleared)"
384                 proto.ClearExtension(msg, test.ext)
385                 if err := checkVal(test, msg, test.def); err != nil {
386                         t.Errorf("%s: %v", name, err)
387                 }
388         }
389 }
390
391 func TestNilMessage(t *testing.T) {
392         name := "nil interface"
393         if got, err := proto.GetExtension(nil, pb.E_Ext_More); err == nil {
394                 t.Errorf("%s: got %T %v, expected to fail", name, got, got)
395         } else if !strings.Contains(err.Error(), "extendable") {
396                 t.Errorf("%s: got error %v, expected not-extendable error", name, err)
397         }
398
399         // Regression tests: all functions of the Extension API
400         // used to panic when passed (*M)(nil), where M is a concrete message
401         // type.  Now they handle this gracefully as a no-op or reported error.
402         var nilMsg *pb.MyMessage
403         desc := pb.E_Ext_More
404
405         isNotExtendable := func(err error) bool {
406                 return strings.Contains(fmt.Sprint(err), "not extendable")
407         }
408
409         if proto.HasExtension(nilMsg, desc) {
410                 t.Error("HasExtension(nil) = true")
411         }
412
413         if _, err := proto.GetExtensions(nilMsg, []*proto.ExtensionDesc{desc}); !isNotExtendable(err) {
414                 t.Errorf("GetExtensions(nil) = %q (wrong error)", err)
415         }
416
417         if _, err := proto.ExtensionDescs(nilMsg); !isNotExtendable(err) {
418                 t.Errorf("ExtensionDescs(nil) = %q (wrong error)", err)
419         }
420
421         if err := proto.SetExtension(nilMsg, desc, nil); !isNotExtendable(err) {
422                 t.Errorf("SetExtension(nil) = %q (wrong error)", err)
423         }
424
425         proto.ClearExtension(nilMsg, desc) // no-op
426         proto.ClearAllExtensions(nilMsg)   // no-op
427 }
428
429 func TestExtensionsRoundTrip(t *testing.T) {
430         msg := &pb.MyMessage{}
431         ext1 := &pb.Ext{
432                 Data: proto.String("hi"),
433         }
434         ext2 := &pb.Ext{
435                 Data: proto.String("there"),
436         }
437         exists := proto.HasExtension(msg, pb.E_Ext_More)
438         if exists {
439                 t.Error("Extension More present unexpectedly")
440         }
441         if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
442                 t.Error(err)
443         }
444         if err := proto.SetExtension(msg, pb.E_Ext_More, ext2); err != nil {
445                 t.Error(err)
446         }
447         e, err := proto.GetExtension(msg, pb.E_Ext_More)
448         if err != nil {
449                 t.Error(err)
450         }
451         x, ok := e.(*pb.Ext)
452         if !ok {
453                 t.Errorf("e has type %T, expected test_proto.Ext", e)
454         } else if *x.Data != "there" {
455                 t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x)
456         }
457         proto.ClearExtension(msg, pb.E_Ext_More)
458         if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension {
459                 t.Errorf("got %v, expected ErrMissingExtension", e)
460         }
461         if _, err := proto.GetExtension(msg, pb.E_X215); err == nil {
462                 t.Error("expected bad extension error, got nil")
463         }
464         if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil {
465                 t.Error("expected extension err")
466         }
467         if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil {
468                 t.Error("expected some sort of type mismatch error, got nil")
469         }
470 }
471
472 func TestNilExtension(t *testing.T) {
473         msg := &pb.MyMessage{
474                 Count: proto.Int32(1),
475         }
476         if err := proto.SetExtension(msg, pb.E_Ext_Text, proto.String("hello")); err != nil {
477                 t.Fatal(err)
478         }
479         if err := proto.SetExtension(msg, pb.E_Ext_More, (*pb.Ext)(nil)); err == nil {
480                 t.Error("expected SetExtension to fail due to a nil extension")
481         } else if want := fmt.Sprintf("proto: SetExtension called with nil value of type %T", new(pb.Ext)); err.Error() != want {
482                 t.Errorf("expected error %v, got %v", want, err)
483         }
484         // Note: if the behavior of Marshal is ever changed to ignore nil extensions, update
485         // this test to verify that E_Ext_Text is properly propagated through marshal->unmarshal.
486 }
487
488 func TestMarshalUnmarshalRepeatedExtension(t *testing.T) {
489         // Add a repeated extension to the result.
490         tests := []struct {
491                 name string
492                 ext  []*pb.ComplexExtension
493         }{
494                 {
495                         "two fields",
496                         []*pb.ComplexExtension{
497                                 {First: proto.Int32(7)},
498                                 {Second: proto.Int32(11)},
499                         },
500                 },
501                 {
502                         "repeated field",
503                         []*pb.ComplexExtension{
504                                 {Third: []int32{1000}},
505                                 {Third: []int32{2000}},
506                         },
507                 },
508                 {
509                         "two fields and repeated field",
510                         []*pb.ComplexExtension{
511                                 {Third: []int32{1000}},
512                                 {First: proto.Int32(9)},
513                                 {Second: proto.Int32(21)},
514                                 {Third: []int32{2000}},
515                         },
516                 },
517         }
518         for _, test := range tests {
519                 // Marshal message with a repeated extension.
520                 msg1 := new(pb.OtherMessage)
521                 err := proto.SetExtension(msg1, pb.E_RComplex, test.ext)
522                 if err != nil {
523                         t.Fatalf("[%s] Error setting extension: %v", test.name, err)
524                 }
525                 b, err := proto.Marshal(msg1)
526                 if err != nil {
527                         t.Fatalf("[%s] Error marshaling message: %v", test.name, err)
528                 }
529
530                 // Unmarshal and read the merged proto.
531                 msg2 := new(pb.OtherMessage)
532                 err = proto.Unmarshal(b, msg2)
533                 if err != nil {
534                         t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
535                 }
536                 e, err := proto.GetExtension(msg2, pb.E_RComplex)
537                 if err != nil {
538                         t.Fatalf("[%s] Error getting extension: %v", test.name, err)
539                 }
540                 ext := e.([]*pb.ComplexExtension)
541                 if ext == nil {
542                         t.Fatalf("[%s] Invalid extension", test.name)
543                 }
544                 if len(ext) != len(test.ext) {
545                         t.Errorf("[%s] Wrong length of ComplexExtension: got: %v want: %v\n", test.name, len(ext), len(test.ext))
546                 }
547                 for i := range test.ext {
548                         if !proto.Equal(ext[i], test.ext[i]) {
549                                 t.Errorf("[%s] Wrong value for ComplexExtension[%d]: got: %v want: %v\n", test.name, i, ext[i], test.ext[i])
550                         }
551                 }
552         }
553 }
554
555 func TestUnmarshalRepeatingNonRepeatedExtension(t *testing.T) {
556         // We may see multiple instances of the same extension in the wire
557         // format. For example, the proto compiler may encode custom options in
558         // this way. Here, we verify that we merge the extensions together.
559         tests := []struct {
560                 name string
561                 ext  []*pb.ComplexExtension
562         }{
563                 {
564                         "two fields",
565                         []*pb.ComplexExtension{
566                                 {First: proto.Int32(7)},
567                                 {Second: proto.Int32(11)},
568                         },
569                 },
570                 {
571                         "repeated field",
572                         []*pb.ComplexExtension{
573                                 {Third: []int32{1000}},
574                                 {Third: []int32{2000}},
575                         },
576                 },
577                 {
578                         "two fields and repeated field",
579                         []*pb.ComplexExtension{
580                                 {Third: []int32{1000}},
581                                 {First: proto.Int32(9)},
582                                 {Second: proto.Int32(21)},
583                                 {Third: []int32{2000}},
584                         },
585                 },
586         }
587         for _, test := range tests {
588                 var buf bytes.Buffer
589                 var want pb.ComplexExtension
590
591                 // Generate a serialized representation of a repeated extension
592                 // by catenating bytes together.
593                 for i, e := range test.ext {
594                         // Merge to create the wanted proto.
595                         proto.Merge(&want, e)
596
597                         // serialize the message
598                         msg := new(pb.OtherMessage)
599                         err := proto.SetExtension(msg, pb.E_Complex, e)
600                         if err != nil {
601                                 t.Fatalf("[%s] Error setting extension %d: %v", test.name, i, err)
602                         }
603                         b, err := proto.Marshal(msg)
604                         if err != nil {
605                                 t.Fatalf("[%s] Error marshaling message %d: %v", test.name, i, err)
606                         }
607                         buf.Write(b)
608                 }
609
610                 // Unmarshal and read the merged proto.
611                 msg2 := new(pb.OtherMessage)
612                 err := proto.Unmarshal(buf.Bytes(), msg2)
613                 if err != nil {
614                         t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
615                 }
616                 e, err := proto.GetExtension(msg2, pb.E_Complex)
617                 if err != nil {
618                         t.Fatalf("[%s] Error getting extension: %v", test.name, err)
619                 }
620                 ext := e.(*pb.ComplexExtension)
621                 if ext == nil {
622                         t.Fatalf("[%s] Invalid extension", test.name)
623                 }
624                 if !proto.Equal(ext, &want) {
625                         t.Errorf("[%s] Wrong value for ComplexExtension: got: %v want: %v\n", test.name, ext, &want)
626
627                 }
628         }
629 }
630
631 func TestClearAllExtensions(t *testing.T) {
632         // unregistered extension
633         desc := &proto.ExtensionDesc{
634                 ExtendedType:  (*pb.MyMessage)(nil),
635                 ExtensionType: (*bool)(nil),
636                 Field:         101010100,
637                 Name:          "emptyextension",
638                 Tag:           "varint,0,opt",
639         }
640         m := &pb.MyMessage{}
641         if proto.HasExtension(m, desc) {
642                 t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
643         }
644         if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil {
645                 t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
646         }
647         if !proto.HasExtension(m, desc) {
648                 t.Errorf("proto.HasExtension(%s): got false, want true", proto.MarshalTextString(m))
649         }
650         proto.ClearAllExtensions(m)
651         if proto.HasExtension(m, desc) {
652                 t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
653         }
654 }
655
656 func TestMarshalRace(t *testing.T) {
657         ext := &pb.Ext{}
658         m := &pb.MyMessage{Count: proto.Int32(4)}
659         if err := proto.SetExtension(m, pb.E_Ext_More, ext); err != nil {
660                 t.Fatalf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
661         }
662
663         b, err := proto.Marshal(m)
664         if err != nil {
665                 t.Fatalf("Could not marshal message: %v", err)
666         }
667         if err := proto.Unmarshal(b, m); err != nil {
668                 t.Fatalf("Could not unmarshal message: %v", err)
669         }
670         // after Unmarshal, the extension is in undecoded form.
671         // GetExtension will decode it lazily. Make sure this does
672         // not race against Marshal.
673
674         errChan := make(chan error, 6)
675         for n := 3; n > 0; n-- {
676                 go func() {
677                         _, err := proto.Marshal(m)
678                         errChan <- err
679                 }()
680                 go func() {
681                         _, err := proto.GetExtension(m, pb.E_Ext_More)
682                         errChan <- err
683                 }()
684         }
685         for i := 0; i < 6; i++ {
686                 err := <-errChan
687                 if err != nil {
688                         t.Fatal(err)
689                 }
690         }
691 }