1 // Copyright 2013 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
14 func muxPair() (*mux, *mux) {
23 // Returns both ends of a channel, and the mux for the the 2nd
25 func channelPair(t *testing.T) (*channel, *channel, *mux) {
28 res := make(chan *channel, 1)
30 newCh, ok := <-s.incomingChannels
32 t.Fatalf("No incoming channel")
34 if newCh.ChannelType() != "chan" {
35 t.Fatalf("got type %q want chan", newCh.ChannelType())
37 ch, _, err := newCh.Accept()
39 t.Fatalf("Accept %v", err)
44 ch, err := c.openChannel("chan", nil)
46 t.Fatalf("OpenChannel: %v", err)
52 // Test that stderr and stdout can be addressed from different
53 // goroutines. This is intended for use with the race detector.
54 func TestMuxChannelExtendedThreadSafety(t *testing.T) {
55 writer, reader, mux := channelPair(t)
60 var wr, rd sync.WaitGroup
61 magic := "hello world"
65 io.WriteString(writer, magic)
69 io.WriteString(writer.Stderr(), magic)
75 c, err := ioutil.ReadAll(reader)
76 if string(c) != magic {
77 t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err)
82 c, err := ioutil.ReadAll(reader.Stderr())
83 if string(c) != magic {
84 t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err)
94 func TestMuxReadWrite(t *testing.T) {
95 s, c, mux := channelPair(t)
100 magic := "hello world"
101 magicExt := "hello stderr"
103 _, err := s.Write([]byte(magic))
105 t.Fatalf("Write: %v", err)
107 _, err = s.Extended(1).Write([]byte(magicExt))
109 t.Fatalf("Write: %v", err)
113 t.Fatalf("Close: %v", err)
118 n, err := c.Read(buf[:])
120 t.Fatalf("server Read: %v", err)
122 got := string(buf[:n])
124 t.Fatalf("server: got %q want %q", got, magic)
127 n, err = c.Extended(1).Read(buf[:])
129 t.Fatalf("server Read: %v", err)
132 got = string(buf[:n])
134 t.Fatalf("server: got %q want %q", got, magic)
138 func TestMuxChannelOverflow(t *testing.T) {
139 reader, writer, mux := channelPair(t)
144 wDone := make(chan int, 1)
146 if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
147 t.Errorf("could not fill window: %v", err)
149 writer.Write(make([]byte, 1))
152 writer.remoteWin.waitWriterBlocked()
155 packet := make([]byte, 1+4+4+1)
156 packet[0] = msgChannelData
157 marshalUint32(packet[1:], writer.remoteId)
158 marshalUint32(packet[5:], uint32(1))
161 if err := writer.mux.conn.writePacket(packet); err != nil {
162 t.Errorf("could not send packet")
164 if _, err := reader.SendRequest("hello", true, nil); err == nil {
165 t.Errorf("SendRequest succeeded.")
170 func TestMuxChannelCloseWriteUnblock(t *testing.T) {
171 reader, writer, mux := channelPair(t)
176 wDone := make(chan int, 1)
178 if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
179 t.Errorf("could not fill window: %v", err)
181 if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
182 t.Errorf("got %v, want EOF for unblock write", err)
187 writer.remoteWin.waitWriterBlocked()
192 func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
193 reader, writer, mux := channelPair(t)
198 wDone := make(chan int, 1)
200 if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
201 t.Errorf("could not fill window: %v", err)
203 if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
204 t.Errorf("got %v, want EOF for unblock write", err)
209 writer.remoteWin.waitWriterBlocked()
214 func TestMuxReject(t *testing.T) {
215 client, server := muxPair()
220 ch, ok := <-server.incomingChannels
224 if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" {
225 t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
227 ch.Reject(RejectionReason(42), "message")
230 ch, err := client.openChannel("ch", []byte("extra"))
232 t.Fatal("openChannel not rejected")
235 ocf, ok := err.(*OpenChannelError)
237 t.Errorf("got %#v want *OpenChannelError", err)
238 } else if ocf.Reason != 42 || ocf.Message != "message" {
239 t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message")
242 want := "ssh: rejected: unknown reason 42 (message)"
243 if err.Error() != want {
244 t.Errorf("got %q, want %q", err.Error(), want)
248 func TestMuxChannelRequest(t *testing.T) {
249 client, server, mux := channelPair(t)
255 var wg sync.WaitGroup
258 for r := range server.incomingRequests {
260 r.Reply(r.Type == "yes", nil)
264 _, err := client.SendRequest("yes", false, nil)
266 t.Fatalf("SendRequest: %v", err)
268 ok, err := client.SendRequest("yes", true, nil)
270 t.Fatalf("SendRequest: %v", err)
274 t.Errorf("SendRequest(yes): %v", ok)
278 ok, err = client.SendRequest("no", true, nil)
280 t.Fatalf("SendRequest: %v", err)
283 t.Errorf("SendRequest(no): %v", ok)
291 t.Errorf("got %d requests, want %d", received, 3)
295 func TestMuxGlobalRequest(t *testing.T) {
296 clientMux, serverMux := muxPair()
297 defer serverMux.Close()
298 defer clientMux.Close()
302 for r := range serverMux.incomingRequests {
303 seen = seen || r.Type == "peek"
305 err := r.Reply(r.Type == "yes",
306 append([]byte(r.Type), r.Payload...))
308 t.Errorf("AckRequest: %v", err)
314 _, _, err := clientMux.SendRequest("peek", false, nil)
316 t.Errorf("SendRequest: %v", err)
319 ok, data, err := clientMux.SendRequest("yes", true, []byte("a"))
320 if !ok || string(data) != "yesa" || err != nil {
321 t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
324 if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil {
325 t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
329 if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil {
330 t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
335 t.Errorf("never saw 'peek' request")
339 func TestMuxGlobalRequestUnblock(t *testing.T) {
340 clientMux, serverMux := muxPair()
341 defer serverMux.Close()
342 defer clientMux.Close()
344 result := make(chan error, 1)
346 _, _, err := clientMux.SendRequest("hello", true, nil)
350 <-serverMux.incomingRequests
351 serverMux.conn.Close()
355 t.Errorf("want EOF, got %v", io.EOF)
359 func TestMuxChannelRequestUnblock(t *testing.T) {
360 a, b, connB := channelPair(t)
365 result := make(chan error, 1)
367 _, err := a.SendRequest("hello", true, nil)
376 t.Errorf("want EOF, got %v", err)
380 func TestMuxCloseChannel(t *testing.T) {
381 r, w, mux := channelPair(t)
386 result := make(chan error, 1)
389 _, err := r.Read(b[:])
392 if err := w.Close(); err != nil {
393 t.Errorf("w.Close: %v", err)
396 if _, err := w.Write([]byte("hello")); err != io.EOF {
397 t.Errorf("got err %v, want io.EOF after Close", err)
400 if err := <-result; err != io.EOF {
401 t.Errorf("got %v (%T), want io.EOF", err, err)
405 func TestMuxCloseWriteChannel(t *testing.T) {
406 r, w, mux := channelPair(t)
409 result := make(chan error, 1)
412 _, err := r.Read(b[:])
415 if err := w.CloseWrite(); err != nil {
416 t.Errorf("w.CloseWrite: %v", err)
419 if _, err := w.Write([]byte("hello")); err != io.EOF {
420 t.Errorf("got err %v, want io.EOF after CloseWrite", err)
423 if err := <-result; err != io.EOF {
424 t.Errorf("got %v (%T), want io.EOF", err, err)
428 func TestMuxInvalidRecord(t *testing.T) {
433 packet := make([]byte, 1+4+4+1)
434 packet[0] = msgChannelData
435 marshalUint32(packet[1:], 29348723 /* invalid channel id */)
436 marshalUint32(packet[5:], 1)
439 a.conn.writePacket(packet)
440 go a.SendRequest("hello", false, nil)
441 // 'a' wrote an invalid packet, so 'b' has exited.
442 req, ok := <-b.incomingRequests
444 t.Errorf("got request %#v after receiving invalid packet", req)
448 func TestZeroWindowAdjust(t *testing.T) {
449 a, b, mux := channelPair(t)
455 io.WriteString(a, "hello")
457 a.sendMessage(windowAdjustMsg{})
458 io.WriteString(a, "world")
463 c, _ := ioutil.ReadAll(b)
464 if string(c) != want {
465 t.Errorf("got %q want %q", c, want)
469 func TestMuxMaxPacketSize(t *testing.T) {
470 a, b, mux := channelPair(t)
475 large := make([]byte, a.maxRemotePayload+1)
476 packet := make([]byte, 1+4+4+1+len(large))
477 packet[0] = msgChannelData
478 marshalUint32(packet[1:], a.remoteId)
479 marshalUint32(packet[5:], uint32(len(large)))
482 if err := a.mux.conn.writePacket(packet); err != nil {
483 t.Errorf("could not send packet")
486 go a.SendRequest("hello", false, nil)
488 _, ok := <-b.incomingRequests
490 t.Errorf("connection still alive after receiving large packet.")
494 // Don't ship code with debug=true.
495 func TestDebug(t *testing.T) {
497 t.Error("mux debug switched on")
500 t.Error("handshake debug switched on")
503 t.Error("transport debug switched on")