OSDN Git Service

Merge pull request #41 from Bytom/dev
[bytom/vapor.git] / vendor / github.com / gogo / protobuf / proto / extensions.go
1 // Go support for Protocol Buffers - Google's data interchange format
2 //
3 // Copyright 2010 The Go Authors.  All rights reserved.
4 // https://github.com/golang/protobuf
5 //
6 // Redistribution and use in source and binary forms, with or without
7 // modification, are permitted provided that the following conditions are
8 // met:
9 //
10 //     * Redistributions of source code must retain the above copyright
11 // notice, this list of conditions and the following disclaimer.
12 //     * Redistributions in binary form must reproduce the above
13 // copyright notice, this list of conditions and the following disclaimer
14 // in the documentation and/or other materials provided with the
15 // distribution.
16 //     * Neither the name of Google Inc. nor the names of its
17 // contributors may be used to endorse or promote products derived from
18 // this software without specific prior written permission.
19 //
20 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
32 package proto
33
34 /*
35  * Types and routines for supporting protocol buffer extensions.
36  */
37
38 import (
39         "errors"
40         "fmt"
41         "io"
42         "reflect"
43         "strconv"
44         "sync"
45 )
46
47 // ErrMissingExtension is the error returned by GetExtension if the named extension is not in the message.
48 var ErrMissingExtension = errors.New("proto: missing extension")
49
50 // ExtensionRange represents a range of message extensions for a protocol buffer.
51 // Used in code generated by the protocol compiler.
52 type ExtensionRange struct {
53         Start, End int32 // both inclusive
54 }
55
56 // extendableProto is an interface implemented by any protocol buffer generated by the current
57 // proto compiler that may be extended.
58 type extendableProto interface {
59         Message
60         ExtensionRangeArray() []ExtensionRange
61         extensionsWrite() map[int32]Extension
62         extensionsRead() (map[int32]Extension, sync.Locker)
63 }
64
65 // extendableProtoV1 is an interface implemented by a protocol buffer generated by the previous
66 // version of the proto compiler that may be extended.
67 type extendableProtoV1 interface {
68         Message
69         ExtensionRangeArray() []ExtensionRange
70         ExtensionMap() map[int32]Extension
71 }
72
73 // extensionAdapter is a wrapper around extendableProtoV1 that implements extendableProto.
74 type extensionAdapter struct {
75         extendableProtoV1
76 }
77
78 func (e extensionAdapter) extensionsWrite() map[int32]Extension {
79         return e.ExtensionMap()
80 }
81
82 func (e extensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) {
83         return e.ExtensionMap(), notLocker{}
84 }
85
86 // notLocker is a sync.Locker whose Lock and Unlock methods are nops.
87 type notLocker struct{}
88
89 func (n notLocker) Lock()   {}
90 func (n notLocker) Unlock() {}
91
92 // extendable returns the extendableProto interface for the given generated proto message.
93 // If the proto message has the old extension format, it returns a wrapper that implements
94 // the extendableProto interface.
95 func extendable(p interface{}) (extendableProto, error) {
96         switch p := p.(type) {
97         case extendableProto:
98                 if isNilPtr(p) {
99                         return nil, fmt.Errorf("proto: nil %T is not extendable", p)
100                 }
101                 return p, nil
102         case extendableProtoV1:
103                 if isNilPtr(p) {
104                         return nil, fmt.Errorf("proto: nil %T is not extendable", p)
105                 }
106                 return extensionAdapter{p}, nil
107         case extensionsBytes:
108                 return slowExtensionAdapter{p}, nil
109         }
110         // Don't allocate a specific error containing %T:
111         // this is the hot path for Clone and MarshalText.
112         return nil, errNotExtendable
113 }
114
115 var errNotExtendable = errors.New("proto: not an extendable proto.Message")
116
117 func isNilPtr(x interface{}) bool {
118         v := reflect.ValueOf(x)
119         return v.Kind() == reflect.Ptr && v.IsNil()
120 }
121
122 // XXX_InternalExtensions is an internal representation of proto extensions.
123 //
124 // Each generated message struct type embeds an anonymous XXX_InternalExtensions field,
125 // thus gaining the unexported 'extensions' method, which can be called only from the proto package.
126 //
127 // The methods of XXX_InternalExtensions are not concurrency safe in general,
128 // but calls to logically read-only methods such as has and get may be executed concurrently.
129 type XXX_InternalExtensions struct {
130         // The struct must be indirect so that if a user inadvertently copies a
131         // generated message and its embedded XXX_InternalExtensions, they
132         // avoid the mayhem of a copied mutex.
133         //
134         // The mutex serializes all logically read-only operations to p.extensionMap.
135         // It is up to the client to ensure that write operations to p.extensionMap are
136         // mutually exclusive with other accesses.
137         p *struct {
138                 mu           sync.Mutex
139                 extensionMap map[int32]Extension
140         }
141 }
142
143 // extensionsWrite returns the extension map, creating it on first use.
144 func (e *XXX_InternalExtensions) extensionsWrite() map[int32]Extension {
145         if e.p == nil {
146                 e.p = new(struct {
147                         mu           sync.Mutex
148                         extensionMap map[int32]Extension
149                 })
150                 e.p.extensionMap = make(map[int32]Extension)
151         }
152         return e.p.extensionMap
153 }
154
155 // extensionsRead returns the extensions map for read-only use.  It may be nil.
156 // The caller must hold the returned mutex's lock when accessing Elements within the map.
157 func (e *XXX_InternalExtensions) extensionsRead() (map[int32]Extension, sync.Locker) {
158         if e.p == nil {
159                 return nil, nil
160         }
161         return e.p.extensionMap, &e.p.mu
162 }
163
164 // ExtensionDesc represents an extension specification.
165 // Used in generated code from the protocol compiler.
166 type ExtensionDesc struct {
167         ExtendedType  Message     // nil pointer to the type that is being extended
168         ExtensionType interface{} // nil pointer to the extension type
169         Field         int32       // field number
170         Name          string      // fully-qualified name of extension, for text formatting
171         Tag           string      // protobuf tag style
172         Filename      string      // name of the file in which the extension is defined
173 }
174
175 func (ed *ExtensionDesc) repeated() bool {
176         t := reflect.TypeOf(ed.ExtensionType)
177         return t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8
178 }
179
180 // Extension represents an extension in a message.
181 type Extension struct {
182         // When an extension is stored in a message using SetExtension
183         // only desc and value are set. When the message is marshaled
184         // enc will be set to the encoded form of the message.
185         //
186         // When a message is unmarshaled and contains extensions, each
187         // extension will have only enc set. When such an extension is
188         // accessed using GetExtension (or GetExtensions) desc and value
189         // will be set.
190         desc  *ExtensionDesc
191         value interface{}
192         enc   []byte
193 }
194
195 // SetRawExtension is for testing only.
196 func SetRawExtension(base Message, id int32, b []byte) {
197         if ebase, ok := base.(extensionsBytes); ok {
198                 clearExtension(base, id)
199                 ext := ebase.GetExtensions()
200                 *ext = append(*ext, b...)
201                 return
202         }
203         epb, err := extendable(base)
204         if err != nil {
205                 return
206         }
207         extmap := epb.extensionsWrite()
208         extmap[id] = Extension{enc: b}
209 }
210
211 // isExtensionField returns true iff the given field number is in an extension range.
212 func isExtensionField(pb extendableProto, field int32) bool {
213         for _, er := range pb.ExtensionRangeArray() {
214                 if er.Start <= field && field <= er.End {
215                         return true
216                 }
217         }
218         return false
219 }
220
221 // checkExtensionTypes checks that the given extension is valid for pb.
222 func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error {
223         var pbi interface{} = pb
224         // Check the extended type.
225         if ea, ok := pbi.(extensionAdapter); ok {
226                 pbi = ea.extendableProtoV1
227         }
228         if ea, ok := pbi.(slowExtensionAdapter); ok {
229                 pbi = ea.extensionsBytes
230         }
231         if a, b := reflect.TypeOf(pbi), reflect.TypeOf(extension.ExtendedType); a != b {
232                 return fmt.Errorf("proto: bad extended type; %v does not extend %v", b, a)
233         }
234         // Check the range.
235         if !isExtensionField(pb, extension.Field) {
236                 return errors.New("proto: bad extension number; not in declared ranges")
237         }
238         return nil
239 }
240
241 // extPropKey is sufficient to uniquely identify an extension.
242 type extPropKey struct {
243         base  reflect.Type
244         field int32
245 }
246
247 var extProp = struct {
248         sync.RWMutex
249         m map[extPropKey]*Properties
250 }{
251         m: make(map[extPropKey]*Properties),
252 }
253
254 func extensionProperties(ed *ExtensionDesc) *Properties {
255         key := extPropKey{base: reflect.TypeOf(ed.ExtendedType), field: ed.Field}
256
257         extProp.RLock()
258         if prop, ok := extProp.m[key]; ok {
259                 extProp.RUnlock()
260                 return prop
261         }
262         extProp.RUnlock()
263
264         extProp.Lock()
265         defer extProp.Unlock()
266         // Check again.
267         if prop, ok := extProp.m[key]; ok {
268                 return prop
269         }
270
271         prop := new(Properties)
272         prop.Init(reflect.TypeOf(ed.ExtensionType), "unknown_name", ed.Tag, nil)
273         extProp.m[key] = prop
274         return prop
275 }
276
277 // HasExtension returns whether the given extension is present in pb.
278 func HasExtension(pb Message, extension *ExtensionDesc) bool {
279         if epb, doki := pb.(extensionsBytes); doki {
280                 ext := epb.GetExtensions()
281                 buf := *ext
282                 o := 0
283                 for o < len(buf) {
284                         tag, n := DecodeVarint(buf[o:])
285                         fieldNum := int32(tag >> 3)
286                         if int32(fieldNum) == extension.Field {
287                                 return true
288                         }
289                         wireType := int(tag & 0x7)
290                         o += n
291                         l, err := size(buf[o:], wireType)
292                         if err != nil {
293                                 return false
294                         }
295                         o += l
296                 }
297                 return false
298         }
299         // TODO: Check types, field numbers, etc.?
300         epb, err := extendable(pb)
301         if err != nil {
302                 return false
303         }
304         extmap, mu := epb.extensionsRead()
305         if extmap == nil {
306                 return false
307         }
308         mu.Lock()
309         _, ok := extmap[extension.Field]
310         mu.Unlock()
311         return ok
312 }
313
314 // ClearExtension removes the given extension from pb.
315 func ClearExtension(pb Message, extension *ExtensionDesc) {
316         clearExtension(pb, extension.Field)
317 }
318
319 func clearExtension(pb Message, fieldNum int32) {
320         if epb, ok := pb.(extensionsBytes); ok {
321                 offset := 0
322                 for offset != -1 {
323                         offset = deleteExtension(epb, fieldNum, offset)
324                 }
325                 return
326         }
327         epb, err := extendable(pb)
328         if err != nil {
329                 return
330         }
331         // TODO: Check types, field numbers, etc.?
332         extmap := epb.extensionsWrite()
333         delete(extmap, fieldNum)
334 }
335
336 // GetExtension retrieves a proto2 extended field from pb.
337 //
338 // If the descriptor is type complete (i.e., ExtensionDesc.ExtensionType is non-nil),
339 // then GetExtension parses the encoded field and returns a Go value of the specified type.
340 // If the field is not present, then the default value is returned (if one is specified),
341 // otherwise ErrMissingExtension is reported.
342 //
343 // If the descriptor is not type complete (i.e., ExtensionDesc.ExtensionType is nil),
344 // then GetExtension returns the raw encoded bytes of the field extension.
345 func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
346         if epb, doki := pb.(extensionsBytes); doki {
347                 ext := epb.GetExtensions()
348                 return decodeExtensionFromBytes(extension, *ext)
349         }
350
351         epb, err := extendable(pb)
352         if err != nil {
353                 return nil, err
354         }
355
356         if extension.ExtendedType != nil {
357                 // can only check type if this is a complete descriptor
358                 if cerr := checkExtensionTypes(epb, extension); cerr != nil {
359                         return nil, cerr
360                 }
361         }
362
363         emap, mu := epb.extensionsRead()
364         if emap == nil {
365                 return defaultExtensionValue(extension)
366         }
367         mu.Lock()
368         defer mu.Unlock()
369         e, ok := emap[extension.Field]
370         if !ok {
371                 // defaultExtensionValue returns the default value or
372                 // ErrMissingExtension if there is no default.
373                 return defaultExtensionValue(extension)
374         }
375
376         if e.value != nil {
377                 // Already decoded. Check the descriptor, though.
378                 if e.desc != extension {
379                         // This shouldn't happen. If it does, it means that
380                         // GetExtension was called twice with two different
381                         // descriptors with the same field number.
382                         return nil, errors.New("proto: descriptor conflict")
383                 }
384                 return e.value, nil
385         }
386
387         if extension.ExtensionType == nil {
388                 // incomplete descriptor
389                 return e.enc, nil
390         }
391
392         v, err := decodeExtension(e.enc, extension)
393         if err != nil {
394                 return nil, err
395         }
396
397         // Remember the decoded version and drop the encoded version.
398         // That way it is safe to mutate what we return.
399         e.value = v
400         e.desc = extension
401         e.enc = nil
402         emap[extension.Field] = e
403         return e.value, nil
404 }
405
406 // defaultExtensionValue returns the default value for extension.
407 // If no default for an extension is defined ErrMissingExtension is returned.
408 func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) {
409         if extension.ExtensionType == nil {
410                 // incomplete descriptor, so no default
411                 return nil, ErrMissingExtension
412         }
413
414         t := reflect.TypeOf(extension.ExtensionType)
415         props := extensionProperties(extension)
416
417         sf, _, err := fieldDefault(t, props)
418         if err != nil {
419                 return nil, err
420         }
421
422         if sf == nil || sf.value == nil {
423                 // There is no default value.
424                 return nil, ErrMissingExtension
425         }
426
427         if t.Kind() != reflect.Ptr {
428                 // We do not need to return a Ptr, we can directly return sf.value.
429                 return sf.value, nil
430         }
431
432         // We need to return an interface{} that is a pointer to sf.value.
433         value := reflect.New(t).Elem()
434         value.Set(reflect.New(value.Type().Elem()))
435         if sf.kind == reflect.Int32 {
436                 // We may have an int32 or an enum, but the underlying data is int32.
437                 // Since we can't set an int32 into a non int32 reflect.value directly
438                 // set it as a int32.
439                 value.Elem().SetInt(int64(sf.value.(int32)))
440         } else {
441                 value.Elem().Set(reflect.ValueOf(sf.value))
442         }
443         return value.Interface(), nil
444 }
445
446 // decodeExtension decodes an extension encoded in b.
447 func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
448         t := reflect.TypeOf(extension.ExtensionType)
449         unmarshal := typeUnmarshaler(t, extension.Tag)
450
451         // t is a pointer to a struct, pointer to basic type or a slice.
452         // Allocate space to store the pointer/slice.
453         value := reflect.New(t).Elem()
454
455         var err error
456         for {
457                 x, n := decodeVarint(b)
458                 if n == 0 {
459                         return nil, io.ErrUnexpectedEOF
460                 }
461                 b = b[n:]
462                 wire := int(x) & 7
463
464                 b, err = unmarshal(b, valToPointer(value.Addr()), wire)
465                 if err != nil {
466                         return nil, err
467                 }
468
469                 if len(b) == 0 {
470                         break
471                 }
472         }
473         return value.Interface(), nil
474 }
475
476 // GetExtensions returns a slice of the extensions present in pb that are also listed in es.
477 // The returned slice has the same length as es; missing extensions will appear as nil elements.
478 func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) {
479         epb, err := extendable(pb)
480         if err != nil {
481                 return nil, err
482         }
483         extensions = make([]interface{}, len(es))
484         for i, e := range es {
485                 extensions[i], err = GetExtension(epb, e)
486                 if err == ErrMissingExtension {
487                         err = nil
488                 }
489                 if err != nil {
490                         return
491                 }
492         }
493         return
494 }
495
496 // ExtensionDescs returns a new slice containing pb's extension descriptors, in undefined order.
497 // For non-registered extensions, ExtensionDescs returns an incomplete descriptor containing
498 // just the Field field, which defines the extension's field number.
499 func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) {
500         epb, err := extendable(pb)
501         if err != nil {
502                 return nil, err
503         }
504         registeredExtensions := RegisteredExtensions(pb)
505
506         emap, mu := epb.extensionsRead()
507         if emap == nil {
508                 return nil, nil
509         }
510         mu.Lock()
511         defer mu.Unlock()
512         extensions := make([]*ExtensionDesc, 0, len(emap))
513         for extid, e := range emap {
514                 desc := e.desc
515                 if desc == nil {
516                         desc = registeredExtensions[extid]
517                         if desc == nil {
518                                 desc = &ExtensionDesc{Field: extid}
519                         }
520                 }
521
522                 extensions = append(extensions, desc)
523         }
524         return extensions, nil
525 }
526
527 // SetExtension sets the specified extension of pb to the specified value.
528 func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error {
529         if epb, ok := pb.(extensionsBytes); ok {
530                 newb, err := encodeExtension(extension, value)
531                 if err != nil {
532                         return err
533                 }
534                 bb := epb.GetExtensions()
535                 *bb = append(*bb, newb...)
536                 return nil
537         }
538         epb, err := extendable(pb)
539         if err != nil {
540                 return err
541         }
542         if err := checkExtensionTypes(epb, extension); err != nil {
543                 return err
544         }
545         typ := reflect.TypeOf(extension.ExtensionType)
546         if typ != reflect.TypeOf(value) {
547                 return fmt.Errorf("proto: bad extension value type. got: %T, want: %T", value, extension.ExtensionType)
548         }
549         // nil extension values need to be caught early, because the
550         // encoder can't distinguish an ErrNil due to a nil extension
551         // from an ErrNil due to a missing field. Extensions are
552         // always optional, so the encoder would just swallow the error
553         // and drop all the extensions from the encoded message.
554         if reflect.ValueOf(value).IsNil() {
555                 return fmt.Errorf("proto: SetExtension called with nil value of type %T", value)
556         }
557
558         extmap := epb.extensionsWrite()
559         extmap[extension.Field] = Extension{desc: extension, value: value}
560         return nil
561 }
562
563 // ClearAllExtensions clears all extensions from pb.
564 func ClearAllExtensions(pb Message) {
565         if epb, doki := pb.(extensionsBytes); doki {
566                 ext := epb.GetExtensions()
567                 *ext = []byte{}
568                 return
569         }
570         epb, err := extendable(pb)
571         if err != nil {
572                 return
573         }
574         m := epb.extensionsWrite()
575         for k := range m {
576                 delete(m, k)
577         }
578 }
579
580 // A global registry of extensions.
581 // The generated code will register the generated descriptors by calling RegisterExtension.
582
583 var extensionMaps = make(map[reflect.Type]map[int32]*ExtensionDesc)
584
585 // RegisterExtension is called from the generated code.
586 func RegisterExtension(desc *ExtensionDesc) {
587         st := reflect.TypeOf(desc.ExtendedType).Elem()
588         m := extensionMaps[st]
589         if m == nil {
590                 m = make(map[int32]*ExtensionDesc)
591                 extensionMaps[st] = m
592         }
593         if _, ok := m[desc.Field]; ok {
594                 panic("proto: duplicate extension registered: " + st.String() + " " + strconv.Itoa(int(desc.Field)))
595         }
596         m[desc.Field] = desc
597 }
598
599 // RegisteredExtensions returns a map of the registered extensions of a
600 // protocol buffer struct, indexed by the extension number.
601 // The argument pb should be a nil pointer to the struct type.
602 func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc {
603         return extensionMaps[reflect.TypeOf(pb).Elem()]
604 }