OSDN Git Service

new repo
[bytom/vapor.git] / vendor / golang.org / x / crypto / ssh / mux_test.go
1 // Copyright 2013 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package ssh
6
7 import (
8         "io"
9         "io/ioutil"
10         "sync"
11         "testing"
12 )
13
14 func muxPair() (*mux, *mux) {
15         a, b := memPipe()
16
17         s := newMux(a)
18         c := newMux(b)
19
20         return s, c
21 }
22
23 // Returns both ends of a channel, and the mux for the the 2nd
24 // channel.
25 func channelPair(t *testing.T) (*channel, *channel, *mux) {
26         c, s := muxPair()
27
28         res := make(chan *channel, 1)
29         go func() {
30                 newCh, ok := <-s.incomingChannels
31                 if !ok {
32                         t.Fatalf("No incoming channel")
33                 }
34                 if newCh.ChannelType() != "chan" {
35                         t.Fatalf("got type %q want chan", newCh.ChannelType())
36                 }
37                 ch, _, err := newCh.Accept()
38                 if err != nil {
39                         t.Fatalf("Accept %v", err)
40                 }
41                 res <- ch.(*channel)
42         }()
43
44         ch, err := c.openChannel("chan", nil)
45         if err != nil {
46                 t.Fatalf("OpenChannel: %v", err)
47         }
48
49         return <-res, ch, c
50 }
51
52 // Test that stderr and stdout can be addressed from different
53 // goroutines. This is intended for use with the race detector.
54 func TestMuxChannelExtendedThreadSafety(t *testing.T) {
55         writer, reader, mux := channelPair(t)
56         defer writer.Close()
57         defer reader.Close()
58         defer mux.Close()
59
60         var wr, rd sync.WaitGroup
61         magic := "hello world"
62
63         wr.Add(2)
64         go func() {
65                 io.WriteString(writer, magic)
66                 wr.Done()
67         }()
68         go func() {
69                 io.WriteString(writer.Stderr(), magic)
70                 wr.Done()
71         }()
72
73         rd.Add(2)
74         go func() {
75                 c, err := ioutil.ReadAll(reader)
76                 if string(c) != magic {
77                         t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err)
78                 }
79                 rd.Done()
80         }()
81         go func() {
82                 c, err := ioutil.ReadAll(reader.Stderr())
83                 if string(c) != magic {
84                         t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err)
85                 }
86                 rd.Done()
87         }()
88
89         wr.Wait()
90         writer.CloseWrite()
91         rd.Wait()
92 }
93
94 func TestMuxReadWrite(t *testing.T) {
95         s, c, mux := channelPair(t)
96         defer s.Close()
97         defer c.Close()
98         defer mux.Close()
99
100         magic := "hello world"
101         magicExt := "hello stderr"
102         go func() {
103                 _, err := s.Write([]byte(magic))
104                 if err != nil {
105                         t.Fatalf("Write: %v", err)
106                 }
107                 _, err = s.Extended(1).Write([]byte(magicExt))
108                 if err != nil {
109                         t.Fatalf("Write: %v", err)
110                 }
111                 err = s.Close()
112                 if err != nil {
113                         t.Fatalf("Close: %v", err)
114                 }
115         }()
116
117         var buf [1024]byte
118         n, err := c.Read(buf[:])
119         if err != nil {
120                 t.Fatalf("server Read: %v", err)
121         }
122         got := string(buf[:n])
123         if got != magic {
124                 t.Fatalf("server: got %q want %q", got, magic)
125         }
126
127         n, err = c.Extended(1).Read(buf[:])
128         if err != nil {
129                 t.Fatalf("server Read: %v", err)
130         }
131
132         got = string(buf[:n])
133         if got != magicExt {
134                 t.Fatalf("server: got %q want %q", got, magic)
135         }
136 }
137
138 func TestMuxChannelOverflow(t *testing.T) {
139         reader, writer, mux := channelPair(t)
140         defer reader.Close()
141         defer writer.Close()
142         defer mux.Close()
143
144         wDone := make(chan int, 1)
145         go func() {
146                 if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
147                         t.Errorf("could not fill window: %v", err)
148                 }
149                 writer.Write(make([]byte, 1))
150                 wDone <- 1
151         }()
152         writer.remoteWin.waitWriterBlocked()
153
154         // Send 1 byte.
155         packet := make([]byte, 1+4+4+1)
156         packet[0] = msgChannelData
157         marshalUint32(packet[1:], writer.remoteId)
158         marshalUint32(packet[5:], uint32(1))
159         packet[9] = 42
160
161         if err := writer.mux.conn.writePacket(packet); err != nil {
162                 t.Errorf("could not send packet")
163         }
164         if _, err := reader.SendRequest("hello", true, nil); err == nil {
165                 t.Errorf("SendRequest succeeded.")
166         }
167         <-wDone
168 }
169
170 func TestMuxChannelCloseWriteUnblock(t *testing.T) {
171         reader, writer, mux := channelPair(t)
172         defer reader.Close()
173         defer writer.Close()
174         defer mux.Close()
175
176         wDone := make(chan int, 1)
177         go func() {
178                 if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
179                         t.Errorf("could not fill window: %v", err)
180                 }
181                 if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
182                         t.Errorf("got %v, want EOF for unblock write", err)
183                 }
184                 wDone <- 1
185         }()
186
187         writer.remoteWin.waitWriterBlocked()
188         reader.Close()
189         <-wDone
190 }
191
192 func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
193         reader, writer, mux := channelPair(t)
194         defer reader.Close()
195         defer writer.Close()
196         defer mux.Close()
197
198         wDone := make(chan int, 1)
199         go func() {
200                 if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
201                         t.Errorf("could not fill window: %v", err)
202                 }
203                 if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
204                         t.Errorf("got %v, want EOF for unblock write", err)
205                 }
206                 wDone <- 1
207         }()
208
209         writer.remoteWin.waitWriterBlocked()
210         mux.Close()
211         <-wDone
212 }
213
214 func TestMuxReject(t *testing.T) {
215         client, server := muxPair()
216         defer server.Close()
217         defer client.Close()
218
219         go func() {
220                 ch, ok := <-server.incomingChannels
221                 if !ok {
222                         t.Fatalf("Accept")
223                 }
224                 if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" {
225                         t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
226                 }
227                 ch.Reject(RejectionReason(42), "message")
228         }()
229
230         ch, err := client.openChannel("ch", []byte("extra"))
231         if ch != nil {
232                 t.Fatal("openChannel not rejected")
233         }
234
235         ocf, ok := err.(*OpenChannelError)
236         if !ok {
237                 t.Errorf("got %#v want *OpenChannelError", err)
238         } else if ocf.Reason != 42 || ocf.Message != "message" {
239                 t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message")
240         }
241
242         want := "ssh: rejected: unknown reason 42 (message)"
243         if err.Error() != want {
244                 t.Errorf("got %q, want %q", err.Error(), want)
245         }
246 }
247
248 func TestMuxChannelRequest(t *testing.T) {
249         client, server, mux := channelPair(t)
250         defer server.Close()
251         defer client.Close()
252         defer mux.Close()
253
254         var received int
255         var wg sync.WaitGroup
256         wg.Add(1)
257         go func() {
258                 for r := range server.incomingRequests {
259                         received++
260                         r.Reply(r.Type == "yes", nil)
261                 }
262                 wg.Done()
263         }()
264         _, err := client.SendRequest("yes", false, nil)
265         if err != nil {
266                 t.Fatalf("SendRequest: %v", err)
267         }
268         ok, err := client.SendRequest("yes", true, nil)
269         if err != nil {
270                 t.Fatalf("SendRequest: %v", err)
271         }
272
273         if !ok {
274                 t.Errorf("SendRequest(yes): %v", ok)
275
276         }
277
278         ok, err = client.SendRequest("no", true, nil)
279         if err != nil {
280                 t.Fatalf("SendRequest: %v", err)
281         }
282         if ok {
283                 t.Errorf("SendRequest(no): %v", ok)
284
285         }
286
287         client.Close()
288         wg.Wait()
289
290         if received != 3 {
291                 t.Errorf("got %d requests, want %d", received, 3)
292         }
293 }
294
295 func TestMuxGlobalRequest(t *testing.T) {
296         clientMux, serverMux := muxPair()
297         defer serverMux.Close()
298         defer clientMux.Close()
299
300         var seen bool
301         go func() {
302                 for r := range serverMux.incomingRequests {
303                         seen = seen || r.Type == "peek"
304                         if r.WantReply {
305                                 err := r.Reply(r.Type == "yes",
306                                         append([]byte(r.Type), r.Payload...))
307                                 if err != nil {
308                                         t.Errorf("AckRequest: %v", err)
309                                 }
310                         }
311                 }
312         }()
313
314         _, _, err := clientMux.SendRequest("peek", false, nil)
315         if err != nil {
316                 t.Errorf("SendRequest: %v", err)
317         }
318
319         ok, data, err := clientMux.SendRequest("yes", true, []byte("a"))
320         if !ok || string(data) != "yesa" || err != nil {
321                 t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
322                         ok, data, err)
323         }
324         if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil {
325                 t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
326                         ok, data, err)
327         }
328
329         if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil {
330                 t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
331                         ok, data, err)
332         }
333
334         if !seen {
335                 t.Errorf("never saw 'peek' request")
336         }
337 }
338
339 func TestMuxGlobalRequestUnblock(t *testing.T) {
340         clientMux, serverMux := muxPair()
341         defer serverMux.Close()
342         defer clientMux.Close()
343
344         result := make(chan error, 1)
345         go func() {
346                 _, _, err := clientMux.SendRequest("hello", true, nil)
347                 result <- err
348         }()
349
350         <-serverMux.incomingRequests
351         serverMux.conn.Close()
352         err := <-result
353
354         if err != io.EOF {
355                 t.Errorf("want EOF, got %v", io.EOF)
356         }
357 }
358
359 func TestMuxChannelRequestUnblock(t *testing.T) {
360         a, b, connB := channelPair(t)
361         defer a.Close()
362         defer b.Close()
363         defer connB.Close()
364
365         result := make(chan error, 1)
366         go func() {
367                 _, err := a.SendRequest("hello", true, nil)
368                 result <- err
369         }()
370
371         <-b.incomingRequests
372         connB.conn.Close()
373         err := <-result
374
375         if err != io.EOF {
376                 t.Errorf("want EOF, got %v", err)
377         }
378 }
379
380 func TestMuxCloseChannel(t *testing.T) {
381         r, w, mux := channelPair(t)
382         defer mux.Close()
383         defer r.Close()
384         defer w.Close()
385
386         result := make(chan error, 1)
387         go func() {
388                 var b [1024]byte
389                 _, err := r.Read(b[:])
390                 result <- err
391         }()
392         if err := w.Close(); err != nil {
393                 t.Errorf("w.Close: %v", err)
394         }
395
396         if _, err := w.Write([]byte("hello")); err != io.EOF {
397                 t.Errorf("got err %v, want io.EOF after Close", err)
398         }
399
400         if err := <-result; err != io.EOF {
401                 t.Errorf("got %v (%T), want io.EOF", err, err)
402         }
403 }
404
405 func TestMuxCloseWriteChannel(t *testing.T) {
406         r, w, mux := channelPair(t)
407         defer mux.Close()
408
409         result := make(chan error, 1)
410         go func() {
411                 var b [1024]byte
412                 _, err := r.Read(b[:])
413                 result <- err
414         }()
415         if err := w.CloseWrite(); err != nil {
416                 t.Errorf("w.CloseWrite: %v", err)
417         }
418
419         if _, err := w.Write([]byte("hello")); err != io.EOF {
420                 t.Errorf("got err %v, want io.EOF after CloseWrite", err)
421         }
422
423         if err := <-result; err != io.EOF {
424                 t.Errorf("got %v (%T), want io.EOF", err, err)
425         }
426 }
427
428 func TestMuxInvalidRecord(t *testing.T) {
429         a, b := muxPair()
430         defer a.Close()
431         defer b.Close()
432
433         packet := make([]byte, 1+4+4+1)
434         packet[0] = msgChannelData
435         marshalUint32(packet[1:], 29348723 /* invalid channel id */)
436         marshalUint32(packet[5:], 1)
437         packet[9] = 42
438
439         a.conn.writePacket(packet)
440         go a.SendRequest("hello", false, nil)
441         // 'a' wrote an invalid packet, so 'b' has exited.
442         req, ok := <-b.incomingRequests
443         if ok {
444                 t.Errorf("got request %#v after receiving invalid packet", req)
445         }
446 }
447
448 func TestZeroWindowAdjust(t *testing.T) {
449         a, b, mux := channelPair(t)
450         defer a.Close()
451         defer b.Close()
452         defer mux.Close()
453
454         go func() {
455                 io.WriteString(a, "hello")
456                 // bogus adjust.
457                 a.sendMessage(windowAdjustMsg{})
458                 io.WriteString(a, "world")
459                 a.Close()
460         }()
461
462         want := "helloworld"
463         c, _ := ioutil.ReadAll(b)
464         if string(c) != want {
465                 t.Errorf("got %q want %q", c, want)
466         }
467 }
468
469 func TestMuxMaxPacketSize(t *testing.T) {
470         a, b, mux := channelPair(t)
471         defer a.Close()
472         defer b.Close()
473         defer mux.Close()
474
475         large := make([]byte, a.maxRemotePayload+1)
476         packet := make([]byte, 1+4+4+1+len(large))
477         packet[0] = msgChannelData
478         marshalUint32(packet[1:], a.remoteId)
479         marshalUint32(packet[5:], uint32(len(large)))
480         packet[9] = 42
481
482         if err := a.mux.conn.writePacket(packet); err != nil {
483                 t.Errorf("could not send packet")
484         }
485
486         go a.SendRequest("hello", false, nil)
487
488         _, ok := <-b.incomingRequests
489         if ok {
490                 t.Errorf("connection still alive after receiving large packet.")
491         }
492 }
493
494 // Don't ship code with debug=true.
495 func TestDebug(t *testing.T) {
496         if debugMux {
497                 t.Error("mux debug switched on")
498         }
499         if debugHandshake {
500                 t.Error("handshake debug switched on")
501         }
502         if debugTransport {
503                 t.Error("transport debug switched on")
504         }
505 }