OSDN Git Service

Hulk did something
[bytom/vapor.git] / vendor / golang.org / x / crypto / ssh / mux.go
1 // Copyright 2013 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package ssh
6
7 import (
8         "encoding/binary"
9         "fmt"
10         "io"
11         "log"
12         "sync"
13         "sync/atomic"
14 )
15
16 // debugMux, if set, causes messages in the connection protocol to be
17 // logged.
18 const debugMux = false
19
20 // chanList is a thread safe channel list.
21 type chanList struct {
22         // protects concurrent access to chans
23         sync.Mutex
24
25         // chans are indexed by the local id of the channel, which the
26         // other side should send in the PeersId field.
27         chans []*channel
28
29         // This is a debugging aid: it offsets all IDs by this
30         // amount. This helps distinguish otherwise identical
31         // server/client muxes
32         offset uint32
33 }
34
35 // Assigns a channel ID to the given channel.
36 func (c *chanList) add(ch *channel) uint32 {
37         c.Lock()
38         defer c.Unlock()
39         for i := range c.chans {
40                 if c.chans[i] == nil {
41                         c.chans[i] = ch
42                         return uint32(i) + c.offset
43                 }
44         }
45         c.chans = append(c.chans, ch)
46         return uint32(len(c.chans)-1) + c.offset
47 }
48
49 // getChan returns the channel for the given ID.
50 func (c *chanList) getChan(id uint32) *channel {
51         id -= c.offset
52
53         c.Lock()
54         defer c.Unlock()
55         if id < uint32(len(c.chans)) {
56                 return c.chans[id]
57         }
58         return nil
59 }
60
61 func (c *chanList) remove(id uint32) {
62         id -= c.offset
63         c.Lock()
64         if id < uint32(len(c.chans)) {
65                 c.chans[id] = nil
66         }
67         c.Unlock()
68 }
69
70 // dropAll forgets all channels it knows, returning them in a slice.
71 func (c *chanList) dropAll() []*channel {
72         c.Lock()
73         defer c.Unlock()
74         var r []*channel
75
76         for _, ch := range c.chans {
77                 if ch == nil {
78                         continue
79                 }
80                 r = append(r, ch)
81         }
82         c.chans = nil
83         return r
84 }
85
86 // mux represents the state for the SSH connection protocol, which
87 // multiplexes many channels onto a single packet transport.
88 type mux struct {
89         conn     packetConn
90         chanList chanList
91
92         incomingChannels chan NewChannel
93
94         globalSentMu     sync.Mutex
95         globalResponses  chan interface{}
96         incomingRequests chan *Request
97
98         errCond *sync.Cond
99         err     error
100 }
101
102 // When debugging, each new chanList instantiation has a different
103 // offset.
104 var globalOff uint32
105
106 func (m *mux) Wait() error {
107         m.errCond.L.Lock()
108         defer m.errCond.L.Unlock()
109         for m.err == nil {
110                 m.errCond.Wait()
111         }
112         return m.err
113 }
114
115 // newMux returns a mux that runs over the given connection.
116 func newMux(p packetConn) *mux {
117         m := &mux{
118                 conn:             p,
119                 incomingChannels: make(chan NewChannel, chanSize),
120                 globalResponses:  make(chan interface{}, 1),
121                 incomingRequests: make(chan *Request, chanSize),
122                 errCond:          newCond(),
123         }
124         if debugMux {
125                 m.chanList.offset = atomic.AddUint32(&globalOff, 1)
126         }
127
128         go m.loop()
129         return m
130 }
131
132 func (m *mux) sendMessage(msg interface{}) error {
133         p := Marshal(msg)
134         if debugMux {
135                 log.Printf("send global(%d): %#v", m.chanList.offset, msg)
136         }
137         return m.conn.writePacket(p)
138 }
139
140 func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
141         if wantReply {
142                 m.globalSentMu.Lock()
143                 defer m.globalSentMu.Unlock()
144         }
145
146         if err := m.sendMessage(globalRequestMsg{
147                 Type:      name,
148                 WantReply: wantReply,
149                 Data:      payload,
150         }); err != nil {
151                 return false, nil, err
152         }
153
154         if !wantReply {
155                 return false, nil, nil
156         }
157
158         msg, ok := <-m.globalResponses
159         if !ok {
160                 return false, nil, io.EOF
161         }
162         switch msg := msg.(type) {
163         case *globalRequestFailureMsg:
164                 return false, msg.Data, nil
165         case *globalRequestSuccessMsg:
166                 return true, msg.Data, nil
167         default:
168                 return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg)
169         }
170 }
171
172 // ackRequest must be called after processing a global request that
173 // has WantReply set.
174 func (m *mux) ackRequest(ok bool, data []byte) error {
175         if ok {
176                 return m.sendMessage(globalRequestSuccessMsg{Data: data})
177         }
178         return m.sendMessage(globalRequestFailureMsg{Data: data})
179 }
180
181 func (m *mux) Close() error {
182         return m.conn.Close()
183 }
184
185 // loop runs the connection machine. It will process packets until an
186 // error is encountered. To synchronize on loop exit, use mux.Wait.
187 func (m *mux) loop() {
188         var err error
189         for err == nil {
190                 err = m.onePacket()
191         }
192
193         for _, ch := range m.chanList.dropAll() {
194                 ch.close()
195         }
196
197         close(m.incomingChannels)
198         close(m.incomingRequests)
199         close(m.globalResponses)
200
201         m.conn.Close()
202
203         m.errCond.L.Lock()
204         m.err = err
205         m.errCond.Broadcast()
206         m.errCond.L.Unlock()
207
208         if debugMux {
209                 log.Println("loop exit", err)
210         }
211 }
212
213 // onePacket reads and processes one packet.
214 func (m *mux) onePacket() error {
215         packet, err := m.conn.readPacket()
216         if err != nil {
217                 return err
218         }
219
220         if debugMux {
221                 if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData {
222                         log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet))
223                 } else {
224                         p, _ := decode(packet)
225                         log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet))
226                 }
227         }
228
229         switch packet[0] {
230         case msgChannelOpen:
231                 return m.handleChannelOpen(packet)
232         case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
233                 return m.handleGlobalPacket(packet)
234         }
235
236         // assume a channel packet.
237         if len(packet) < 5 {
238                 return parseError(packet[0])
239         }
240         id := binary.BigEndian.Uint32(packet[1:])
241         ch := m.chanList.getChan(id)
242         if ch == nil {
243                 return fmt.Errorf("ssh: invalid channel %d", id)
244         }
245
246         return ch.handlePacket(packet)
247 }
248
249 func (m *mux) handleGlobalPacket(packet []byte) error {
250         msg, err := decode(packet)
251         if err != nil {
252                 return err
253         }
254
255         switch msg := msg.(type) {
256         case *globalRequestMsg:
257                 m.incomingRequests <- &Request{
258                         Type:      msg.Type,
259                         WantReply: msg.WantReply,
260                         Payload:   msg.Data,
261                         mux:       m,
262                 }
263         case *globalRequestSuccessMsg, *globalRequestFailureMsg:
264                 m.globalResponses <- msg
265         default:
266                 panic(fmt.Sprintf("not a global message %#v", msg))
267         }
268
269         return nil
270 }
271
272 // handleChannelOpen schedules a channel to be Accept()ed.
273 func (m *mux) handleChannelOpen(packet []byte) error {
274         var msg channelOpenMsg
275         if err := Unmarshal(packet, &msg); err != nil {
276                 return err
277         }
278
279         if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
280                 failMsg := channelOpenFailureMsg{
281                         PeersId:  msg.PeersId,
282                         Reason:   ConnectionFailed,
283                         Message:  "invalid request",
284                         Language: "en_US.UTF-8",
285                 }
286                 return m.sendMessage(failMsg)
287         }
288
289         c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
290         c.remoteId = msg.PeersId
291         c.maxRemotePayload = msg.MaxPacketSize
292         c.remoteWin.add(msg.PeersWindow)
293         m.incomingChannels <- c
294         return nil
295 }
296
297 func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) {
298         ch, err := m.openChannel(chanType, extra)
299         if err != nil {
300                 return nil, nil, err
301         }
302
303         return ch, ch.incomingRequests, nil
304 }
305
306 func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
307         ch := m.newChannel(chanType, channelOutbound, extra)
308
309         ch.maxIncomingPayload = channelMaxPacket
310
311         open := channelOpenMsg{
312                 ChanType:         chanType,
313                 PeersWindow:      ch.myWindow,
314                 MaxPacketSize:    ch.maxIncomingPayload,
315                 TypeSpecificData: extra,
316                 PeersId:          ch.localId,
317         }
318         if err := m.sendMessage(open); err != nil {
319                 return nil, err
320         }
321
322         switch msg := (<-ch.msg).(type) {
323         case *channelOpenConfirmMsg:
324                 return ch, nil
325         case *channelOpenFailureMsg:
326                 return nil, &OpenChannelError{msg.Reason, msg.Message}
327         default:
328                 return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg)
329         }
330 }