OSDN Git Service

Hulk did something
[bytom/vapor.git] / net / http / httpjson / handler_test.go
diff --git a/net/http/httpjson/handler_test.go b/net/http/httpjson/handler_test.go
new file mode 100644 (file)
index 0000000..e2b8559
--- /dev/null
@@ -0,0 +1,144 @@
+package httpjson
+
+import (
+       "context"
+       "net/http"
+       "net/http/httptest"
+       "reflect"
+       "strings"
+       "testing"
+       "testing/iotest"
+
+       "github.com/vapor/errors"
+)
+
+func TestHandler(t *testing.T) {
+       errX := errors.New("x")
+
+       cases := []struct {
+               rawQuery string
+               input    string
+               output   string
+               f        interface{}
+               wantErr  error
+       }{
+               {"", ``, `{"message":"ok"}`, func() {}, nil},
+               {"", ``, `1`, func() int { return 1 }, nil},
+               {"", ``, `{"message":"ok"}`, func() error { return nil }, nil},
+               {"", ``, ``, func() error { return errX }, errX},
+               {"", ``, `1`, func() (int, error) { return 1, nil }, nil},
+               {"", ``, ``, func() (int, error) { return 0, errX }, errX},
+               {"", `1`, `1`, func(i int) int { return i }, nil},
+               {"", `1`, `1`, func(i *int) int { return *i }, nil},
+               {"", `"foo"`, `"foo"`, func(s string) string { return s }, nil},
+               {"", `{"x":1}`, `1`, func(x struct{ X int }) int { return x.X }, nil},
+               {"", `{"x":1}`, `1`, func(x *struct{ X int }) int { return x.X }, nil},
+               {"", ``, `1`, func(ctx context.Context) int { return ctx.Value("k").(int) }, nil},
+       }
+
+       for _, test := range cases {
+               var gotErr error
+               errFunc := func(ctx context.Context, w http.ResponseWriter, err error) {
+                       gotErr = err
+               }
+               h, err := Handler(test.f, errFunc)
+               if err != nil {
+                       t.Errorf("Handler(%v) got err %v", test.f, err)
+                       continue
+               }
+
+               resp := httptest.NewRecorder()
+               req, _ := http.NewRequest("GET", "/", strings.NewReader(test.input))
+               req.URL.RawQuery = test.rawQuery
+               ctx := context.WithValue(context.Background(), "k", 1)
+               h.ServeHTTP(resp, req.WithContext(ctx))
+               if resp.Code != 200 {
+                       t.Errorf("%T response code = %d want 200", test.f, resp.Code)
+               }
+               got := strings.TrimSpace(resp.Body.String())
+               if got != test.output {
+                       t.Errorf("%T response body = %#q want %#q", test.f, got, test.output)
+               }
+               if gotErr != test.wantErr {
+                       t.Errorf("%T err = %v want %v", test.f, gotErr, test.wantErr)
+               }
+       }
+}
+
+func TestReadErr(t *testing.T) {
+       var gotErr error
+       errFunc := func(ctx context.Context, w http.ResponseWriter, err error) {
+               gotErr = errors.Root(err)
+       }
+       h, _ := Handler(func(int) {}, errFunc)
+
+       resp := httptest.NewRecorder()
+       body := iotest.OneByteReader(iotest.TimeoutReader(strings.NewReader("123456")))
+       req, _ := http.NewRequest("GET", "/", body)
+       h.ServeHTTP(resp, req)
+       if got := resp.Body.Len(); got != 0 {
+               t.Errorf("len(response) = %d want 0", got)
+       }
+       wantErr := ErrBadRequest
+       if gotErr != wantErr {
+               t.Errorf("err = %v want %v", gotErr, wantErr)
+       }
+}
+
+func TestFuncInputTypeError(t *testing.T) {
+       cases := []interface{}{
+               0,
+               "foo",
+               func() (int, int) { return 0, 0 },
+               func(string, int) {},
+               func() (int, int, error) { return 0, 0, nil },
+       }
+
+       for _, testf := range cases {
+               _, _, err := funcInputType(reflect.ValueOf(testf))
+               if err == nil {
+                       t.Errorf("funcInputType(%T) want error", testf)
+               }
+
+               _, err = Handler(testf, nil)
+               if err == nil {
+                       t.Errorf("funcInputType(%T) want error", testf)
+               }
+       }
+}
+
+var (
+       intType    = reflect.TypeOf(0)
+       intpType   = reflect.TypeOf((*int)(nil))
+       stringType = reflect.TypeOf("")
+)
+
+func TestFuncInputTypeOk(t *testing.T) {
+       cases := []struct {
+               f       interface{}
+               wantCtx bool
+               wantT   reflect.Type
+       }{
+               {func() {}, false, nil},
+               {func() int { return 0 }, false, nil},
+               {func() error { return nil }, false, nil},
+               {func() (int, error) { return 0, nil }, false, nil},
+               {func(int) {}, false, intType},
+               {func(*int) {}, false, intpType},
+               {func(context.Context) {}, true, nil},
+               {func(string) {}, false, stringType}, // req body is string
+       }
+
+       for _, test := range cases {
+               gotCtx, gotT, err := funcInputType(reflect.ValueOf(test.f))
+               if err != nil {
+                       t.Errorf("funcInputType(%T) got error: %v", test.f, err)
+               }
+               if gotCtx != test.wantCtx {
+                       t.Errorf("funcInputType(%T) context = %v want %v", test.f, gotCtx, test.wantCtx)
+               }
+               if gotT != test.wantT {
+                       t.Errorf("funcInputType(%T) = %v want %v", test.f, gotT, test.wantT)
+               }
+       }
+}