OSDN Git Service

Hulk did something
[bytom/vapor.git] / vendor / golang.org / x / crypto / ssh / mux.go
diff --git a/vendor/golang.org/x/crypto/ssh/mux.go b/vendor/golang.org/x/crypto/ssh/mux.go
new file mode 100644 (file)
index 0000000..27a527c
--- /dev/null
@@ -0,0 +1,330 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+       "encoding/binary"
+       "fmt"
+       "io"
+       "log"
+       "sync"
+       "sync/atomic"
+)
+
+// debugMux, if set, causes messages in the connection protocol to be
+// logged.
+const debugMux = false
+
+// chanList is a thread safe channel list.
+type chanList struct {
+       // protects concurrent access to chans
+       sync.Mutex
+
+       // chans are indexed by the local id of the channel, which the
+       // other side should send in the PeersId field.
+       chans []*channel
+
+       // This is a debugging aid: it offsets all IDs by this
+       // amount. This helps distinguish otherwise identical
+       // server/client muxes
+       offset uint32
+}
+
+// Assigns a channel ID to the given channel.
+func (c *chanList) add(ch *channel) uint32 {
+       c.Lock()
+       defer c.Unlock()
+       for i := range c.chans {
+               if c.chans[i] == nil {
+                       c.chans[i] = ch
+                       return uint32(i) + c.offset
+               }
+       }
+       c.chans = append(c.chans, ch)
+       return uint32(len(c.chans)-1) + c.offset
+}
+
+// getChan returns the channel for the given ID.
+func (c *chanList) getChan(id uint32) *channel {
+       id -= c.offset
+
+       c.Lock()
+       defer c.Unlock()
+       if id < uint32(len(c.chans)) {
+               return c.chans[id]
+       }
+       return nil
+}
+
+func (c *chanList) remove(id uint32) {
+       id -= c.offset
+       c.Lock()
+       if id < uint32(len(c.chans)) {
+               c.chans[id] = nil
+       }
+       c.Unlock()
+}
+
+// dropAll forgets all channels it knows, returning them in a slice.
+func (c *chanList) dropAll() []*channel {
+       c.Lock()
+       defer c.Unlock()
+       var r []*channel
+
+       for _, ch := range c.chans {
+               if ch == nil {
+                       continue
+               }
+               r = append(r, ch)
+       }
+       c.chans = nil
+       return r
+}
+
+// mux represents the state for the SSH connection protocol, which
+// multiplexes many channels onto a single packet transport.
+type mux struct {
+       conn     packetConn
+       chanList chanList
+
+       incomingChannels chan NewChannel
+
+       globalSentMu     sync.Mutex
+       globalResponses  chan interface{}
+       incomingRequests chan *Request
+
+       errCond *sync.Cond
+       err     error
+}
+
+// When debugging, each new chanList instantiation has a different
+// offset.
+var globalOff uint32
+
+func (m *mux) Wait() error {
+       m.errCond.L.Lock()
+       defer m.errCond.L.Unlock()
+       for m.err == nil {
+               m.errCond.Wait()
+       }
+       return m.err
+}
+
+// newMux returns a mux that runs over the given connection.
+func newMux(p packetConn) *mux {
+       m := &mux{
+               conn:             p,
+               incomingChannels: make(chan NewChannel, chanSize),
+               globalResponses:  make(chan interface{}, 1),
+               incomingRequests: make(chan *Request, chanSize),
+               errCond:          newCond(),
+       }
+       if debugMux {
+               m.chanList.offset = atomic.AddUint32(&globalOff, 1)
+       }
+
+       go m.loop()
+       return m
+}
+
+func (m *mux) sendMessage(msg interface{}) error {
+       p := Marshal(msg)
+       if debugMux {
+               log.Printf("send global(%d): %#v", m.chanList.offset, msg)
+       }
+       return m.conn.writePacket(p)
+}
+
+func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
+       if wantReply {
+               m.globalSentMu.Lock()
+               defer m.globalSentMu.Unlock()
+       }
+
+       if err := m.sendMessage(globalRequestMsg{
+               Type:      name,
+               WantReply: wantReply,
+               Data:      payload,
+       }); err != nil {
+               return false, nil, err
+       }
+
+       if !wantReply {
+               return false, nil, nil
+       }
+
+       msg, ok := <-m.globalResponses
+       if !ok {
+               return false, nil, io.EOF
+       }
+       switch msg := msg.(type) {
+       case *globalRequestFailureMsg:
+               return false, msg.Data, nil
+       case *globalRequestSuccessMsg:
+               return true, msg.Data, nil
+       default:
+               return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg)
+       }
+}
+
+// ackRequest must be called after processing a global request that
+// has WantReply set.
+func (m *mux) ackRequest(ok bool, data []byte) error {
+       if ok {
+               return m.sendMessage(globalRequestSuccessMsg{Data: data})
+       }
+       return m.sendMessage(globalRequestFailureMsg{Data: data})
+}
+
+func (m *mux) Close() error {
+       return m.conn.Close()
+}
+
+// loop runs the connection machine. It will process packets until an
+// error is encountered. To synchronize on loop exit, use mux.Wait.
+func (m *mux) loop() {
+       var err error
+       for err == nil {
+               err = m.onePacket()
+       }
+
+       for _, ch := range m.chanList.dropAll() {
+               ch.close()
+       }
+
+       close(m.incomingChannels)
+       close(m.incomingRequests)
+       close(m.globalResponses)
+
+       m.conn.Close()
+
+       m.errCond.L.Lock()
+       m.err = err
+       m.errCond.Broadcast()
+       m.errCond.L.Unlock()
+
+       if debugMux {
+               log.Println("loop exit", err)
+       }
+}
+
+// onePacket reads and processes one packet.
+func (m *mux) onePacket() error {
+       packet, err := m.conn.readPacket()
+       if err != nil {
+               return err
+       }
+
+       if debugMux {
+               if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData {
+                       log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet))
+               } else {
+                       p, _ := decode(packet)
+                       log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet))
+               }
+       }
+
+       switch packet[0] {
+       case msgChannelOpen:
+               return m.handleChannelOpen(packet)
+       case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
+               return m.handleGlobalPacket(packet)
+       }
+
+       // assume a channel packet.
+       if len(packet) < 5 {
+               return parseError(packet[0])
+       }
+       id := binary.BigEndian.Uint32(packet[1:])
+       ch := m.chanList.getChan(id)
+       if ch == nil {
+               return fmt.Errorf("ssh: invalid channel %d", id)
+       }
+
+       return ch.handlePacket(packet)
+}
+
+func (m *mux) handleGlobalPacket(packet []byte) error {
+       msg, err := decode(packet)
+       if err != nil {
+               return err
+       }
+
+       switch msg := msg.(type) {
+       case *globalRequestMsg:
+               m.incomingRequests <- &Request{
+                       Type:      msg.Type,
+                       WantReply: msg.WantReply,
+                       Payload:   msg.Data,
+                       mux:       m,
+               }
+       case *globalRequestSuccessMsg, *globalRequestFailureMsg:
+               m.globalResponses <- msg
+       default:
+               panic(fmt.Sprintf("not a global message %#v", msg))
+       }
+
+       return nil
+}
+
+// handleChannelOpen schedules a channel to be Accept()ed.
+func (m *mux) handleChannelOpen(packet []byte) error {
+       var msg channelOpenMsg
+       if err := Unmarshal(packet, &msg); err != nil {
+               return err
+       }
+
+       if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
+               failMsg := channelOpenFailureMsg{
+                       PeersId:  msg.PeersId,
+                       Reason:   ConnectionFailed,
+                       Message:  "invalid request",
+                       Language: "en_US.UTF-8",
+               }
+               return m.sendMessage(failMsg)
+       }
+
+       c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
+       c.remoteId = msg.PeersId
+       c.maxRemotePayload = msg.MaxPacketSize
+       c.remoteWin.add(msg.PeersWindow)
+       m.incomingChannels <- c
+       return nil
+}
+
+func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) {
+       ch, err := m.openChannel(chanType, extra)
+       if err != nil {
+               return nil, nil, err
+       }
+
+       return ch, ch.incomingRequests, nil
+}
+
+func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
+       ch := m.newChannel(chanType, channelOutbound, extra)
+
+       ch.maxIncomingPayload = channelMaxPacket
+
+       open := channelOpenMsg{
+               ChanType:         chanType,
+               PeersWindow:      ch.myWindow,
+               MaxPacketSize:    ch.maxIncomingPayload,
+               TypeSpecificData: extra,
+               PeersId:          ch.localId,
+       }
+       if err := m.sendMessage(open); err != nil {
+               return nil, err
+       }
+
+       switch msg := (<-ch.msg).(type) {
+       case *channelOpenConfirmMsg:
+               return ch, nil
+       case *channelOpenFailureMsg:
+               return nil, &OpenChannelError{msg.Reason, msg.Message}
+       default:
+               return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg)
+       }
+}