OSDN Git Service

new repo
[bytom/vapor.git] / vendor / github.com / go-kit / kit / transport / http / client_test.go
1 package http_test
2
3 import (
4         "context"
5         "io"
6         "io/ioutil"
7         "net/http"
8         "net/http/httptest"
9         "net/url"
10         "testing"
11         "time"
12
13         httptransport "github.com/go-kit/kit/transport/http"
14 )
15
16 type TestResponse struct {
17         Body   io.ReadCloser
18         String string
19 }
20
21 func TestHTTPClient(t *testing.T) {
22         var (
23                 testbody = "testbody"
24                 encode   = func(context.Context, *http.Request, interface{}) error { return nil }
25                 decode   = func(_ context.Context, r *http.Response) (interface{}, error) {
26                         buffer := make([]byte, len(testbody))
27                         r.Body.Read(buffer)
28                         return TestResponse{r.Body, string(buffer)}, nil
29                 }
30                 headers        = make(chan string, 1)
31                 headerKey      = "X-Foo"
32                 headerVal      = "abcde"
33                 afterHeaderKey = "X-The-Dude"
34                 afterHeaderVal = "Abides"
35                 afterVal       = ""
36                 afterFunc      = func(ctx context.Context, r *http.Response) context.Context {
37                         afterVal = r.Header.Get(afterHeaderKey)
38                         return ctx
39                 }
40         )
41
42         server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
43                 headers <- r.Header.Get(headerKey)
44                 w.Header().Set(afterHeaderKey, afterHeaderVal)
45                 w.WriteHeader(http.StatusOK)
46                 w.Write([]byte(testbody))
47         }))
48
49         client := httptransport.NewClient(
50                 "GET",
51                 mustParse(server.URL),
52                 encode,
53                 decode,
54                 httptransport.ClientBefore(httptransport.SetRequestHeader(headerKey, headerVal)),
55                 httptransport.ClientAfter(afterFunc),
56         )
57
58         res, err := client.Endpoint()(context.Background(), struct{}{})
59         if err != nil {
60                 t.Fatal(err)
61         }
62
63         var have string
64         select {
65         case have = <-headers:
66         case <-time.After(time.Millisecond):
67                 t.Fatalf("timeout waiting for %s", headerKey)
68         }
69         // Check that Request Header was successfully received
70         if want := headerVal; want != have {
71                 t.Errorf("want %q, have %q", want, have)
72         }
73
74         // Check that Response header set from server was received in SetClientAfter
75         if want, have := afterVal, afterHeaderVal; want != have {
76                 t.Errorf("want %q, have %q", want, have)
77         }
78
79         // Check that the response was successfully decoded
80         response, ok := res.(TestResponse)
81         if !ok {
82                 t.Fatal("response should be TestResponse")
83         }
84         if want, have := testbody, response.String; want != have {
85                 t.Errorf("want %q, have %q", want, have)
86         }
87
88         // Check that response body was closed
89         b := make([]byte, 1)
90         _, err = response.Body.Read(b)
91         if err == nil {
92                 t.Fatal("wanted error, got none")
93         }
94         if doNotWant, have := io.EOF, err; doNotWant == have {
95                 t.Errorf("do not want %q, have %q", doNotWant, have)
96         }
97 }
98
99 func TestHTTPClientBufferedStream(t *testing.T) {
100         var (
101                 testbody = "testbody"
102                 encode   = func(context.Context, *http.Request, interface{}) error { return nil }
103                 decode   = func(_ context.Context, r *http.Response) (interface{}, error) {
104                         return TestResponse{r.Body, ""}, nil
105                 }
106         )
107
108         server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
109                 w.WriteHeader(http.StatusOK)
110                 w.Write([]byte(testbody))
111         }))
112
113         client := httptransport.NewClient(
114                 "GET",
115                 mustParse(server.URL),
116                 encode,
117                 decode,
118                 httptransport.BufferedStream(true),
119         )
120
121         res, err := client.Endpoint()(context.Background(), struct{}{})
122         if err != nil {
123                 t.Fatal(err)
124         }
125
126         // Check that the response was successfully decoded
127         response, ok := res.(TestResponse)
128         if !ok {
129                 t.Fatal("response should be TestResponse")
130         }
131
132         // Check that response body was NOT closed
133         b := make([]byte, len(testbody))
134         _, err = response.Body.Read(b)
135         if want, have := io.EOF, err; have != want {
136                 t.Fatalf("want %q, have %q", want, have)
137         }
138         if want, have := testbody, string(b); want != have {
139                 t.Errorf("want %q, have %q", want, have)
140         }
141 }
142
143 func TestClientFinalizer(t *testing.T) {
144         var (
145                 headerKey    = "X-Henlo-Lizer"
146                 headerVal    = "Helllo you stinky lizard"
147                 responseBody = "go eat a fly ugly\n"
148                 done         = make(chan struct{})
149                 encode       = func(context.Context, *http.Request, interface{}) error { return nil }
150                 decode       = func(_ context.Context, r *http.Response) (interface{}, error) {
151                         return TestResponse{r.Body, ""}, nil
152                 }
153         )
154
155         server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
156                 w.Header().Set(headerKey, headerVal)
157                 w.Write([]byte(responseBody))
158         }))
159         defer server.Close()
160
161         client := httptransport.NewClient(
162                 "GET",
163                 mustParse(server.URL),
164                 encode,
165                 decode,
166                 httptransport.ClientFinalizer(func(ctx context.Context, err error) {
167                         responseHeader := ctx.Value(httptransport.ContextKeyResponseHeaders).(http.Header)
168                         if want, have := headerVal, responseHeader.Get(headerKey); want != have {
169                                 t.Errorf("%s: want %q, have %q", headerKey, want, have)
170                         }
171
172                         responseSize := ctx.Value(httptransport.ContextKeyResponseSize).(int64)
173                         if want, have := int64(len(responseBody)), responseSize; want != have {
174                                 t.Errorf("response size: want %d, have %d", want, have)
175                         }
176
177                         close(done)
178                 }),
179         )
180
181         _, err := client.Endpoint()(context.Background(), struct{}{})
182         if err != nil {
183                 t.Fatal(err)
184         }
185
186         select {
187         case <-done:
188         case <-time.After(time.Second):
189                 t.Fatal("timeout waiting for finalizer")
190         }
191 }
192
193 func TestEncodeJSONRequest(t *testing.T) {
194         var header http.Header
195         var body string
196
197         server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
198                 b, err := ioutil.ReadAll(r.Body)
199                 if err != nil && err != io.EOF {
200                         t.Fatal(err)
201                 }
202                 header = r.Header
203                 body = string(b)
204         }))
205
206         defer server.Close()
207
208         serverURL, err := url.Parse(server.URL)
209
210         if err != nil {
211                 t.Fatal(err)
212         }
213
214         client := httptransport.NewClient(
215                 "POST",
216                 serverURL,
217                 httptransport.EncodeJSONRequest,
218                 func(context.Context, *http.Response) (interface{}, error) { return nil, nil },
219         ).Endpoint()
220
221         for _, test := range []struct {
222                 value interface{}
223                 body  string
224         }{
225                 {nil, "null\n"},
226                 {12, "12\n"},
227                 {1.2, "1.2\n"},
228                 {true, "true\n"},
229                 {"test", "\"test\"\n"},
230                 {enhancedRequest{Foo: "foo"}, "{\"foo\":\"foo\"}\n"},
231         } {
232                 if _, err := client(context.Background(), test.value); err != nil {
233                         t.Error(err)
234                         continue
235                 }
236
237                 if body != test.body {
238                         t.Errorf("%v: actual %#v, expected %#v", test.value, body, test.body)
239                 }
240         }
241
242         if _, err := client(context.Background(), enhancedRequest{Foo: "foo"}); err != nil {
243                 t.Fatal(err)
244         }
245
246         if _, ok := header["X-Edward"]; !ok {
247                 t.Fatalf("X-Edward value: actual %v, expected %v", nil, []string{"Snowden"})
248         }
249
250         if v := header.Get("X-Edward"); v != "Snowden" {
251                 t.Errorf("X-Edward string: actual %v, expected %v", v, "Snowden")
252         }
253 }
254
255 func mustParse(s string) *url.URL {
256         u, err := url.Parse(s)
257         if err != nil {
258                 panic(err)
259         }
260         return u
261 }
262
263 type enhancedRequest struct {
264         Foo string `json:"foo"`
265 }
266
267 func (e enhancedRequest) Headers() http.Header { return http.Header{"X-Edward": []string{"Snowden"}} }