OSDN Git Service

Hulk did something
[bytom/vapor.git] / vendor / github.com / miekg / dns / msg_generate.go
1 //+build ignore
2
3 // msg_generate.go is meant to run with go generate. It will use
4 // go/{importer,types} to track down all the RR struct types. Then for each type
5 // it will generate pack/unpack methods based on the struct tags. The generated source is
6 // written to zmsg.go, and is meant to be checked into git.
7 package main
8
9 import (
10         "bytes"
11         "fmt"
12         "go/format"
13         "go/importer"
14         "go/types"
15         "log"
16         "os"
17         "strings"
18 )
19
20 var packageHdr = `
21 // Code generated by "go run msg_generate.go"; DO NOT EDIT.
22
23 package dns
24
25 `
26
27 // getTypeStruct will take a type and the package scope, and return the
28 // (innermost) struct if the type is considered a RR type (currently defined as
29 // those structs beginning with a RR_Header, could be redefined as implementing
30 // the RR interface). The bool return value indicates if embedded structs were
31 // resolved.
32 func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
33         st, ok := t.Underlying().(*types.Struct)
34         if !ok {
35                 return nil, false
36         }
37         if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
38                 return st, false
39         }
40         if st.Field(0).Anonymous() {
41                 st, _ := getTypeStruct(st.Field(0).Type(), scope)
42                 return st, true
43         }
44         return nil, false
45 }
46
47 func main() {
48         // Import and type-check the package
49         pkg, err := importer.Default().Import("github.com/miekg/dns")
50         fatalIfErr(err)
51         scope := pkg.Scope()
52
53         // Collect actual types (*X)
54         var namedTypes []string
55         for _, name := range scope.Names() {
56                 o := scope.Lookup(name)
57                 if o == nil || !o.Exported() {
58                         continue
59                 }
60                 if st, _ := getTypeStruct(o.Type(), scope); st == nil {
61                         continue
62                 }
63                 if name == "PrivateRR" {
64                         continue
65                 }
66
67                 // Check if corresponding TypeX exists
68                 if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" {
69                         log.Fatalf("Constant Type%s does not exist.", o.Name())
70                 }
71
72                 namedTypes = append(namedTypes, o.Name())
73         }
74
75         b := &bytes.Buffer{}
76         b.WriteString(packageHdr)
77
78         fmt.Fprint(b, "// pack*() functions\n\n")
79         for _, name := range namedTypes {
80                 o := scope.Lookup(name)
81                 st, _ := getTypeStruct(o.Type(), scope)
82
83                 fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {\n", name)
84                 for i := 1; i < st.NumFields(); i++ {
85                         o := func(s string) {
86                                 fmt.Fprintf(b, s, st.Field(i).Name())
87                                 fmt.Fprint(b, `if err != nil {
88 return off, err
89 }
90 `)
91                         }
92
93                         if _, ok := st.Field(i).Type().(*types.Slice); ok {
94                                 switch st.Tag(i) {
95                                 case `dns:"-"`: // ignored
96                                 case `dns:"txt"`:
97                                         o("off, err = packStringTxt(rr.%s, msg, off)\n")
98                                 case `dns:"opt"`:
99                                         o("off, err = packDataOpt(rr.%s, msg, off)\n")
100                                 case `dns:"nsec"`:
101                                         o("off, err = packDataNsec(rr.%s, msg, off)\n")
102                                 case `dns:"domain-name"`:
103                                         o("off, err = packDataDomainNames(rr.%s, msg, off, compression, false)\n")
104                                 default:
105                                         log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
106                                 }
107                                 continue
108                         }
109
110                         switch {
111                         case st.Tag(i) == `dns:"-"`: // ignored
112                         case st.Tag(i) == `dns:"cdomain-name"`:
113                                 o("off, err = packDomainName(rr.%s, msg, off, compression, compress)\n")
114                         case st.Tag(i) == `dns:"domain-name"`:
115                                 o("off, err = packDomainName(rr.%s, msg, off, compression, false)\n")
116                         case st.Tag(i) == `dns:"a"`:
117                                 o("off, err = packDataA(rr.%s, msg, off)\n")
118                         case st.Tag(i) == `dns:"aaaa"`:
119                                 o("off, err = packDataAAAA(rr.%s, msg, off)\n")
120                         case st.Tag(i) == `dns:"uint48"`:
121                                 o("off, err = packUint48(rr.%s, msg, off)\n")
122                         case st.Tag(i) == `dns:"txt"`:
123                                 o("off, err = packString(rr.%s, msg, off)\n")
124
125                         case strings.HasPrefix(st.Tag(i), `dns:"size-base32`): // size-base32 can be packed just like base32
126                                 fallthrough
127                         case st.Tag(i) == `dns:"base32"`:
128                                 o("off, err = packStringBase32(rr.%s, msg, off)\n")
129
130                         case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): // size-base64 can be packed just like base64
131                                 fallthrough
132                         case st.Tag(i) == `dns:"base64"`:
133                                 o("off, err = packStringBase64(rr.%s, msg, off)\n")
134
135                         case strings.HasPrefix(st.Tag(i), `dns:"size-hex:SaltLength`):
136                                 // directly write instead of using o() so we get the error check in the correct place
137                                 field := st.Field(i).Name()
138                                 fmt.Fprintf(b, `// Only pack salt if value is not "-", i.e. empty
139 if rr.%s != "-" {
140   off, err = packStringHex(rr.%s, msg, off)
141   if err != nil {
142     return off, err
143   }
144 }
145 `, field, field)
146                                 continue
147                         case strings.HasPrefix(st.Tag(i), `dns:"size-hex`): // size-hex can be packed just like hex
148                                 fallthrough
149                         case st.Tag(i) == `dns:"hex"`:
150                                 o("off, err = packStringHex(rr.%s, msg, off)\n")
151                         case st.Tag(i) == `dns:"any"`:
152                                 o("off, err = packStringAny(rr.%s, msg, off)\n")
153                         case st.Tag(i) == `dns:"octet"`:
154                                 o("off, err = packStringOctet(rr.%s, msg, off)\n")
155                         case st.Tag(i) == "":
156                                 switch st.Field(i).Type().(*types.Basic).Kind() {
157                                 case types.Uint8:
158                                         o("off, err = packUint8(rr.%s, msg, off)\n")
159                                 case types.Uint16:
160                                         o("off, err = packUint16(rr.%s, msg, off)\n")
161                                 case types.Uint32:
162                                         o("off, err = packUint32(rr.%s, msg, off)\n")
163                                 case types.Uint64:
164                                         o("off, err = packUint64(rr.%s, msg, off)\n")
165                                 case types.String:
166                                         o("off, err = packString(rr.%s, msg, off)\n")
167                                 default:
168                                         log.Fatalln(name, st.Field(i).Name())
169                                 }
170                         default:
171                                 log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
172                         }
173                 }
174                 fmt.Fprintln(b, "return off, nil }\n")
175         }
176
177         fmt.Fprint(b, "// unpack*() functions\n\n")
178         for _, name := range namedTypes {
179                 o := scope.Lookup(name)
180                 st, _ := getTypeStruct(o.Type(), scope)
181
182                 fmt.Fprintf(b, "func (rr *%s) unpack(msg []byte, off int) (off1 int, err error) {\n", name)
183                 fmt.Fprint(b, `rdStart := off
184 _ = rdStart
185
186 `)
187                 for i := 1; i < st.NumFields(); i++ {
188                         o := func(s string) {
189                                 fmt.Fprintf(b, s, st.Field(i).Name())
190                                 fmt.Fprint(b, `if err != nil {
191 return off, err
192 }
193 `)
194                         }
195
196                         // size-* are special, because they reference a struct member we should use for the length.
197                         if strings.HasPrefix(st.Tag(i), `dns:"size-`) {
198                                 structMember := structMember(st.Tag(i))
199                                 structTag := structTag(st.Tag(i))
200                                 switch structTag {
201                                 case "hex":
202                                         fmt.Fprintf(b, "rr.%s, off, err = unpackStringHex(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
203                                 case "base32":
204                                         fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase32(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
205                                 case "base64":
206                                         fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase64(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
207                                 default:
208                                         log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
209                                 }
210                                 fmt.Fprint(b, `if err != nil {
211 return off, err
212 }
213 `)
214                                 continue
215                         }
216
217                         if _, ok := st.Field(i).Type().(*types.Slice); ok {
218                                 switch st.Tag(i) {
219                                 case `dns:"-"`: // ignored
220                                 case `dns:"txt"`:
221                                         o("rr.%s, off, err = unpackStringTxt(msg, off)\n")
222                                 case `dns:"opt"`:
223                                         o("rr.%s, off, err = unpackDataOpt(msg, off)\n")
224                                 case `dns:"nsec"`:
225                                         o("rr.%s, off, err = unpackDataNsec(msg, off)\n")
226                                 case `dns:"domain-name"`:
227                                         o("rr.%s, off, err = unpackDataDomainNames(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
228                                 default:
229                                         log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
230                                 }
231                                 continue
232                         }
233
234                         switch st.Tag(i) {
235                         case `dns:"-"`: // ignored
236                         case `dns:"cdomain-name"`:
237                                 fallthrough
238                         case `dns:"domain-name"`:
239                                 o("rr.%s, off, err = UnpackDomainName(msg, off)\n")
240                         case `dns:"a"`:
241                                 o("rr.%s, off, err = unpackDataA(msg, off)\n")
242                         case `dns:"aaaa"`:
243                                 o("rr.%s, off, err = unpackDataAAAA(msg, off)\n")
244                         case `dns:"uint48"`:
245                                 o("rr.%s, off, err = unpackUint48(msg, off)\n")
246                         case `dns:"txt"`:
247                                 o("rr.%s, off, err = unpackString(msg, off)\n")
248                         case `dns:"base32"`:
249                                 o("rr.%s, off, err = unpackStringBase32(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
250                         case `dns:"base64"`:
251                                 o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
252                         case `dns:"hex"`:
253                                 o("rr.%s, off, err = unpackStringHex(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
254                         case `dns:"any"`:
255                                 o("rr.%s, off, err = unpackStringAny(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
256                         case `dns:"octet"`:
257                                 o("rr.%s, off, err = unpackStringOctet(msg, off)\n")
258                         case "":
259                                 switch st.Field(i).Type().(*types.Basic).Kind() {
260                                 case types.Uint8:
261                                         o("rr.%s, off, err = unpackUint8(msg, off)\n")
262                                 case types.Uint16:
263                                         o("rr.%s, off, err = unpackUint16(msg, off)\n")
264                                 case types.Uint32:
265                                         o("rr.%s, off, err = unpackUint32(msg, off)\n")
266                                 case types.Uint64:
267                                         o("rr.%s, off, err = unpackUint64(msg, off)\n")
268                                 case types.String:
269                                         o("rr.%s, off, err = unpackString(msg, off)\n")
270                                 default:
271                                         log.Fatalln(name, st.Field(i).Name())
272                                 }
273                         default:
274                                 log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
275                         }
276                         // If we've hit len(msg) we return without error.
277                         if i < st.NumFields()-1 {
278                                 fmt.Fprintf(b, `if off == len(msg) {
279 return off, nil
280         }
281 `)
282                         }
283                 }
284                 fmt.Fprintf(b, "return off, nil }\n\n")
285         }
286
287         // gofmt
288         res, err := format.Source(b.Bytes())
289         if err != nil {
290                 b.WriteTo(os.Stderr)
291                 log.Fatal(err)
292         }
293
294         // write result
295         f, err := os.Create("zmsg.go")
296         fatalIfErr(err)
297         defer f.Close()
298         f.Write(res)
299 }
300
301 // structMember will take a tag like dns:"size-base32:SaltLength" and return the last part of this string.
302 func structMember(s string) string {
303         fields := strings.Split(s, ":")
304         if len(fields) == 0 {
305                 return ""
306         }
307         f := fields[len(fields)-1]
308         // f should have a closing "
309         if len(f) > 1 {
310                 return f[:len(f)-1]
311         }
312         return f
313 }
314
315 // structTag will take a tag like dns:"size-base32:SaltLength" and return base32.
316 func structTag(s string) string {
317         fields := strings.Split(s, ":")
318         if len(fields) < 2 {
319                 return ""
320         }
321         return fields[1][len("\"size-"):]
322 }
323
324 func fatalIfErr(err error) {
325         if err != nil {
326                 log.Fatal(err)
327         }
328 }