1 // Copyright 2013 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
21 type testChecker struct {
25 func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
26 if dialAddr == "bad" {
27 return fmt.Errorf("dialAddr is bad")
30 if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil {
31 return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr)
34 t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal()))
39 // netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
40 // therefore is buffered (net.Pipe deadlocks if both sides start with
42 func netPipe() (net.Conn, net.Conn, error) {
43 listener, err := net.Listen("tcp", "127.0.0.1:0")
45 listener, err = net.Listen("tcp", "[::1]:0")
50 defer listener.Close()
51 c1, err := net.Dial("tcp", listener.Addr().String())
56 c2, err := listener.Accept()
65 // noiseTransport inserts ignore messages to check that the read loop
66 // and the key exchange filters out these messages.
67 type noiseTransport struct {
71 func (t *noiseTransport) writePacket(p []byte) error {
72 ignore := []byte{msgIgnore}
73 if err := t.keyingTransport.writePacket(ignore); err != nil {
76 debug := []byte{msgDebug, 1, 2, 3}
77 if err := t.keyingTransport.writePacket(debug); err != nil {
81 return t.keyingTransport.writePacket(p)
84 func addNoiseTransport(t keyingTransport) keyingTransport {
85 return &noiseTransport{t}
88 // handshakePair creates two handshakeTransports connected with each
89 // other. If the noise argument is true, both transports will try to
90 // confuse the other side by sending ignore and debug messages.
91 func handshakePair(clientConf *ClientConfig, addr string, noise bool) (client *handshakeTransport, server *handshakeTransport, err error) {
92 a, b, err := netPipe()
97 var trC, trS keyingTransport
99 trC = newTransport(a, rand.Reader, true)
100 trS = newTransport(b, rand.Reader, false)
102 trC = addNoiseTransport(trC)
103 trS = addNoiseTransport(trS)
105 clientConf.SetDefaults()
107 v := []byte("version")
108 client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr())
110 serverConf := &ServerConfig{}
111 serverConf.AddHostKey(testSigners["ecdsa"])
112 serverConf.AddHostKey(testSigners["rsa"])
113 serverConf.SetDefaults()
114 server = newServerTransport(trS, v, v, serverConf)
116 if err := server.waitSession(); err != nil {
117 return nil, nil, fmt.Errorf("server.waitSession: %v", err)
119 if err := client.waitSession(); err != nil {
120 return nil, nil, fmt.Errorf("client.waitSession: %v", err)
123 return client, server, nil
126 func TestHandshakeBasic(t *testing.T) {
127 if runtime.GOOS == "plan9" {
128 t.Skip("see golang.org/issue/7237")
131 checker := &syncChecker{
132 waitCall: make(chan int, 10),
133 called: make(chan int, 10),
136 checker.waitCall <- 1
137 trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
139 t.Fatalf("handshakePair: %v", err)
145 // Let first kex complete normally.
148 clientDone := make(chan int, 0)
149 gotHalf := make(chan int, 0)
153 defer close(clientDone)
154 // Client writes a bunch of stuff, and does a key
155 // change in the middle. This should not confuse the
156 // handshake in progress. We do this twice, so we test
157 // that the packet buffer is reset correctly.
158 for i := 0; i < N; i++ {
159 p := []byte{msgRequestSuccess, byte(i)}
160 if err := trC.writePacket(p); err != nil {
161 t.Fatalf("sendPacket: %v", err)
165 // halfway through, we request a key change.
166 trC.requestKeyExchange()
168 // Wait until we can be sure the key
169 // change has really started before we
174 // write some packets until the kex
175 // completes, to test buffering of
177 checker.waitCall <- 1
182 // Server checks that client messages come in cleanly
187 p, err = trS.readPacket()
195 want := []byte{msgRequestSuccess, byte(i)}
196 if bytes.Compare(p, want) != 0 {
197 t.Errorf("message %d: got %v, want %v", i, p, want)
201 if err != nil && err != io.EOF {
202 t.Fatalf("server error: %v", err)
205 t.Errorf("received %d messages, want 10.", i)
208 close(checker.called)
209 if _, ok := <-checker.called; ok {
210 // If all went well, we registered exactly 2 key changes: one
211 // that establishes the session, and one that we requested
213 t.Fatalf("got another host key checks after 2 handshakes")
217 func TestForceFirstKex(t *testing.T) {
218 // like handshakePair, but must access the keyingTransport.
219 checker := &testChecker{}
220 clientConf := &ClientConfig{HostKeyCallback: checker.Check}
221 a, b, err := netPipe()
223 t.Fatalf("netPipe: %v", err)
226 var trC, trS keyingTransport
228 trC = newTransport(a, rand.Reader, true)
230 // This is the disallowed packet:
231 trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth}))
233 // Rest of the setup.
234 trS = newTransport(b, rand.Reader, false)
235 clientConf.SetDefaults()
237 v := []byte("version")
238 client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr())
240 serverConf := &ServerConfig{}
241 serverConf.AddHostKey(testSigners["ecdsa"])
242 serverConf.AddHostKey(testSigners["rsa"])
243 serverConf.SetDefaults()
244 server := newServerTransport(trS, v, v, serverConf)
249 // We setup the initial key exchange, but the remote side
250 // tries to send serviceRequestMsg in cleartext, which is
253 if err := server.waitSession(); err == nil {
254 t.Errorf("server first kex init should reject unexpected packet")
258 func TestHandshakeAutoRekeyWrite(t *testing.T) {
259 checker := &syncChecker{
260 called: make(chan int, 10),
263 clientConf := &ClientConfig{HostKeyCallback: checker.Check}
264 clientConf.RekeyThreshold = 500
265 trC, trS, err := handshakePair(clientConf, "addr", false)
267 t.Fatalf("handshakePair: %v", err)
272 input := make([]byte, 251)
273 input[0] = msgRequestSuccess
275 done := make(chan int, 1)
280 for ; j < numPacket; j++ {
281 if p, err := trS.readPacket(); err != nil {
283 } else if !bytes.Equal(input, p) {
284 t.Errorf("got packet type %d, want %d", p[0], input[0])
289 t.Errorf("got %d, want 5 messages", j)
295 for i := 0; i < numPacket; i++ {
296 p := make([]byte, len(input))
298 if err := trC.writePacket(p); err != nil {
299 t.Errorf("writePacket: %v", err)
302 // Make sure the kex is in progress.
310 type syncChecker struct {
315 func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
317 if c.waitCall != nil {
323 func TestHandshakeAutoRekeyRead(t *testing.T) {
324 sync := &syncChecker{
325 called: make(chan int, 2),
328 clientConf := &ClientConfig{
329 HostKeyCallback: sync.Check,
331 clientConf.RekeyThreshold = 500
333 trC, trS, err := handshakePair(clientConf, "addr", false)
335 t.Fatalf("handshakePair: %v", err)
340 packet := make([]byte, 501)
341 packet[0] = msgRequestSuccess
342 if err := trS.writePacket(packet); err != nil {
343 t.Fatalf("writePacket: %v", err)
346 // While we read out the packet, a key change will be
348 done := make(chan int, 1)
351 if _, err := trC.readPacket(); err != nil {
352 t.Fatalf("readPacket(client): %v", err)
361 // errorKeyingTransport generates errors after a given number of
362 // read/write operations.
363 type errorKeyingTransport struct {
365 readLeft, writeLeft int
368 func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
372 func (n *errorKeyingTransport) getSessionID() []byte {
376 func (n *errorKeyingTransport) writePacket(packet []byte) error {
377 if n.writeLeft == 0 {
379 return errors.New("barf")
383 return n.packetConn.writePacket(packet)
386 func (n *errorKeyingTransport) readPacket() ([]byte, error) {
389 return nil, errors.New("barf")
393 return n.packetConn.readPacket()
396 func TestHandshakeErrorHandlingRead(t *testing.T) {
397 for i := 0; i < 20; i++ {
398 testHandshakeErrorHandlingN(t, i, -1, false)
402 func TestHandshakeErrorHandlingWrite(t *testing.T) {
403 for i := 0; i < 20; i++ {
404 testHandshakeErrorHandlingN(t, -1, i, false)
408 func TestHandshakeErrorHandlingReadCoupled(t *testing.T) {
409 for i := 0; i < 20; i++ {
410 testHandshakeErrorHandlingN(t, i, -1, true)
414 func TestHandshakeErrorHandlingWriteCoupled(t *testing.T) {
415 for i := 0; i < 20; i++ {
416 testHandshakeErrorHandlingN(t, -1, i, true)
420 // testHandshakeErrorHandlingN runs handshakes, injecting errors. If
421 // handshakeTransport deadlocks, the go runtime will detect it and
423 func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, coupled bool) {
424 msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})
430 key := testSigners["ecdsa"]
431 serverConf := Config{RekeyThreshold: minRekeyThreshold}
432 serverConf.SetDefaults()
433 serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'})
434 serverConn.hostKeys = []Signer{key}
435 go serverConn.readLoop()
436 go serverConn.kexLoop()
438 clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold}
439 clientConf.SetDefaults()
440 clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})
441 clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}
442 clientConn.hostKeyCallback = InsecureIgnoreHostKey()
443 go clientConn.readLoop()
444 go clientConn.kexLoop()
446 var wg sync.WaitGroup
448 for _, hs := range []packetConn{serverConn, clientConn} {
451 go func(c packetConn) {
453 str := fmt.Sprintf("%08x", i) + strings.Repeat("x", int(minRekeyThreshold)/4-8)
454 err := c.writePacket(Marshal(&serviceRequestMsg{str}))
462 go func(c packetConn) {
464 _, err := c.readPacket()
473 go func(c packetConn) {
475 _, err := c.readPacket()
479 if err := c.writePacket(msg); err != nil {
491 func TestDisconnect(t *testing.T) {
492 if runtime.GOOS == "plan9" {
493 t.Skip("see golang.org/issue/7237")
495 checker := &testChecker{}
496 trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
498 t.Fatalf("handshakePair: %v", err)
504 trC.writePacket([]byte{msgRequestSuccess, 0, 0})
505 errMsg := &disconnectMsg{
507 Message: "such is life",
509 trC.writePacket(Marshal(errMsg))
510 trC.writePacket([]byte{msgRequestSuccess, 0, 0})
512 packet, err := trS.readPacket()
514 t.Fatalf("readPacket 1: %v", err)
516 if packet[0] != msgRequestSuccess {
517 t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess)
520 _, err = trS.readPacket()
522 t.Errorf("readPacket 2 succeeded")
523 } else if !reflect.DeepEqual(err, errMsg) {
524 t.Errorf("got error %#v, want %#v", err, errMsg)
527 _, err = trS.readPacket()
529 t.Errorf("readPacket 3 succeeded")
533 func TestHandshakeRekeyDefault(t *testing.T) {
534 clientConf := &ClientConfig{
536 Ciphers: []string{"aes128-ctr"},
538 HostKeyCallback: InsecureIgnoreHostKey(),
540 trC, trS, err := handshakePair(clientConf, "addr", false)
542 t.Fatalf("handshakePair: %v", err)
547 trC.writePacket([]byte{msgRequestSuccess, 0, 0})
550 rgb := (1024 + trC.readBytesLeft) >> 30
551 wgb := (1024 + trC.writeBytesLeft) >> 30
554 t.Errorf("got rekey after %dG read, want 64G", rgb)
557 t.Errorf("got rekey after %dG write, want 64G", wgb)