OSDN Git Service

Hulk did something
[bytom/vapor.git] / vendor / github.com / stretchr / testify / _codegen / main.go
1 // This program reads all assertion functions from the assert package and
2 // automatically generates the corresponding requires and forwarded assertions
3
4 package main
5
6 import (
7         "bytes"
8         "flag"
9         "fmt"
10         "go/ast"
11         "go/build"
12         "go/doc"
13         "go/format"
14         "go/importer"
15         "go/parser"
16         "go/token"
17         "go/types"
18         "io"
19         "io/ioutil"
20         "log"
21         "os"
22         "path"
23         "regexp"
24         "strings"
25         "text/template"
26
27         "github.com/ernesto-jimenez/gogen/imports"
28 )
29
30 var (
31         pkg       = flag.String("assert-path", "github.com/stretchr/testify/assert", "Path to the assert package")
32         includeF  = flag.Bool("include-format-funcs", false, "include format functions such as Errorf and Equalf")
33         outputPkg = flag.String("output-package", "", "package for the resulting code")
34         tmplFile  = flag.String("template", "", "What file to load the function template from")
35         out       = flag.String("out", "", "What file to write the source code to")
36 )
37
38 func main() {
39         flag.Parse()
40
41         scope, docs, err := parsePackageSource(*pkg)
42         if err != nil {
43                 log.Fatal(err)
44         }
45
46         importer, funcs, err := analyzeCode(scope, docs)
47         if err != nil {
48                 log.Fatal(err)
49         }
50
51         if err := generateCode(importer, funcs); err != nil {
52                 log.Fatal(err)
53         }
54 }
55
56 func generateCode(importer imports.Importer, funcs []testFunc) error {
57         buff := bytes.NewBuffer(nil)
58
59         tmplHead, tmplFunc, err := parseTemplates()
60         if err != nil {
61                 return err
62         }
63
64         // Generate header
65         if err := tmplHead.Execute(buff, struct {
66                 Name    string
67                 Imports map[string]string
68         }{
69                 *outputPkg,
70                 importer.Imports(),
71         }); err != nil {
72                 return err
73         }
74
75         // Generate funcs
76         for _, fn := range funcs {
77                 buff.Write([]byte("\n\n"))
78                 if err := tmplFunc.Execute(buff, &fn); err != nil {
79                         return err
80                 }
81         }
82
83         code, err := format.Source(buff.Bytes())
84         if err != nil {
85                 return err
86         }
87
88         // Write file
89         output, err := outputFile()
90         if err != nil {
91                 return err
92         }
93         defer output.Close()
94         _, err = io.Copy(output, bytes.NewReader(code))
95         return err
96 }
97
98 func parseTemplates() (*template.Template, *template.Template, error) {
99         tmplHead, err := template.New("header").Parse(headerTemplate)
100         if err != nil {
101                 return nil, nil, err
102         }
103         if *tmplFile != "" {
104                 f, err := ioutil.ReadFile(*tmplFile)
105                 if err != nil {
106                         return nil, nil, err
107                 }
108                 funcTemplate = string(f)
109         }
110         tmpl, err := template.New("function").Parse(funcTemplate)
111         if err != nil {
112                 return nil, nil, err
113         }
114         return tmplHead, tmpl, nil
115 }
116
117 func outputFile() (*os.File, error) {
118         filename := *out
119         if filename == "-" || (filename == "" && *tmplFile == "") {
120                 return os.Stdout, nil
121         }
122         if filename == "" {
123                 filename = strings.TrimSuffix(strings.TrimSuffix(*tmplFile, ".tmpl"), ".go") + ".go"
124         }
125         return os.Create(filename)
126 }
127
128 // analyzeCode takes the types scope and the docs and returns the import
129 // information and information about all the assertion functions.
130 func analyzeCode(scope *types.Scope, docs *doc.Package) (imports.Importer, []testFunc, error) {
131         testingT := scope.Lookup("TestingT").Type().Underlying().(*types.Interface)
132
133         importer := imports.New(*outputPkg)
134         var funcs []testFunc
135         // Go through all the top level functions
136         for _, fdocs := range docs.Funcs {
137                 // Find the function
138                 obj := scope.Lookup(fdocs.Name)
139
140                 fn, ok := obj.(*types.Func)
141                 if !ok {
142                         continue
143                 }
144                 // Check function signature has at least two arguments
145                 sig := fn.Type().(*types.Signature)
146                 if sig.Params().Len() < 2 {
147                         continue
148                 }
149                 // Check first argument is of type testingT
150                 first, ok := sig.Params().At(0).Type().(*types.Named)
151                 if !ok {
152                         continue
153                 }
154                 firstType, ok := first.Underlying().(*types.Interface)
155                 if !ok {
156                         continue
157                 }
158                 if !types.Implements(firstType, testingT) {
159                         continue
160                 }
161
162                 // Skip functions ending with f
163                 if strings.HasSuffix(fdocs.Name, "f") && !*includeF {
164                         continue
165                 }
166
167                 funcs = append(funcs, testFunc{*outputPkg, fdocs, fn})
168                 importer.AddImportsFrom(sig.Params())
169         }
170         return importer, funcs, nil
171 }
172
173 // parsePackageSource returns the types scope and the package documentation from the package
174 func parsePackageSource(pkg string) (*types.Scope, *doc.Package, error) {
175         pd, err := build.Import(pkg, ".", 0)
176         if err != nil {
177                 return nil, nil, err
178         }
179
180         fset := token.NewFileSet()
181         files := make(map[string]*ast.File)
182         fileList := make([]*ast.File, len(pd.GoFiles))
183         for i, fname := range pd.GoFiles {
184                 src, err := ioutil.ReadFile(path.Join(pd.SrcRoot, pd.ImportPath, fname))
185                 if err != nil {
186                         return nil, nil, err
187                 }
188                 f, err := parser.ParseFile(fset, fname, src, parser.ParseComments|parser.AllErrors)
189                 if err != nil {
190                         return nil, nil, err
191                 }
192                 files[fname] = f
193                 fileList[i] = f
194         }
195
196         cfg := types.Config{
197                 Importer: importer.Default(),
198         }
199         info := types.Info{
200                 Defs: make(map[*ast.Ident]types.Object),
201         }
202         tp, err := cfg.Check(pkg, fset, fileList, &info)
203         if err != nil {
204                 return nil, nil, err
205         }
206
207         scope := tp.Scope()
208
209         ap, _ := ast.NewPackage(fset, files, nil, nil)
210         docs := doc.New(ap, pkg, 0)
211
212         return scope, docs, nil
213 }
214
215 type testFunc struct {
216         CurrentPkg string
217         DocInfo    *doc.Func
218         TypeInfo   *types.Func
219 }
220
221 func (f *testFunc) Qualifier(p *types.Package) string {
222         if p == nil || p.Name() == f.CurrentPkg {
223                 return ""
224         }
225         return p.Name()
226 }
227
228 func (f *testFunc) Params() string {
229         sig := f.TypeInfo.Type().(*types.Signature)
230         params := sig.Params()
231         p := ""
232         comma := ""
233         to := params.Len()
234         var i int
235
236         if sig.Variadic() {
237                 to--
238         }
239         for i = 1; i < to; i++ {
240                 param := params.At(i)
241                 p += fmt.Sprintf("%s%s %s", comma, param.Name(), types.TypeString(param.Type(), f.Qualifier))
242                 comma = ", "
243         }
244         if sig.Variadic() {
245                 param := params.At(params.Len() - 1)
246                 p += fmt.Sprintf("%s%s ...%s", comma, param.Name(), types.TypeString(param.Type().(*types.Slice).Elem(), f.Qualifier))
247         }
248         return p
249 }
250
251 func (f *testFunc) ForwardedParams() string {
252         sig := f.TypeInfo.Type().(*types.Signature)
253         params := sig.Params()
254         p := ""
255         comma := ""
256         to := params.Len()
257         var i int
258
259         if sig.Variadic() {
260                 to--
261         }
262         for i = 1; i < to; i++ {
263                 param := params.At(i)
264                 p += fmt.Sprintf("%s%s", comma, param.Name())
265                 comma = ", "
266         }
267         if sig.Variadic() {
268                 param := params.At(params.Len() - 1)
269                 p += fmt.Sprintf("%s%s...", comma, param.Name())
270         }
271         return p
272 }
273
274 func (f *testFunc) ParamsFormat() string {
275         return strings.Replace(f.Params(), "msgAndArgs", "msg string, args", 1)
276 }
277
278 func (f *testFunc) ForwardedParamsFormat() string {
279         return strings.Replace(f.ForwardedParams(), "msgAndArgs", "append([]interface{}{msg}, args...)", 1)
280 }
281
282 func (f *testFunc) Comment() string {
283         return "// " + strings.Replace(strings.TrimSpace(f.DocInfo.Doc), "\n", "\n// ", -1)
284 }
285
286 func (f *testFunc) CommentFormat() string {
287         search := fmt.Sprintf("%s", f.DocInfo.Name)
288         replace := fmt.Sprintf("%sf", f.DocInfo.Name)
289         comment := strings.Replace(f.Comment(), search, replace, -1)
290         exp := regexp.MustCompile(replace + `\(((\(\)|[^)])+)\)`)
291         return exp.ReplaceAllString(comment, replace+`($1, "error message %s", "formatted")`)
292 }
293
294 func (f *testFunc) CommentWithoutT(receiver string) string {
295         search := fmt.Sprintf("assert.%s(t, ", f.DocInfo.Name)
296         replace := fmt.Sprintf("%s.%s(", receiver, f.DocInfo.Name)
297         return strings.Replace(f.Comment(), search, replace, -1)
298 }
299
300 var headerTemplate = `/*
301 * CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen
302 * THIS FILE MUST NOT BE EDITED BY HAND
303 */
304
305 package {{.Name}}
306
307 import (
308 {{range $path, $name := .Imports}}
309         {{$name}} "{{$path}}"{{end}}
310 )
311 `
312
313 var funcTemplate = `{{.Comment}}
314 func (fwd *AssertionsForwarder) {{.DocInfo.Name}}({{.Params}}) bool {
315         return assert.{{.DocInfo.Name}}({{.ForwardedParams}})
316 }`