OSDN Git Service

Hulk did something
[bytom/vapor.git] / vendor / github.com / miekg / dns / types_generate.go
1 //+build ignore
2
3 // types_generate.go is meant to run with go generate. It will use
4 // go/{importer,types} to track down all the RR struct types. Then for each type
5 // it will generate conversion tables (TypeToRR and TypeToString) and banal
6 // methods (len, Header, copy) based on the struct tags. The generated source is
7 // written to ztypes.go, and is meant to be checked into git.
8 package main
9
10 import (
11         "bytes"
12         "fmt"
13         "go/format"
14         "go/importer"
15         "go/types"
16         "log"
17         "os"
18         "strings"
19         "text/template"
20 )
21
22 var skipLen = map[string]struct{}{
23         "NSEC":  {},
24         "NSEC3": {},
25         "OPT":   {},
26         "CSYNC": {},
27 }
28
29 var packageHdr = `
30 // Code generated by "go run types_generate.go"; DO NOT EDIT.
31
32 package dns
33
34 import (
35         "encoding/base64"
36         "net"
37 )
38
39 `
40
41 var TypeToRR = template.Must(template.New("TypeToRR").Parse(`
42 // TypeToRR is a map of constructors for each RR type.
43 var TypeToRR = map[uint16]func() RR{
44 {{range .}}{{if ne . "RFC3597"}}  Type{{.}}:  func() RR { return new({{.}}) },
45 {{end}}{{end}}                    }
46
47 `))
48
49 var typeToString = template.Must(template.New("typeToString").Parse(`
50 // TypeToString is a map of strings for each RR type.
51 var TypeToString = map[uint16]string{
52 {{range .}}{{if ne . "NSAPPTR"}}  Type{{.}}: "{{.}}",
53 {{end}}{{end}}                    TypeNSAPPTR:    "NSAP-PTR",
54 }
55
56 `))
57
58 var headerFunc = template.Must(template.New("headerFunc").Parse(`
59 {{range .}}  func (rr *{{.}}) Header() *RR_Header { return &rr.Hdr }
60 {{end}}
61
62 `))
63
64 // getTypeStruct will take a type and the package scope, and return the
65 // (innermost) struct if the type is considered a RR type (currently defined as
66 // those structs beginning with a RR_Header, could be redefined as implementing
67 // the RR interface). The bool return value indicates if embedded structs were
68 // resolved.
69 func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
70         st, ok := t.Underlying().(*types.Struct)
71         if !ok {
72                 return nil, false
73         }
74         if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
75                 return st, false
76         }
77         if st.Field(0).Anonymous() {
78                 st, _ := getTypeStruct(st.Field(0).Type(), scope)
79                 return st, true
80         }
81         return nil, false
82 }
83
84 func main() {
85         // Import and type-check the package
86         pkg, err := importer.Default().Import("github.com/miekg/dns")
87         fatalIfErr(err)
88         scope := pkg.Scope()
89
90         // Collect constants like TypeX
91         var numberedTypes []string
92         for _, name := range scope.Names() {
93                 o := scope.Lookup(name)
94                 if o == nil || !o.Exported() {
95                         continue
96                 }
97                 b, ok := o.Type().(*types.Basic)
98                 if !ok || b.Kind() != types.Uint16 {
99                         continue
100                 }
101                 if !strings.HasPrefix(o.Name(), "Type") {
102                         continue
103                 }
104                 name := strings.TrimPrefix(o.Name(), "Type")
105                 if name == "PrivateRR" {
106                         continue
107                 }
108                 numberedTypes = append(numberedTypes, name)
109         }
110
111         // Collect actual types (*X)
112         var namedTypes []string
113         for _, name := range scope.Names() {
114                 o := scope.Lookup(name)
115                 if o == nil || !o.Exported() {
116                         continue
117                 }
118                 if st, _ := getTypeStruct(o.Type(), scope); st == nil {
119                         continue
120                 }
121                 if name == "PrivateRR" {
122                         continue
123                 }
124
125                 // Check if corresponding TypeX exists
126                 if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" {
127                         log.Fatalf("Constant Type%s does not exist.", o.Name())
128                 }
129
130                 namedTypes = append(namedTypes, o.Name())
131         }
132
133         b := &bytes.Buffer{}
134         b.WriteString(packageHdr)
135
136         // Generate TypeToRR
137         fatalIfErr(TypeToRR.Execute(b, namedTypes))
138
139         // Generate typeToString
140         fatalIfErr(typeToString.Execute(b, numberedTypes))
141
142         // Generate headerFunc
143         fatalIfErr(headerFunc.Execute(b, namedTypes))
144
145         // Generate len()
146         fmt.Fprint(b, "// len() functions\n")
147         for _, name := range namedTypes {
148                 if _, ok := skipLen[name]; ok {
149                         continue
150                 }
151                 o := scope.Lookup(name)
152                 st, isEmbedded := getTypeStruct(o.Type(), scope)
153                 if isEmbedded {
154                         continue
155                 }
156                 fmt.Fprintf(b, "func (rr *%s) len(off int, compression map[string]struct{}) int {\n", name)
157                 fmt.Fprintf(b, "l := rr.Hdr.len(off, compression)\n")
158                 for i := 1; i < st.NumFields(); i++ {
159                         o := func(s string) { fmt.Fprintf(b, s, st.Field(i).Name()) }
160
161                         if _, ok := st.Field(i).Type().(*types.Slice); ok {
162                                 switch st.Tag(i) {
163                                 case `dns:"-"`:
164                                         // ignored
165                                 case `dns:"cdomain-name"`:
166                                         o("for _, x := range rr.%s { l += domainNameLen(x, off+l, compression, true) }\n")
167                                 case `dns:"domain-name"`:
168                                         o("for _, x := range rr.%s { l += domainNameLen(x, off+l, compression, false) }\n")
169                                 case `dns:"txt"`:
170                                         o("for _, x := range rr.%s { l += len(x) + 1 }\n")
171                                 default:
172                                         log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
173                                 }
174                                 continue
175                         }
176
177                         switch {
178                         case st.Tag(i) == `dns:"-"`:
179                                 // ignored
180                         case st.Tag(i) == `dns:"cdomain-name"`:
181                                 o("l += domainNameLen(rr.%s, off+l, compression, true)\n")
182                         case st.Tag(i) == `dns:"domain-name"`:
183                                 o("l += domainNameLen(rr.%s, off+l, compression, false)\n")
184                         case st.Tag(i) == `dns:"octet"`:
185                                 o("l += len(rr.%s)\n")
186                         case strings.HasPrefix(st.Tag(i), `dns:"size-base64`):
187                                 fallthrough
188                         case st.Tag(i) == `dns:"base64"`:
189                                 o("l += base64.StdEncoding.DecodedLen(len(rr.%s))\n")
190                         case strings.HasPrefix(st.Tag(i), `dns:"size-hex:`): // this has an extra field where the length is stored
191                                 o("l += len(rr.%s)/2\n")
192                         case strings.HasPrefix(st.Tag(i), `dns:"size-hex`):
193                                 fallthrough
194                         case st.Tag(i) == `dns:"hex"`:
195                                 o("l += len(rr.%s)/2 + 1\n")
196                         case st.Tag(i) == `dns:"any"`:
197                                 o("l += len(rr.%s)\n")
198                         case st.Tag(i) == `dns:"a"`:
199                                 o("l += net.IPv4len // %s\n")
200                         case st.Tag(i) == `dns:"aaaa"`:
201                                 o("l += net.IPv6len // %s\n")
202                         case st.Tag(i) == `dns:"txt"`:
203                                 o("for _, t := range rr.%s { l += len(t) + 1 }\n")
204                         case st.Tag(i) == `dns:"uint48"`:
205                                 o("l += 6 // %s\n")
206                         case st.Tag(i) == "":
207                                 switch st.Field(i).Type().(*types.Basic).Kind() {
208                                 case types.Uint8:
209                                         o("l++ // %s\n")
210                                 case types.Uint16:
211                                         o("l += 2 // %s\n")
212                                 case types.Uint32:
213                                         o("l += 4 // %s\n")
214                                 case types.Uint64:
215                                         o("l += 8 // %s\n")
216                                 case types.String:
217                                         o("l += len(rr.%s) + 1\n")
218                                 default:
219                                         log.Fatalln(name, st.Field(i).Name())
220                                 }
221                         default:
222                                 log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
223                         }
224                 }
225                 fmt.Fprintf(b, "return l }\n")
226         }
227
228         // Generate copy()
229         fmt.Fprint(b, "// copy() functions\n")
230         for _, name := range namedTypes {
231                 o := scope.Lookup(name)
232                 st, isEmbedded := getTypeStruct(o.Type(), scope)
233                 if isEmbedded {
234                         continue
235                 }
236                 fmt.Fprintf(b, "func (rr *%s) copy() RR {\n", name)
237                 fields := []string{"rr.Hdr"}
238                 for i := 1; i < st.NumFields(); i++ {
239                         f := st.Field(i).Name()
240                         if sl, ok := st.Field(i).Type().(*types.Slice); ok {
241                                 t := sl.Underlying().String()
242                                 t = strings.TrimPrefix(t, "[]")
243                                 if strings.Contains(t, ".") {
244                                         splits := strings.Split(t, ".")
245                                         t = splits[len(splits)-1]
246                                 }
247                                 // For the EDNS0 interface (used in the OPT RR), we need to call the copy method on each element.
248                                 if t == "EDNS0" {
249                                         fmt.Fprintf(b, "%s := make([]%s, len(rr.%s));\nfor i,e := range rr.%s {\n %s[i] = e.copy()\n}\n",
250                                                 f, t, f, f, f)
251                                         fields = append(fields, f)
252                                         continue
253                                 }
254                                 fmt.Fprintf(b, "%s := make([]%s, len(rr.%s)); copy(%s, rr.%s)\n",
255                                         f, t, f, f, f)
256                                 fields = append(fields, f)
257                                 continue
258                         }
259                         if st.Field(i).Type().String() == "net.IP" {
260                                 fields = append(fields, "copyIP(rr."+f+")")
261                                 continue
262                         }
263                         fields = append(fields, "rr."+f)
264                 }
265                 fmt.Fprintf(b, "return &%s{%s}\n", name, strings.Join(fields, ","))
266                 fmt.Fprintf(b, "}\n")
267         }
268
269         // gofmt
270         res, err := format.Source(b.Bytes())
271         if err != nil {
272                 b.WriteTo(os.Stderr)
273                 log.Fatal(err)
274         }
275
276         // write result
277         f, err := os.Create("ztypes.go")
278         fatalIfErr(err)
279         defer f.Close()
280         f.Write(res)
281 }
282
283 func fatalIfErr(err error) {
284         if err != nil {
285                 log.Fatal(err)
286         }
287 }