OSDN Git Service

new repo
[bytom/vapor.git] / vendor / golang.org / x / crypto / ssh / handshake.go
1 // Copyright 2013 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package ssh
6
7 import (
8         "crypto/rand"
9         "errors"
10         "fmt"
11         "io"
12         "log"
13         "net"
14         "sync"
15 )
16
17 // debugHandshake, if set, prints messages sent and received.  Key
18 // exchange messages are printed as if DH were used, so the debug
19 // messages are wrong when using ECDH.
20 const debugHandshake = false
21
22 // chanSize sets the amount of buffering SSH connections. This is
23 // primarily for testing: setting chanSize=0 uncovers deadlocks more
24 // quickly.
25 const chanSize = 16
26
27 // keyingTransport is a packet based transport that supports key
28 // changes. It need not be thread-safe. It should pass through
29 // msgNewKeys in both directions.
30 type keyingTransport interface {
31         packetConn
32
33         // prepareKeyChange sets up a key change. The key change for a
34         // direction will be effected if a msgNewKeys message is sent
35         // or received.
36         prepareKeyChange(*algorithms, *kexResult) error
37 }
38
39 // handshakeTransport implements rekeying on top of a keyingTransport
40 // and offers a thread-safe writePacket() interface.
41 type handshakeTransport struct {
42         conn   keyingTransport
43         config *Config
44
45         serverVersion []byte
46         clientVersion []byte
47
48         // hostKeys is non-empty if we are the server. In that case,
49         // it contains all host keys that can be used to sign the
50         // connection.
51         hostKeys []Signer
52
53         // hostKeyAlgorithms is non-empty if we are the client. In that case,
54         // we accept these key types from the server as host key.
55         hostKeyAlgorithms []string
56
57         // On read error, incoming is closed, and readError is set.
58         incoming  chan []byte
59         readError error
60
61         mu             sync.Mutex
62         writeError     error
63         sentInitPacket []byte
64         sentInitMsg    *kexInitMsg
65         pendingPackets [][]byte // Used when a key exchange is in progress.
66
67         // If the read loop wants to schedule a kex, it pings this
68         // channel, and the write loop will send out a kex
69         // message.
70         requestKex chan struct{}
71
72         // If the other side requests or confirms a kex, its kexInit
73         // packet is sent here for the write loop to find it.
74         startKex chan *pendingKex
75
76         // data for host key checking
77         hostKeyCallback HostKeyCallback
78         dialAddress     string
79         remoteAddr      net.Addr
80
81         // Algorithms agreed in the last key exchange.
82         algorithms *algorithms
83
84         readPacketsLeft uint32
85         readBytesLeft   int64
86
87         writePacketsLeft uint32
88         writeBytesLeft   int64
89
90         // The session ID or nil if first kex did not complete yet.
91         sessionID []byte
92 }
93
94 type pendingKex struct {
95         otherInit []byte
96         done      chan error
97 }
98
99 func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
100         t := &handshakeTransport{
101                 conn:          conn,
102                 serverVersion: serverVersion,
103                 clientVersion: clientVersion,
104                 incoming:      make(chan []byte, chanSize),
105                 requestKex:    make(chan struct{}, 1),
106                 startKex:      make(chan *pendingKex, 1),
107
108                 config: config,
109         }
110         t.resetReadThresholds()
111         t.resetWriteThresholds()
112
113         // We always start with a mandatory key exchange.
114         t.requestKex <- struct{}{}
115         return t
116 }
117
118 func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport {
119         t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
120         t.dialAddress = dialAddr
121         t.remoteAddr = addr
122         t.hostKeyCallback = config.HostKeyCallback
123         if config.HostKeyAlgorithms != nil {
124                 t.hostKeyAlgorithms = config.HostKeyAlgorithms
125         } else {
126                 t.hostKeyAlgorithms = supportedHostKeyAlgos
127         }
128         go t.readLoop()
129         go t.kexLoop()
130         return t
131 }
132
133 func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport {
134         t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
135         t.hostKeys = config.hostKeys
136         go t.readLoop()
137         go t.kexLoop()
138         return t
139 }
140
141 func (t *handshakeTransport) getSessionID() []byte {
142         return t.sessionID
143 }
144
145 // waitSession waits for the session to be established. This should be
146 // the first thing to call after instantiating handshakeTransport.
147 func (t *handshakeTransport) waitSession() error {
148         p, err := t.readPacket()
149         if err != nil {
150                 return err
151         }
152         if p[0] != msgNewKeys {
153                 return fmt.Errorf("ssh: first packet should be msgNewKeys")
154         }
155
156         return nil
157 }
158
159 func (t *handshakeTransport) id() string {
160         if len(t.hostKeys) > 0 {
161                 return "server"
162         }
163         return "client"
164 }
165
166 func (t *handshakeTransport) printPacket(p []byte, write bool) {
167         action := "got"
168         if write {
169                 action = "sent"
170         }
171
172         if p[0] == msgChannelData || p[0] == msgChannelExtendedData {
173                 log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p))
174         } else {
175                 msg, err := decode(p)
176                 log.Printf("%s %s %T %v (%v)", t.id(), action, msg, msg, err)
177         }
178 }
179
180 func (t *handshakeTransport) readPacket() ([]byte, error) {
181         p, ok := <-t.incoming
182         if !ok {
183                 return nil, t.readError
184         }
185         return p, nil
186 }
187
188 func (t *handshakeTransport) readLoop() {
189         first := true
190         for {
191                 p, err := t.readOnePacket(first)
192                 first = false
193                 if err != nil {
194                         t.readError = err
195                         close(t.incoming)
196                         break
197                 }
198                 if p[0] == msgIgnore || p[0] == msgDebug {
199                         continue
200                 }
201                 t.incoming <- p
202         }
203
204         // Stop writers too.
205         t.recordWriteError(t.readError)
206
207         // Unblock the writer should it wait for this.
208         close(t.startKex)
209
210         // Don't close t.requestKex; it's also written to from writePacket.
211 }
212
213 func (t *handshakeTransport) pushPacket(p []byte) error {
214         if debugHandshake {
215                 t.printPacket(p, true)
216         }
217         return t.conn.writePacket(p)
218 }
219
220 func (t *handshakeTransport) getWriteError() error {
221         t.mu.Lock()
222         defer t.mu.Unlock()
223         return t.writeError
224 }
225
226 func (t *handshakeTransport) recordWriteError(err error) {
227         t.mu.Lock()
228         defer t.mu.Unlock()
229         if t.writeError == nil && err != nil {
230                 t.writeError = err
231         }
232 }
233
234 func (t *handshakeTransport) requestKeyExchange() {
235         select {
236         case t.requestKex <- struct{}{}:
237         default:
238                 // something already requested a kex, so do nothing.
239         }
240 }
241
242 func (t *handshakeTransport) resetWriteThresholds() {
243         t.writePacketsLeft = packetRekeyThreshold
244         if t.config.RekeyThreshold > 0 {
245                 t.writeBytesLeft = int64(t.config.RekeyThreshold)
246         } else if t.algorithms != nil {
247                 t.writeBytesLeft = t.algorithms.w.rekeyBytes()
248         } else {
249                 t.writeBytesLeft = 1 << 30
250         }
251 }
252
253 func (t *handshakeTransport) kexLoop() {
254
255 write:
256         for t.getWriteError() == nil {
257                 var request *pendingKex
258                 var sent bool
259
260                 for request == nil || !sent {
261                         var ok bool
262                         select {
263                         case request, ok = <-t.startKex:
264                                 if !ok {
265                                         break write
266                                 }
267                         case <-t.requestKex:
268                                 break
269                         }
270
271                         if !sent {
272                                 if err := t.sendKexInit(); err != nil {
273                                         t.recordWriteError(err)
274                                         break
275                                 }
276                                 sent = true
277                         }
278                 }
279
280                 if err := t.getWriteError(); err != nil {
281                         if request != nil {
282                                 request.done <- err
283                         }
284                         break
285                 }
286
287                 // We're not servicing t.requestKex, but that is OK:
288                 // we never block on sending to t.requestKex.
289
290                 // We're not servicing t.startKex, but the remote end
291                 // has just sent us a kexInitMsg, so it can't send
292                 // another key change request, until we close the done
293                 // channel on the pendingKex request.
294
295                 err := t.enterKeyExchange(request.otherInit)
296
297                 t.mu.Lock()
298                 t.writeError = err
299                 t.sentInitPacket = nil
300                 t.sentInitMsg = nil
301
302                 t.resetWriteThresholds()
303
304                 // we have completed the key exchange. Since the
305                 // reader is still blocked, it is safe to clear out
306                 // the requestKex channel. This avoids the situation
307                 // where: 1) we consumed our own request for the
308                 // initial kex, and 2) the kex from the remote side
309                 // caused another send on the requestKex channel,
310         clear:
311                 for {
312                         select {
313                         case <-t.requestKex:
314                                 //
315                         default:
316                                 break clear
317                         }
318                 }
319
320                 request.done <- t.writeError
321
322                 // kex finished. Push packets that we received while
323                 // the kex was in progress. Don't look at t.startKex
324                 // and don't increment writtenSinceKex: if we trigger
325                 // another kex while we are still busy with the last
326                 // one, things will become very confusing.
327                 for _, p := range t.pendingPackets {
328                         t.writeError = t.pushPacket(p)
329                         if t.writeError != nil {
330                                 break
331                         }
332                 }
333                 t.pendingPackets = t.pendingPackets[:0]
334                 t.mu.Unlock()
335         }
336
337         // drain startKex channel. We don't service t.requestKex
338         // because nobody does blocking sends there.
339         go func() {
340                 for init := range t.startKex {
341                         init.done <- t.writeError
342                 }
343         }()
344
345         // Unblock reader.
346         t.conn.Close()
347 }
348
349 // The protocol uses uint32 for packet counters, so we can't let them
350 // reach 1<<32.  We will actually read and write more packets than
351 // this, though: the other side may send more packets, and after we
352 // hit this limit on writing we will send a few more packets for the
353 // key exchange itself.
354 const packetRekeyThreshold = (1 << 31)
355
356 func (t *handshakeTransport) resetReadThresholds() {
357         t.readPacketsLeft = packetRekeyThreshold
358         if t.config.RekeyThreshold > 0 {
359                 t.readBytesLeft = int64(t.config.RekeyThreshold)
360         } else if t.algorithms != nil {
361                 t.readBytesLeft = t.algorithms.r.rekeyBytes()
362         } else {
363                 t.readBytesLeft = 1 << 30
364         }
365 }
366
367 func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
368         p, err := t.conn.readPacket()
369         if err != nil {
370                 return nil, err
371         }
372
373         if t.readPacketsLeft > 0 {
374                 t.readPacketsLeft--
375         } else {
376                 t.requestKeyExchange()
377         }
378
379         if t.readBytesLeft > 0 {
380                 t.readBytesLeft -= int64(len(p))
381         } else {
382                 t.requestKeyExchange()
383         }
384
385         if debugHandshake {
386                 t.printPacket(p, false)
387         }
388
389         if first && p[0] != msgKexInit {
390                 return nil, fmt.Errorf("ssh: first packet should be msgKexInit")
391         }
392
393         if p[0] != msgKexInit {
394                 return p, nil
395         }
396
397         firstKex := t.sessionID == nil
398
399         kex := pendingKex{
400                 done:      make(chan error, 1),
401                 otherInit: p,
402         }
403         t.startKex <- &kex
404         err = <-kex.done
405
406         if debugHandshake {
407                 log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err)
408         }
409
410         if err != nil {
411                 return nil, err
412         }
413
414         t.resetReadThresholds()
415
416         // By default, a key exchange is hidden from higher layers by
417         // translating it into msgIgnore.
418         successPacket := []byte{msgIgnore}
419         if firstKex {
420                 // sendKexInit() for the first kex waits for
421                 // msgNewKeys so the authentication process is
422                 // guaranteed to happen over an encrypted transport.
423                 successPacket = []byte{msgNewKeys}
424         }
425
426         return successPacket, nil
427 }
428
429 // sendKexInit sends a key change message.
430 func (t *handshakeTransport) sendKexInit() error {
431         t.mu.Lock()
432         defer t.mu.Unlock()
433         if t.sentInitMsg != nil {
434                 // kexInits may be sent either in response to the other side,
435                 // or because our side wants to initiate a key change, so we
436                 // may have already sent a kexInit. In that case, don't send a
437                 // second kexInit.
438                 return nil
439         }
440
441         msg := &kexInitMsg{
442                 KexAlgos:                t.config.KeyExchanges,
443                 CiphersClientServer:     t.config.Ciphers,
444                 CiphersServerClient:     t.config.Ciphers,
445                 MACsClientServer:        t.config.MACs,
446                 MACsServerClient:        t.config.MACs,
447                 CompressionClientServer: supportedCompressions,
448                 CompressionServerClient: supportedCompressions,
449         }
450         io.ReadFull(rand.Reader, msg.Cookie[:])
451
452         if len(t.hostKeys) > 0 {
453                 for _, k := range t.hostKeys {
454                         msg.ServerHostKeyAlgos = append(
455                                 msg.ServerHostKeyAlgos, k.PublicKey().Type())
456                 }
457         } else {
458                 msg.ServerHostKeyAlgos = t.hostKeyAlgorithms
459         }
460         packet := Marshal(msg)
461
462         // writePacket destroys the contents, so save a copy.
463         packetCopy := make([]byte, len(packet))
464         copy(packetCopy, packet)
465
466         if err := t.pushPacket(packetCopy); err != nil {
467                 return err
468         }
469
470         t.sentInitMsg = msg
471         t.sentInitPacket = packet
472
473         return nil
474 }
475
476 func (t *handshakeTransport) writePacket(p []byte) error {
477         switch p[0] {
478         case msgKexInit:
479                 return errors.New("ssh: only handshakeTransport can send kexInit")
480         case msgNewKeys:
481                 return errors.New("ssh: only handshakeTransport can send newKeys")
482         }
483
484         t.mu.Lock()
485         defer t.mu.Unlock()
486         if t.writeError != nil {
487                 return t.writeError
488         }
489
490         if t.sentInitMsg != nil {
491                 // Copy the packet so the writer can reuse the buffer.
492                 cp := make([]byte, len(p))
493                 copy(cp, p)
494                 t.pendingPackets = append(t.pendingPackets, cp)
495                 return nil
496         }
497
498         if t.writeBytesLeft > 0 {
499                 t.writeBytesLeft -= int64(len(p))
500         } else {
501                 t.requestKeyExchange()
502         }
503
504         if t.writePacketsLeft > 0 {
505                 t.writePacketsLeft--
506         } else {
507                 t.requestKeyExchange()
508         }
509
510         if err := t.pushPacket(p); err != nil {
511                 t.writeError = err
512         }
513
514         return nil
515 }
516
517 func (t *handshakeTransport) Close() error {
518         return t.conn.Close()
519 }
520
521 func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
522         if debugHandshake {
523                 log.Printf("%s entered key exchange", t.id())
524         }
525
526         otherInit := &kexInitMsg{}
527         if err := Unmarshal(otherInitPacket, otherInit); err != nil {
528                 return err
529         }
530
531         magics := handshakeMagics{
532                 clientVersion: t.clientVersion,
533                 serverVersion: t.serverVersion,
534                 clientKexInit: otherInitPacket,
535                 serverKexInit: t.sentInitPacket,
536         }
537
538         clientInit := otherInit
539         serverInit := t.sentInitMsg
540         if len(t.hostKeys) == 0 {
541                 clientInit, serverInit = serverInit, clientInit
542
543                 magics.clientKexInit = t.sentInitPacket
544                 magics.serverKexInit = otherInitPacket
545         }
546
547         var err error
548         t.algorithms, err = findAgreedAlgorithms(clientInit, serverInit)
549         if err != nil {
550                 return err
551         }
552
553         // We don't send FirstKexFollows, but we handle receiving it.
554         //
555         // RFC 4253 section 7 defines the kex and the agreement method for
556         // first_kex_packet_follows. It states that the guessed packet
557         // should be ignored if the "kex algorithm and/or the host
558         // key algorithm is guessed wrong (server and client have
559         // different preferred algorithm), or if any of the other
560         // algorithms cannot be agreed upon". The other algorithms have
561         // already been checked above so the kex algorithm and host key
562         // algorithm are checked here.
563         if otherInit.FirstKexFollows && (clientInit.KexAlgos[0] != serverInit.KexAlgos[0] || clientInit.ServerHostKeyAlgos[0] != serverInit.ServerHostKeyAlgos[0]) {
564                 // other side sent a kex message for the wrong algorithm,
565                 // which we have to ignore.
566                 if _, err := t.conn.readPacket(); err != nil {
567                         return err
568                 }
569         }
570
571         kex, ok := kexAlgoMap[t.algorithms.kex]
572         if !ok {
573                 return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex)
574         }
575
576         var result *kexResult
577         if len(t.hostKeys) > 0 {
578                 result, err = t.server(kex, t.algorithms, &magics)
579         } else {
580                 result, err = t.client(kex, t.algorithms, &magics)
581         }
582
583         if err != nil {
584                 return err
585         }
586
587         if t.sessionID == nil {
588                 t.sessionID = result.H
589         }
590         result.SessionID = t.sessionID
591
592         if err := t.conn.prepareKeyChange(t.algorithms, result); err != nil {
593                 return err
594         }
595         if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
596                 return err
597         }
598         if packet, err := t.conn.readPacket(); err != nil {
599                 return err
600         } else if packet[0] != msgNewKeys {
601                 return unexpectedMessageError(msgNewKeys, packet[0])
602         }
603
604         return nil
605 }
606
607 func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
608         var hostKey Signer
609         for _, k := range t.hostKeys {
610                 if algs.hostKey == k.PublicKey().Type() {
611                         hostKey = k
612                 }
613         }
614
615         r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey)
616         return r, err
617 }
618
619 func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) {
620         result, err := kex.Client(t.conn, t.config.Rand, magics)
621         if err != nil {
622                 return nil, err
623         }
624
625         hostKey, err := ParsePublicKey(result.HostKey)
626         if err != nil {
627                 return nil, err
628         }
629
630         if err := verifyHostKeySignature(hostKey, result); err != nil {
631                 return nil, err
632         }
633
634         err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
635         if err != nil {
636                 return nil, err
637         }
638
639         return result, nil
640 }