// Copyright 2011 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package ssh import ( "encoding/binary" "errors" "fmt" "io" "log" "sync" ) const ( minPacketLength = 9 // channelMaxPacket contains the maximum number of bytes that will be // sent in a single packet. As per RFC 4253, section 6.1, 32k is also // the minimum. channelMaxPacket = 1 << 15 // We follow OpenSSH here. channelWindowSize = 64 * channelMaxPacket ) // NewChannel represents an incoming request to a channel. It must either be // accepted for use by calling Accept, or rejected by calling Reject. type NewChannel interface { // Accept accepts the channel creation request. It returns the Channel // and a Go channel containing SSH requests. The Go channel must be // serviced otherwise the Channel will hang. Accept() (Channel, <-chan *Request, error) // Reject rejects the channel creation request. After calling // this, no other methods on the Channel may be called. Reject(reason RejectionReason, message string) error // ChannelType returns the type of the channel, as supplied by the // client. ChannelType() string // ExtraData returns the arbitrary payload for this channel, as supplied // by the client. This data is specific to the channel type. ExtraData() []byte } // A Channel is an ordered, reliable, flow-controlled, duplex stream // that is multiplexed over an SSH connection. type Channel interface { // Read reads up to len(data) bytes from the channel. Read(data []byte) (int, error) // Write writes len(data) bytes to the channel. Write(data []byte) (int, error) // Close signals end of channel use. No data may be sent after this // call. Close() error // CloseWrite signals the end of sending in-band // data. Requests may still be sent, and the other side may // still send data CloseWrite() error // SendRequest sends a channel request. If wantReply is true, // it will wait for a reply and return the result as a // boolean, otherwise the return value will be false. Channel // requests are out-of-band messages so they may be sent even // if the data stream is closed or blocked by flow control. // If the channel is closed before a reply is returned, io.EOF // is returned. SendRequest(name string, wantReply bool, payload []byte) (bool, error) // Stderr returns an io.ReadWriter that writes to this channel // with the extended data type set to stderr. Stderr may // safely be read and written from a different goroutine than // Read and Write respectively. Stderr() io.ReadWriter } // Request is a request sent outside of the normal stream of // data. Requests can either be specific to an SSH channel, or they // can be global. type Request struct { Type string WantReply bool Payload []byte ch *channel mux *mux } // Reply sends a response to a request. It must be called for all requests // where WantReply is true and is a no-op otherwise. The payload argument is // ignored for replies to channel-specific requests. func (r *Request) Reply(ok bool, payload []byte) error { if !r.WantReply { return nil } if r.ch == nil { return r.mux.ackRequest(ok, payload) } return r.ch.ackRequest(ok) } // RejectionReason is an enumeration used when rejecting channel creation // requests. See RFC 4254, section 5.1. type RejectionReason uint32 const ( Prohibited RejectionReason = iota + 1 ConnectionFailed UnknownChannelType ResourceShortage ) // String converts the rejection reason to human readable form. func (r RejectionReason) String() string { switch r { case Prohibited: return "administratively prohibited" case ConnectionFailed: return "connect failed" case UnknownChannelType: return "unknown channel type" case ResourceShortage: return "resource shortage" } return fmt.Sprintf("unknown reason %d", int(r)) } func min(a uint32, b int) uint32 { if a < uint32(b) { return a } return uint32(b) } type channelDirection uint8 const ( channelInbound channelDirection = iota channelOutbound ) // channel is an implementation of the Channel interface that works // with the mux class. type channel struct { // R/O after creation chanType string extraData []byte localId, remoteId uint32 // maxIncomingPayload and maxRemotePayload are the maximum // payload sizes of normal and extended data packets for // receiving and sending, respectively. The wire packet will // be 9 or 13 bytes larger (excluding encryption overhead). maxIncomingPayload uint32 maxRemotePayload uint32 mux *mux // decided is set to true if an accept or reject message has been sent // (for outbound channels) or received (for inbound channels). decided bool // direction contains either channelOutbound, for channels created // locally, or channelInbound, for channels created by the peer. direction channelDirection // Pending internal channel messages. msg chan interface{} // Since requests have no ID, there can be only one request // with WantReply=true outstanding. This lock is held by a // goroutine that has such an outgoing request pending. sentRequestMu sync.Mutex incomingRequests chan *Request sentEOF bool // thread-safe data remoteWin window pending *buffer extPending *buffer // windowMu protects myWindow, the flow-control window. windowMu sync.Mutex myWindow uint32 // writeMu serializes calls to mux.conn.writePacket() and // protects sentClose and packetPool. This mutex must be // different from windowMu, as writePacket can block if there // is a key exchange pending. writeMu sync.Mutex sentClose bool // packetPool has a buffer for each extended channel ID to // save allocations during writes. packetPool map[uint32][]byte } // writePacket sends a packet. If the packet is a channel close, it updates // sentClose. This method takes the lock c.writeMu. func (c *channel) writePacket(packet []byte) error { c.writeMu.Lock() if c.sentClose { c.writeMu.Unlock() return io.EOF } c.sentClose = (packet[0] == msgChannelClose) err := c.mux.conn.writePacket(packet) c.writeMu.Unlock() return err } func (c *channel) sendMessage(msg interface{}) error { if debugMux { log.Printf("send(%d): %#v", c.mux.chanList.offset, msg) } p := Marshal(msg) binary.BigEndian.PutUint32(p[1:], c.remoteId) return c.writePacket(p) } // WriteExtended writes data to a specific extended stream. These streams are // used, for example, for stderr. func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) { if c.sentEOF { return 0, io.EOF } // 1 byte message type, 4 bytes remoteId, 4 bytes data length opCode := byte(msgChannelData) headerLength := uint32(9) if extendedCode > 0 { headerLength += 4 opCode = msgChannelExtendedData } c.writeMu.Lock() packet := c.packetPool[extendedCode] // We don't remove the buffer from packetPool, so // WriteExtended calls from different goroutines will be // flagged as errors by the race detector. c.writeMu.Unlock() for len(data) > 0 { space := min(c.maxRemotePayload, len(data)) if space, err = c.remoteWin.reserve(space); err != nil { return n, err } if want := headerLength + space; uint32(cap(packet)) < want { packet = make([]byte, want) } else { packet = packet[:want] } todo := data[:space] packet[0] = opCode binary.BigEndian.PutUint32(packet[1:], c.remoteId) if extendedCode > 0 { binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode)) } binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo))) copy(packet[headerLength:], todo) if err = c.writePacket(packet); err != nil { return n, err } n += len(todo) data = data[len(todo):] } c.writeMu.Lock() c.packetPool[extendedCode] = packet c.writeMu.Unlock() return n, err } func (c *channel) handleData(packet []byte) error { headerLen := 9 isExtendedData := packet[0] == msgChannelExtendedData if isExtendedData { headerLen = 13 } if len(packet) < headerLen { // malformed data packet return parseError(packet[0]) } var extended uint32 if isExtendedData { extended = binary.BigEndian.Uint32(packet[5:]) } length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen]) if length == 0 { return nil } if length > c.maxIncomingPayload { // TODO(hanwen): should send Disconnect? return errors.New("ssh: incoming packet exceeds maximum payload size") } data := packet[headerLen:] if length != uint32(len(data)) { return errors.New("ssh: wrong packet length") } c.windowMu.Lock() if c.myWindow < length { c.windowMu.Unlock() // TODO(hanwen): should send Disconnect with reason? return errors.New("ssh: remote side wrote too much") } c.myWindow -= length c.windowMu.Unlock() if extended == 1 { c.extPending.write(data) } else if extended > 0 { // discard other extended data. } else { c.pending.write(data) } return nil } func (c *channel) adjustWindow(n uint32) error { c.windowMu.Lock() // Since myWindow is managed on our side, and can never exceed // the initial window setting, we don't worry about overflow. c.myWindow += uint32(n) c.windowMu.Unlock() return c.sendMessage(windowAdjustMsg{ AdditionalBytes: uint32(n), }) } func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) { switch extended { case 1: n, err = c.extPending.Read(data) case 0: n, err = c.pending.Read(data) default: return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended) } if n > 0 { err = c.adjustWindow(uint32(n)) // sendWindowAdjust can return io.EOF if the remote // peer has closed the connection, however we want to // defer forwarding io.EOF to the caller of Read until // the buffer has been drained. if n > 0 && err == io.EOF { err = nil } } return n, err } func (c *channel) close() { c.pending.eof() c.extPending.eof() close(c.msg) close(c.incomingRequests) c.writeMu.Lock() // This is not necessary for a normal channel teardown, but if // there was another error, it is. c.sentClose = true c.writeMu.Unlock() // Unblock writers. c.remoteWin.close() } // responseMessageReceived is called when a success or failure message is // received on a channel to check that such a message is reasonable for the // given channel. func (c *channel) responseMessageReceived() error { if c.direction == channelInbound { return errors.New("ssh: channel response message received on inbound channel") } if c.decided { return errors.New("ssh: duplicate response received for channel") } c.decided = true return nil } func (c *channel) handlePacket(packet []byte) error { switch packet[0] { case msgChannelData, msgChannelExtendedData: return c.handleData(packet) case msgChannelClose: c.sendMessage(channelCloseMsg{PeersId: c.remoteId}) c.mux.chanList.remove(c.localId) c.close() return nil case msgChannelEOF: // RFC 4254 is mute on how EOF affects dataExt messages but // it is logical to signal EOF at the same time. c.extPending.eof() c.pending.eof() return nil } decoded, err := decode(packet) if err != nil { return err } switch msg := decoded.(type) { case *channelOpenFailureMsg: if err := c.responseMessageReceived(); err != nil { return err } c.mux.chanList.remove(msg.PeersId) c.msg <- msg case *channelOpenConfirmMsg: if err := c.responseMessageReceived(); err != nil { return err } if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize) } c.remoteId = msg.MyId c.maxRemotePayload = msg.MaxPacketSize c.remoteWin.add(msg.MyWindow) c.msg <- msg case *windowAdjustMsg: if !c.remoteWin.add(msg.AdditionalBytes) { return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes) } case *channelRequestMsg: req := Request{ Type: msg.Request, WantReply: msg.WantReply, Payload: msg.RequestSpecificData, ch: c, } c.incomingRequests <- &req default: c.msg <- msg } return nil } func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel { ch := &channel{ remoteWin: window{Cond: newCond()}, myWindow: channelWindowSize, pending: newBuffer(), extPending: newBuffer(), direction: direction, incomingRequests: make(chan *Request, chanSize), msg: make(chan interface{}, chanSize), chanType: chanType, extraData: extraData, mux: m, packetPool: make(map[uint32][]byte), } ch.localId = m.chanList.add(ch) return ch } var errUndecided = errors.New("ssh: must Accept or Reject channel") var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once") type extChannel struct { code uint32 ch *channel } func (e *extChannel) Write(data []byte) (n int, err error) { return e.ch.WriteExtended(data, e.code) } func (e *extChannel) Read(data []byte) (n int, err error) { return e.ch.ReadExtended(data, e.code) } func (c *channel) Accept() (Channel, <-chan *Request, error) { if c.decided { return nil, nil, errDecidedAlready } c.maxIncomingPayload = channelMaxPacket confirm := channelOpenConfirmMsg{ PeersId: c.remoteId, MyId: c.localId, MyWindow: c.myWindow, MaxPacketSize: c.maxIncomingPayload, } c.decided = true if err := c.sendMessage(confirm); err != nil { return nil, nil, err } return c, c.incomingRequests, nil } func (ch *channel) Reject(reason RejectionReason, message string) error { if ch.decided { return errDecidedAlready } reject := channelOpenFailureMsg{ PeersId: ch.remoteId, Reason: reason, Message: message, Language: "en", } ch.decided = true return ch.sendMessage(reject) } func (ch *channel) Read(data []byte) (int, error) { if !ch.decided { return 0, errUndecided } return ch.ReadExtended(data, 0) } func (ch *channel) Write(data []byte) (int, error) { if !ch.decided { return 0, errUndecided } return ch.WriteExtended(data, 0) } func (ch *channel) CloseWrite() error { if !ch.decided { return errUndecided } ch.sentEOF = true return ch.sendMessage(channelEOFMsg{ PeersId: ch.remoteId}) } func (ch *channel) Close() error { if !ch.decided { return errUndecided } return ch.sendMessage(channelCloseMsg{ PeersId: ch.remoteId}) } // Extended returns an io.ReadWriter that sends and receives data on the given, // SSH extended stream. Such streams are used, for example, for stderr. func (ch *channel) Extended(code uint32) io.ReadWriter { if !ch.decided { return nil } return &extChannel{code, ch} } func (ch *channel) Stderr() io.ReadWriter { return ch.Extended(1) } func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { if !ch.decided { return false, errUndecided } if wantReply { ch.sentRequestMu.Lock() defer ch.sentRequestMu.Unlock() } msg := channelRequestMsg{ PeersId: ch.remoteId, Request: name, WantReply: wantReply, RequestSpecificData: payload, } if err := ch.sendMessage(msg); err != nil { return false, err } if wantReply { m, ok := (<-ch.msg) if !ok { return false, io.EOF } switch m.(type) { case *channelRequestFailureMsg: return false, nil case *channelRequestSuccessMsg: return true, nil default: return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m) } } return false, nil } // ackRequest either sends an ack or nack to the channel request. func (ch *channel) ackRequest(ok bool) error { if !ch.decided { return errUndecided } var msg interface{} if !ok { msg = channelRequestFailureMsg{ PeersId: ch.remoteId, } } else { msg = channelRequestSuccessMsg{ PeersId: ch.remoteId, } } return ch.sendMessage(msg) } func (ch *channel) ChannelType() string { return ch.chanType } func (ch *channel) ExtraData() []byte { return ch.extraData }