1 // Copyright 2012 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
5 // +build darwin dragonfly freebsd linux netbsd openbsd
19 type closeWriter interface {
23 func testPortForward(t *testing.T, n, listenAddr string) {
24 server := newServer(t)
25 defer server.Shutdown()
26 conn := server.Dial(clientConfig())
29 sshListener, err := conn.Listen(n, listenAddr)
35 sshConn, err := sshListener.Accept()
37 t.Fatalf("listen.Accept failed: %v", err)
40 _, err = io.Copy(sshConn, sshConn)
41 if err != nil && err != io.EOF {
42 t.Fatalf("ssh client copy: %v", err)
47 forwardedAddr := sshListener.Addr().String()
48 netConn, err := net.Dial(n, forwardedAddr)
50 t.Fatalf("net dial failed: %v", err)
53 readChan := make(chan []byte)
55 data, _ := ioutil.ReadAll(netConn)
60 data := make([]byte, 100*1000)
62 data[i] = byte(i % 255)
66 for len(sent) < 1000*1000 {
67 // Send random sized chunks
68 m := rand.Intn(len(data))
69 n, err := netConn.Write(data[:m])
73 sent = append(sent, data[:n]...)
75 if err := netConn.(closeWriter).CloseWrite(); err != nil {
76 t.Errorf("netConn.CloseWrite: %v", err)
81 if len(sent) != len(read) {
82 t.Fatalf("got %d bytes, want %d", len(read), len(sent))
84 if bytes.Compare(sent, read) != 0 {
85 t.Fatalf("read back data does not match")
88 if err := sshListener.Close(); err != nil {
89 t.Fatalf("sshListener.Close: %v", err)
92 // Check that the forward disappeared.
93 netConn, err = net.Dial(n, forwardedAddr)
96 t.Errorf("still listening to %s after closing", forwardedAddr)
100 func TestPortForwardTCP(t *testing.T) {
101 testPortForward(t, "tcp", "localhost:0")
104 func TestPortForwardUnix(t *testing.T) {
105 addr, cleanup := newTempSocket(t)
107 testPortForward(t, "unix", addr)
110 func testAcceptClose(t *testing.T, n, listenAddr string) {
111 server := newServer(t)
112 defer server.Shutdown()
113 conn := server.Dial(clientConfig())
115 sshListener, err := conn.Listen(n, listenAddr)
120 quit := make(chan error, 1)
123 c, err := sshListener.Accept()
134 case <-time.After(1 * time.Second):
135 t.Errorf("timeout: listener did not close.")
137 t.Logf("quit as expected (error %v)", err)
141 func TestAcceptCloseTCP(t *testing.T) {
142 testAcceptClose(t, "tcp", "localhost:0")
145 func TestAcceptCloseUnix(t *testing.T) {
146 addr, cleanup := newTempSocket(t)
148 testAcceptClose(t, "unix", addr)
151 // Check that listeners exit if the underlying client transport dies.
152 func testPortForwardConnectionClose(t *testing.T, n, listenAddr string) {
153 server := newServer(t)
154 defer server.Shutdown()
155 conn := server.Dial(clientConfig())
157 sshListener, err := conn.Listen(n, listenAddr)
162 quit := make(chan error, 1)
165 c, err := sshListener.Accept()
174 // It would be even nicer if we closed the server side, but it
175 // is more involved as the fd for that side is dup()ed.
176 server.clientConn.Close()
179 case <-time.After(1 * time.Second):
180 t.Errorf("timeout: listener did not close.")
182 t.Logf("quit as expected (error %v)", err)
186 func TestPortForwardConnectionCloseTCP(t *testing.T) {
187 testPortForwardConnectionClose(t, "tcp", "localhost:0")
190 func TestPortForwardConnectionCloseUnix(t *testing.T) {
191 addr, cleanup := newTempSocket(t)
193 testPortForwardConnectionClose(t, "unix", addr)