OSDN Git Service

e2b8559bcd9af5dbbc17f9697fae942a967ba11e
[bytom/vapor.git] / net / http / httpjson / handler_test.go
1 package httpjson
2
3 import (
4         "context"
5         "net/http"
6         "net/http/httptest"
7         "reflect"
8         "strings"
9         "testing"
10         "testing/iotest"
11
12         "github.com/vapor/errors"
13 )
14
15 func TestHandler(t *testing.T) {
16         errX := errors.New("x")
17
18         cases := []struct {
19                 rawQuery string
20                 input    string
21                 output   string
22                 f        interface{}
23                 wantErr  error
24         }{
25                 {"", ``, `{"message":"ok"}`, func() {}, nil},
26                 {"", ``, `1`, func() int { return 1 }, nil},
27                 {"", ``, `{"message":"ok"}`, func() error { return nil }, nil},
28                 {"", ``, ``, func() error { return errX }, errX},
29                 {"", ``, `1`, func() (int, error) { return 1, nil }, nil},
30                 {"", ``, ``, func() (int, error) { return 0, errX }, errX},
31                 {"", `1`, `1`, func(i int) int { return i }, nil},
32                 {"", `1`, `1`, func(i *int) int { return *i }, nil},
33                 {"", `"foo"`, `"foo"`, func(s string) string { return s }, nil},
34                 {"", `{"x":1}`, `1`, func(x struct{ X int }) int { return x.X }, nil},
35                 {"", `{"x":1}`, `1`, func(x *struct{ X int }) int { return x.X }, nil},
36                 {"", ``, `1`, func(ctx context.Context) int { return ctx.Value("k").(int) }, nil},
37         }
38
39         for _, test := range cases {
40                 var gotErr error
41                 errFunc := func(ctx context.Context, w http.ResponseWriter, err error) {
42                         gotErr = err
43                 }
44                 h, err := Handler(test.f, errFunc)
45                 if err != nil {
46                         t.Errorf("Handler(%v) got err %v", test.f, err)
47                         continue
48                 }
49
50                 resp := httptest.NewRecorder()
51                 req, _ := http.NewRequest("GET", "/", strings.NewReader(test.input))
52                 req.URL.RawQuery = test.rawQuery
53                 ctx := context.WithValue(context.Background(), "k", 1)
54                 h.ServeHTTP(resp, req.WithContext(ctx))
55                 if resp.Code != 200 {
56                         t.Errorf("%T response code = %d want 200", test.f, resp.Code)
57                 }
58                 got := strings.TrimSpace(resp.Body.String())
59                 if got != test.output {
60                         t.Errorf("%T response body = %#q want %#q", test.f, got, test.output)
61                 }
62                 if gotErr != test.wantErr {
63                         t.Errorf("%T err = %v want %v", test.f, gotErr, test.wantErr)
64                 }
65         }
66 }
67
68 func TestReadErr(t *testing.T) {
69         var gotErr error
70         errFunc := func(ctx context.Context, w http.ResponseWriter, err error) {
71                 gotErr = errors.Root(err)
72         }
73         h, _ := Handler(func(int) {}, errFunc)
74
75         resp := httptest.NewRecorder()
76         body := iotest.OneByteReader(iotest.TimeoutReader(strings.NewReader("123456")))
77         req, _ := http.NewRequest("GET", "/", body)
78         h.ServeHTTP(resp, req)
79         if got := resp.Body.Len(); got != 0 {
80                 t.Errorf("len(response) = %d want 0", got)
81         }
82         wantErr := ErrBadRequest
83         if gotErr != wantErr {
84                 t.Errorf("err = %v want %v", gotErr, wantErr)
85         }
86 }
87
88 func TestFuncInputTypeError(t *testing.T) {
89         cases := []interface{}{
90                 0,
91                 "foo",
92                 func() (int, int) { return 0, 0 },
93                 func(string, int) {},
94                 func() (int, int, error) { return 0, 0, nil },
95         }
96
97         for _, testf := range cases {
98                 _, _, err := funcInputType(reflect.ValueOf(testf))
99                 if err == nil {
100                         t.Errorf("funcInputType(%T) want error", testf)
101                 }
102
103                 _, err = Handler(testf, nil)
104                 if err == nil {
105                         t.Errorf("funcInputType(%T) want error", testf)
106                 }
107         }
108 }
109
110 var (
111         intType    = reflect.TypeOf(0)
112         intpType   = reflect.TypeOf((*int)(nil))
113         stringType = reflect.TypeOf("")
114 )
115
116 func TestFuncInputTypeOk(t *testing.T) {
117         cases := []struct {
118                 f       interface{}
119                 wantCtx bool
120                 wantT   reflect.Type
121         }{
122                 {func() {}, false, nil},
123                 {func() int { return 0 }, false, nil},
124                 {func() error { return nil }, false, nil},
125                 {func() (int, error) { return 0, nil }, false, nil},
126                 {func(int) {}, false, intType},
127                 {func(*int) {}, false, intpType},
128                 {func(context.Context) {}, true, nil},
129                 {func(string) {}, false, stringType}, // req body is string
130         }
131
132         for _, test := range cases {
133                 gotCtx, gotT, err := funcInputType(reflect.ValueOf(test.f))
134                 if err != nil {
135                         t.Errorf("funcInputType(%T) got error: %v", test.f, err)
136                 }
137                 if gotCtx != test.wantCtx {
138                         t.Errorf("funcInputType(%T) context = %v want %v", test.f, gotCtx, test.wantCtx)
139                 }
140                 if gotT != test.wantT {
141                         t.Errorf("funcInputType(%T) = %v want %v", test.f, gotT, test.wantT)
142                 }
143         }
144 }