OSDN Git Service

Merge pull request #41 from Bytom/dev
[bytom/vapor.git] / vendor / github.com / golang / protobuf / proto / extensions_test.go
1 // Go support for Protocol Buffers - Google's data interchange format
2 //
3 // Copyright 2014 The Go Authors.  All rights reserved.
4 // https://github.com/golang/protobuf
5 //
6 // Redistribution and use in source and binary forms, with or without
7 // modification, are permitted provided that the following conditions are
8 // met:
9 //
10 //     * Redistributions of source code must retain the above copyright
11 // notice, this list of conditions and the following disclaimer.
12 //     * Redistributions in binary form must reproduce the above
13 // copyright notice, this list of conditions and the following disclaimer
14 // in the documentation and/or other materials provided with the
15 // distribution.
16 //     * Neither the name of Google Inc. nor the names of its
17 // contributors may be used to endorse or promote products derived from
18 // this software without specific prior written permission.
19 //
20 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
32 package proto_test
33
34 import (
35         "bytes"
36         "fmt"
37         "reflect"
38         "sort"
39         "testing"
40
41         "github.com/golang/protobuf/proto"
42         pb "github.com/golang/protobuf/proto/testdata"
43         "golang.org/x/sync/errgroup"
44 )
45
46 func TestGetExtensionsWithMissingExtensions(t *testing.T) {
47         msg := &pb.MyMessage{}
48         ext1 := &pb.Ext{}
49         if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
50                 t.Fatalf("Could not set ext1: %s", err)
51         }
52         exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{
53                 pb.E_Ext_More,
54                 pb.E_Ext_Text,
55         })
56         if err != nil {
57                 t.Fatalf("GetExtensions() failed: %s", err)
58         }
59         if exts[0] != ext1 {
60                 t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0])
61         }
62         if exts[1] != nil {
63                 t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1])
64         }
65 }
66
67 func TestExtensionDescsWithMissingExtensions(t *testing.T) {
68         msg := &pb.MyMessage{Count: proto.Int32(0)}
69         extdesc1 := pb.E_Ext_More
70         if descs, err := proto.ExtensionDescs(msg); len(descs) != 0 || err != nil {
71                 t.Errorf("proto.ExtensionDescs: got %d descs, error %v; want 0, nil", len(descs), err)
72         }
73
74         ext1 := &pb.Ext{}
75         if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
76                 t.Fatalf("Could not set ext1: %s", err)
77         }
78         extdesc2 := &proto.ExtensionDesc{
79                 ExtendedType:  (*pb.MyMessage)(nil),
80                 ExtensionType: (*bool)(nil),
81                 Field:         123456789,
82                 Name:          "a.b",
83                 Tag:           "varint,123456789,opt",
84         }
85         ext2 := proto.Bool(false)
86         if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
87                 t.Fatalf("Could not set ext2: %s", err)
88         }
89
90         b, err := proto.Marshal(msg)
91         if err != nil {
92                 t.Fatalf("Could not marshal msg: %v", err)
93         }
94         if err := proto.Unmarshal(b, msg); err != nil {
95                 t.Fatalf("Could not unmarshal into msg: %v", err)
96         }
97
98         descs, err := proto.ExtensionDescs(msg)
99         if err != nil {
100                 t.Fatalf("proto.ExtensionDescs: got error %v", err)
101         }
102         sortExtDescs(descs)
103         wantDescs := []*proto.ExtensionDesc{extdesc1, &proto.ExtensionDesc{Field: extdesc2.Field}}
104         if !reflect.DeepEqual(descs, wantDescs) {
105                 t.Errorf("proto.ExtensionDescs(msg) sorted extension ids: got %+v, want %+v", descs, wantDescs)
106         }
107 }
108
109 type ExtensionDescSlice []*proto.ExtensionDesc
110
111 func (s ExtensionDescSlice) Len() int           { return len(s) }
112 func (s ExtensionDescSlice) Less(i, j int) bool { return s[i].Field < s[j].Field }
113 func (s ExtensionDescSlice) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
114
115 func sortExtDescs(s []*proto.ExtensionDesc) {
116         sort.Sort(ExtensionDescSlice(s))
117 }
118
119 func TestGetExtensionStability(t *testing.T) {
120         check := func(m *pb.MyMessage) bool {
121                 ext1, err := proto.GetExtension(m, pb.E_Ext_More)
122                 if err != nil {
123                         t.Fatalf("GetExtension() failed: %s", err)
124                 }
125                 ext2, err := proto.GetExtension(m, pb.E_Ext_More)
126                 if err != nil {
127                         t.Fatalf("GetExtension() failed: %s", err)
128                 }
129                 return ext1 == ext2
130         }
131         msg := &pb.MyMessage{Count: proto.Int32(4)}
132         ext0 := &pb.Ext{}
133         if err := proto.SetExtension(msg, pb.E_Ext_More, ext0); err != nil {
134                 t.Fatalf("Could not set ext1: %s", ext0)
135         }
136         if !check(msg) {
137                 t.Errorf("GetExtension() not stable before marshaling")
138         }
139         bb, err := proto.Marshal(msg)
140         if err != nil {
141                 t.Fatalf("Marshal() failed: %s", err)
142         }
143         msg1 := &pb.MyMessage{}
144         err = proto.Unmarshal(bb, msg1)
145         if err != nil {
146                 t.Fatalf("Unmarshal() failed: %s", err)
147         }
148         if !check(msg1) {
149                 t.Errorf("GetExtension() not stable after unmarshaling")
150         }
151 }
152
153 func TestGetExtensionDefaults(t *testing.T) {
154         var setFloat64 float64 = 1
155         var setFloat32 float32 = 2
156         var setInt32 int32 = 3
157         var setInt64 int64 = 4
158         var setUint32 uint32 = 5
159         var setUint64 uint64 = 6
160         var setBool = true
161         var setBool2 = false
162         var setString = "Goodnight string"
163         var setBytes = []byte("Goodnight bytes")
164         var setEnum = pb.DefaultsMessage_TWO
165
166         type testcase struct {
167                 ext  *proto.ExtensionDesc // Extension we are testing.
168                 want interface{}          // Expected value of extension, or nil (meaning that GetExtension will fail).
169                 def  interface{}          // Expected value of extension after ClearExtension().
170         }
171         tests := []testcase{
172                 {pb.E_NoDefaultDouble, setFloat64, nil},
173                 {pb.E_NoDefaultFloat, setFloat32, nil},
174                 {pb.E_NoDefaultInt32, setInt32, nil},
175                 {pb.E_NoDefaultInt64, setInt64, nil},
176                 {pb.E_NoDefaultUint32, setUint32, nil},
177                 {pb.E_NoDefaultUint64, setUint64, nil},
178                 {pb.E_NoDefaultSint32, setInt32, nil},
179                 {pb.E_NoDefaultSint64, setInt64, nil},
180                 {pb.E_NoDefaultFixed32, setUint32, nil},
181                 {pb.E_NoDefaultFixed64, setUint64, nil},
182                 {pb.E_NoDefaultSfixed32, setInt32, nil},
183                 {pb.E_NoDefaultSfixed64, setInt64, nil},
184                 {pb.E_NoDefaultBool, setBool, nil},
185                 {pb.E_NoDefaultBool, setBool2, nil},
186                 {pb.E_NoDefaultString, setString, nil},
187                 {pb.E_NoDefaultBytes, setBytes, nil},
188                 {pb.E_NoDefaultEnum, setEnum, nil},
189                 {pb.E_DefaultDouble, setFloat64, float64(3.1415)},
190                 {pb.E_DefaultFloat, setFloat32, float32(3.14)},
191                 {pb.E_DefaultInt32, setInt32, int32(42)},
192                 {pb.E_DefaultInt64, setInt64, int64(43)},
193                 {pb.E_DefaultUint32, setUint32, uint32(44)},
194                 {pb.E_DefaultUint64, setUint64, uint64(45)},
195                 {pb.E_DefaultSint32, setInt32, int32(46)},
196                 {pb.E_DefaultSint64, setInt64, int64(47)},
197                 {pb.E_DefaultFixed32, setUint32, uint32(48)},
198                 {pb.E_DefaultFixed64, setUint64, uint64(49)},
199                 {pb.E_DefaultSfixed32, setInt32, int32(50)},
200                 {pb.E_DefaultSfixed64, setInt64, int64(51)},
201                 {pb.E_DefaultBool, setBool, true},
202                 {pb.E_DefaultBool, setBool2, true},
203                 {pb.E_DefaultString, setString, "Hello, string"},
204                 {pb.E_DefaultBytes, setBytes, []byte("Hello, bytes")},
205                 {pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE},
206         }
207
208         checkVal := func(test testcase, msg *pb.DefaultsMessage, valWant interface{}) error {
209                 val, err := proto.GetExtension(msg, test.ext)
210                 if err != nil {
211                         if valWant != nil {
212                                 return fmt.Errorf("GetExtension(): %s", err)
213                         }
214                         if want := proto.ErrMissingExtension; err != want {
215                                 return fmt.Errorf("Unexpected error: got %v, want %v", err, want)
216                         }
217                         return nil
218                 }
219
220                 // All proto2 extension values are either a pointer to a value or a slice of values.
221                 ty := reflect.TypeOf(val)
222                 tyWant := reflect.TypeOf(test.ext.ExtensionType)
223                 if got, want := ty, tyWant; got != want {
224                         return fmt.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want)
225                 }
226                 tye := ty.Elem()
227                 tyeWant := tyWant.Elem()
228                 if got, want := tye, tyeWant; got != want {
229                         return fmt.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want)
230                 }
231
232                 // Check the name of the type of the value.
233                 // If it is an enum it will be type int32 with the name of the enum.
234                 if got, want := tye.Name(), tye.Name(); got != want {
235                         return fmt.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want)
236                 }
237
238                 // Check that value is what we expect.
239                 // If we have a pointer in val, get the value it points to.
240                 valExp := val
241                 if ty.Kind() == reflect.Ptr {
242                         valExp = reflect.ValueOf(val).Elem().Interface()
243                 }
244                 if got, want := valExp, valWant; !reflect.DeepEqual(got, want) {
245                         return fmt.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want)
246                 }
247
248                 return nil
249         }
250
251         setTo := func(test testcase) interface{} {
252                 setTo := reflect.ValueOf(test.want)
253                 if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr {
254                         setTo = reflect.New(typ).Elem()
255                         setTo.Set(reflect.New(setTo.Type().Elem()))
256                         setTo.Elem().Set(reflect.ValueOf(test.want))
257                 }
258                 return setTo.Interface()
259         }
260
261         for _, test := range tests {
262                 msg := &pb.DefaultsMessage{}
263                 name := test.ext.Name
264
265                 // Check the initial value.
266                 if err := checkVal(test, msg, test.def); err != nil {
267                         t.Errorf("%s: %v", name, err)
268                 }
269
270                 // Set the per-type value and check value.
271                 name = fmt.Sprintf("%s (set to %T %v)", name, test.want, test.want)
272                 if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil {
273                         t.Errorf("%s: SetExtension(): %v", name, err)
274                         continue
275                 }
276                 if err := checkVal(test, msg, test.want); err != nil {
277                         t.Errorf("%s: %v", name, err)
278                         continue
279                 }
280
281                 // Set and check the value.
282                 name += " (cleared)"
283                 proto.ClearExtension(msg, test.ext)
284                 if err := checkVal(test, msg, test.def); err != nil {
285                         t.Errorf("%s: %v", name, err)
286                 }
287         }
288 }
289
290 func TestExtensionsRoundTrip(t *testing.T) {
291         msg := &pb.MyMessage{}
292         ext1 := &pb.Ext{
293                 Data: proto.String("hi"),
294         }
295         ext2 := &pb.Ext{
296                 Data: proto.String("there"),
297         }
298         exists := proto.HasExtension(msg, pb.E_Ext_More)
299         if exists {
300                 t.Error("Extension More present unexpectedly")
301         }
302         if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
303                 t.Error(err)
304         }
305         if err := proto.SetExtension(msg, pb.E_Ext_More, ext2); err != nil {
306                 t.Error(err)
307         }
308         e, err := proto.GetExtension(msg, pb.E_Ext_More)
309         if err != nil {
310                 t.Error(err)
311         }
312         x, ok := e.(*pb.Ext)
313         if !ok {
314                 t.Errorf("e has type %T, expected testdata.Ext", e)
315         } else if *x.Data != "there" {
316                 t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x)
317         }
318         proto.ClearExtension(msg, pb.E_Ext_More)
319         if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension {
320                 t.Errorf("got %v, expected ErrMissingExtension", e)
321         }
322         if _, err := proto.GetExtension(msg, pb.E_X215); err == nil {
323                 t.Error("expected bad extension error, got nil")
324         }
325         if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil {
326                 t.Error("expected extension err")
327         }
328         if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil {
329                 t.Error("expected some sort of type mismatch error, got nil")
330         }
331 }
332
333 func TestNilExtension(t *testing.T) {
334         msg := &pb.MyMessage{
335                 Count: proto.Int32(1),
336         }
337         if err := proto.SetExtension(msg, pb.E_Ext_Text, proto.String("hello")); err != nil {
338                 t.Fatal(err)
339         }
340         if err := proto.SetExtension(msg, pb.E_Ext_More, (*pb.Ext)(nil)); err == nil {
341                 t.Error("expected SetExtension to fail due to a nil extension")
342         } else if want := "proto: SetExtension called with nil value of type *testdata.Ext"; err.Error() != want {
343                 t.Errorf("expected error %v, got %v", want, err)
344         }
345         // Note: if the behavior of Marshal is ever changed to ignore nil extensions, update
346         // this test to verify that E_Ext_Text is properly propagated through marshal->unmarshal.
347 }
348
349 func TestMarshalUnmarshalRepeatedExtension(t *testing.T) {
350         // Add a repeated extension to the result.
351         tests := []struct {
352                 name string
353                 ext  []*pb.ComplexExtension
354         }{
355                 {
356                         "two fields",
357                         []*pb.ComplexExtension{
358                                 {First: proto.Int32(7)},
359                                 {Second: proto.Int32(11)},
360                         },
361                 },
362                 {
363                         "repeated field",
364                         []*pb.ComplexExtension{
365                                 {Third: []int32{1000}},
366                                 {Third: []int32{2000}},
367                         },
368                 },
369                 {
370                         "two fields and repeated field",
371                         []*pb.ComplexExtension{
372                                 {Third: []int32{1000}},
373                                 {First: proto.Int32(9)},
374                                 {Second: proto.Int32(21)},
375                                 {Third: []int32{2000}},
376                         },
377                 },
378         }
379         for _, test := range tests {
380                 // Marshal message with a repeated extension.
381                 msg1 := new(pb.OtherMessage)
382                 err := proto.SetExtension(msg1, pb.E_RComplex, test.ext)
383                 if err != nil {
384                         t.Fatalf("[%s] Error setting extension: %v", test.name, err)
385                 }
386                 b, err := proto.Marshal(msg1)
387                 if err != nil {
388                         t.Fatalf("[%s] Error marshaling message: %v", test.name, err)
389                 }
390
391                 // Unmarshal and read the merged proto.
392                 msg2 := new(pb.OtherMessage)
393                 err = proto.Unmarshal(b, msg2)
394                 if err != nil {
395                         t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
396                 }
397                 e, err := proto.GetExtension(msg2, pb.E_RComplex)
398                 if err != nil {
399                         t.Fatalf("[%s] Error getting extension: %v", test.name, err)
400                 }
401                 ext := e.([]*pb.ComplexExtension)
402                 if ext == nil {
403                         t.Fatalf("[%s] Invalid extension", test.name)
404                 }
405                 if !reflect.DeepEqual(ext, test.ext) {
406                         t.Errorf("[%s] Wrong value for ComplexExtension: got: %v want: %v\n", test.name, ext, test.ext)
407                 }
408         }
409 }
410
411 func TestUnmarshalRepeatingNonRepeatedExtension(t *testing.T) {
412         // We may see multiple instances of the same extension in the wire
413         // format. For example, the proto compiler may encode custom options in
414         // this way. Here, we verify that we merge the extensions together.
415         tests := []struct {
416                 name string
417                 ext  []*pb.ComplexExtension
418         }{
419                 {
420                         "two fields",
421                         []*pb.ComplexExtension{
422                                 {First: proto.Int32(7)},
423                                 {Second: proto.Int32(11)},
424                         },
425                 },
426                 {
427                         "repeated field",
428                         []*pb.ComplexExtension{
429                                 {Third: []int32{1000}},
430                                 {Third: []int32{2000}},
431                         },
432                 },
433                 {
434                         "two fields and repeated field",
435                         []*pb.ComplexExtension{
436                                 {Third: []int32{1000}},
437                                 {First: proto.Int32(9)},
438                                 {Second: proto.Int32(21)},
439                                 {Third: []int32{2000}},
440                         },
441                 },
442         }
443         for _, test := range tests {
444                 var buf bytes.Buffer
445                 var want pb.ComplexExtension
446
447                 // Generate a serialized representation of a repeated extension
448                 // by catenating bytes together.
449                 for i, e := range test.ext {
450                         // Merge to create the wanted proto.
451                         proto.Merge(&want, e)
452
453                         // serialize the message
454                         msg := new(pb.OtherMessage)
455                         err := proto.SetExtension(msg, pb.E_Complex, e)
456                         if err != nil {
457                                 t.Fatalf("[%s] Error setting extension %d: %v", test.name, i, err)
458                         }
459                         b, err := proto.Marshal(msg)
460                         if err != nil {
461                                 t.Fatalf("[%s] Error marshaling message %d: %v", test.name, i, err)
462                         }
463                         buf.Write(b)
464                 }
465
466                 // Unmarshal and read the merged proto.
467                 msg2 := new(pb.OtherMessage)
468                 err := proto.Unmarshal(buf.Bytes(), msg2)
469                 if err != nil {
470                         t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
471                 }
472                 e, err := proto.GetExtension(msg2, pb.E_Complex)
473                 if err != nil {
474                         t.Fatalf("[%s] Error getting extension: %v", test.name, err)
475                 }
476                 ext := e.(*pb.ComplexExtension)
477                 if ext == nil {
478                         t.Fatalf("[%s] Invalid extension", test.name)
479                 }
480                 if !reflect.DeepEqual(*ext, want) {
481                         t.Errorf("[%s] Wrong value for ComplexExtension: got: %s want: %s\n", test.name, ext, want)
482                 }
483         }
484 }
485
486 func TestClearAllExtensions(t *testing.T) {
487         // unregistered extension
488         desc := &proto.ExtensionDesc{
489                 ExtendedType:  (*pb.MyMessage)(nil),
490                 ExtensionType: (*bool)(nil),
491                 Field:         101010100,
492                 Name:          "emptyextension",
493                 Tag:           "varint,0,opt",
494         }
495         m := &pb.MyMessage{}
496         if proto.HasExtension(m, desc) {
497                 t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
498         }
499         if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil {
500                 t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
501         }
502         if !proto.HasExtension(m, desc) {
503                 t.Errorf("proto.HasExtension(%s): got false, want true", proto.MarshalTextString(m))
504         }
505         proto.ClearAllExtensions(m)
506         if proto.HasExtension(m, desc) {
507                 t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
508         }
509 }
510
511 func TestMarshalRace(t *testing.T) {
512         // unregistered extension
513         desc := &proto.ExtensionDesc{
514                 ExtendedType:  (*pb.MyMessage)(nil),
515                 ExtensionType: (*bool)(nil),
516                 Field:         101010100,
517                 Name:          "emptyextension",
518                 Tag:           "varint,0,opt",
519         }
520
521         m := &pb.MyMessage{Count: proto.Int32(4)}
522         if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil {
523                 t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
524         }
525
526         var g errgroup.Group
527         for n := 3; n > 0; n-- {
528                 g.Go(func() error {
529                         _, err := proto.Marshal(m)
530                         return err
531                 })
532         }
533         if err := g.Wait(); err != nil {
534                 t.Fatal(err)
535         }
536 }