OSDN Git Service

Only allow ed25519 pubkeys when connecting (#1789)
[bytom/bytom.git] / p2p / connection / secret_connection.go
index 5bd8f9c..788a015 100644 (file)
@@ -1,9 +1,3 @@
-// Uses nacl's secret_box to encrypt a net.Conn.
-// It is (meant to be) an implementation of the STS protocol.
-// Note we do not (yet) assume that a remote peer's pubkey
-// is known ahead of time, and thus we are technically
-// still vulnerable to MITM. (TODO!)
-// See docs/sts-final.pdf for more info
 package connection
 
 import (
@@ -16,23 +10,30 @@ import (
        "net"
        "time"
 
+       log "github.com/sirupsen/logrus"
        "golang.org/x/crypto/nacl/box"
        "golang.org/x/crypto/nacl/secretbox"
        "golang.org/x/crypto/ripemd160"
 
        "github.com/tendermint/go-crypto"
-       "github.com/tendermint/go-wire"
+       wire "github.com/tendermint/go-wire"
        cmn "github.com/tendermint/tmlibs/common"
 )
 
-// 2 + 1024 == 1026 total frame size
-const dataLenSize = 2 // uint16 to describe the length, is <= dataMaxSize
-const dataMaxSize = 1024
-const totalFrameSize = dataMaxSize + dataLenSize
-const sealedFrameSize = totalFrameSize + secretbox.Overhead
-const authSigMsgSize = (32 + 1) + (64 + 1) // fixed size (length prefixed) byte arrays
+const (
+       dataLenSize     = 2 // uint16 to describe the length, is <= dataMaxSize
+       dataMaxSize     = 1024
+       totalFrameSize  = dataMaxSize + dataLenSize
+       sealedFrameSize = totalFrameSize + secretbox.Overhead
+       authSigMsgSize  = (32 + 1) + (64 + 1) // fixed size (length prefixed) byte arrays
+)
+
+type authSigMessage struct {
+       Key crypto.PubKey
+       Sig crypto.Signature
+}
 
-// Implements net.Conn
+// SecretConnection implements net.Conn
 type SecretConnection struct {
        conn       io.ReadWriteCloser
        recvBuffer []byte
@@ -42,12 +43,8 @@ type SecretConnection struct {
        shrSecret  *[32]byte // shared secret
 }
 
-// Performs handshake and returns a new authenticated SecretConnection.
-// Returns nil if error in handshake.
-// Caller should call conn.Close()
-// See docs/sts-final.pdf for more information.
+// MakeSecretConnection performs handshake and returns a new authenticated SecretConnection.
 func MakeSecretConnection(conn io.ReadWriteCloser, locPrivKey crypto.PrivKeyEd25519) (*SecretConnection, error) {
-
        locPubKey := locPrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519)
 
        // Generate ephemeral keys for perfect forward secrecy.
@@ -90,17 +87,52 @@ func MakeSecretConnection(conn io.ReadWriteCloser, locPrivKey crypto.PrivKeyEd25
        if err != nil {
                return nil, err
        }
+
        remPubKey, remSignature := authSigMsg.Key, authSigMsg.Sig
+       if _, ok := remPubKey.PubKeyInner.(crypto.PubKeyEd25519); !ok {
+               return nil, errors.New("peer sent a nil public key")
+       }
+
        if !remPubKey.VerifyBytes(challenge[:], remSignature) {
                return nil, errors.New("Challenge verification failed")
        }
 
-       // We've authorized.
        sc.remPubKey = remPubKey.Unwrap().(crypto.PubKeyEd25519)
        return sc, nil
 }
 
-// Returns authenticated remote pubkey
+// CONTRACT: data smaller than dataMaxSize is read atomically.
+func (sc *SecretConnection) Read(data []byte) (n int, err error) {
+       if 0 < len(sc.recvBuffer) {
+               n_ := copy(data, sc.recvBuffer)
+               sc.recvBuffer = sc.recvBuffer[n_:]
+               return
+       }
+
+       sealedFrame := make([]byte, sealedFrameSize)
+       if _, err = io.ReadFull(sc.conn, sealedFrame); err != nil {
+               return
+       }
+
+       // decrypt the frame
+       frame := make([]byte, totalFrameSize)
+       if _, ok := secretbox.Open(frame[:0], sealedFrame, sc.recvNonce, sc.shrSecret); !ok {
+               return n, errors.New("Failed to decrypt SecretConnection")
+       }
+
+       incr2Nonce(sc.recvNonce)
+       chunkLength := binary.BigEndian.Uint16(frame) // read the first two bytes
+       if chunkLength > dataMaxSize {
+               return 0, errors.New("chunkLength is greater than dataMaxSize")
+       }
+
+       chunk := frame[dataLenSize : dataLenSize+chunkLength]
+       n = copy(data, chunk)
+       sc.recvBuffer = chunk[n:]
+       return
+}
+
+// RemotePubKey returns authenticated remote pubkey
 func (sc *SecretConnection) RemotePubKey() crypto.PubKeyEd25519 {
        return sc.remPubKey
 }
@@ -109,8 +141,8 @@ func (sc *SecretConnection) RemotePubKey() crypto.PubKeyEd25519 {
 // CONTRACT: data smaller than dataMaxSize is read atomically.
 func (sc *SecretConnection) Write(data []byte) (n int, err error) {
        for 0 < len(data) {
-               var frame []byte = make([]byte, totalFrameSize)
                var chunk []byte
+               frame := make([]byte, totalFrameSize)
                if dataMaxSize < len(data) {
                        chunk = data[:dataMaxSize]
                        data = data[dataMaxSize:]
@@ -118,140 +150,89 @@ func (sc *SecretConnection) Write(data []byte) (n int, err error) {
                        chunk = data
                        data = nil
                }
-               chunkLength := len(chunk)
-               binary.BigEndian.PutUint16(frame, uint16(chunkLength))
+               binary.BigEndian.PutUint16(frame, uint16(len(chunk)))
                copy(frame[dataLenSize:], chunk)
 
                // encrypt the frame
-               var sealedFrame = make([]byte, sealedFrameSize)
+               sealedFrame := make([]byte, sealedFrameSize)
                secretbox.Seal(sealedFrame[:0], frame, sc.sendNonce, sc.shrSecret)
-               // fmt.Printf("secretbox.Seal(sealed:%X,sendNonce:%X,shrSecret:%X\n", sealedFrame, sc.sendNonce, sc.shrSecret)
                incr2Nonce(sc.sendNonce)
-               // end encryption
 
-               _, err := sc.conn.Write(sealedFrame)
-               if err != nil {
+               if _, err := sc.conn.Write(sealedFrame); err != nil {
                        return n, err
-               } else {
-                       n += len(chunk)
                }
+
+               n += len(chunk)
        }
        return
 }
 
-// CONTRACT: data smaller than dataMaxSize is read atomically.
-func (sc *SecretConnection) Read(data []byte) (n int, err error) {
-       if 0 < len(sc.recvBuffer) {
-               n_ := copy(data, sc.recvBuffer)
-               sc.recvBuffer = sc.recvBuffer[n_:]
-               return
-       }
-
-       sealedFrame := make([]byte, sealedFrameSize)
-       _, err = io.ReadFull(sc.conn, sealedFrame)
-       if err != nil {
-               return
-       }
-
-       // decrypt the frame
-       var frame = make([]byte, totalFrameSize)
-       // fmt.Printf("secretbox.Open(sealed:%X,recvNonce:%X,shrSecret:%X\n", sealedFrame, sc.recvNonce, sc.shrSecret)
-       _, ok := secretbox.Open(frame[:0], sealedFrame, sc.recvNonce, sc.shrSecret)
-       if !ok {
-               return n, errors.New("Failed to decrypt SecretConnection")
-       }
-       incr2Nonce(sc.recvNonce)
-       // end decryption
+// Close implements net.Conn
+func (sc *SecretConnection) Close() error { return sc.conn.Close() }
 
-       var chunkLength = binary.BigEndian.Uint16(frame) // read the first two bytes
-       if chunkLength > dataMaxSize {
-               return 0, errors.New("chunkLength is greater than dataMaxSize")
-       }
-       var chunk = frame[dataLenSize : dataLenSize+chunkLength]
+// LocalAddr implements net.Conn
+func (sc *SecretConnection) LocalAddr() net.Addr { return sc.conn.(net.Conn).LocalAddr() }
 
-       n = copy(data, chunk)
-       sc.recvBuffer = chunk[n:]
-       return
-}
+// RemoteAddr implements net.Conn
+func (sc *SecretConnection) RemoteAddr() net.Addr { return sc.conn.(net.Conn).RemoteAddr() }
 
-// Implements net.Conn
-func (sc *SecretConnection) Close() error                  { return sc.conn.Close() }
-func (sc *SecretConnection) LocalAddr() net.Addr           { return sc.conn.(net.Conn).LocalAddr() }
-func (sc *SecretConnection) RemoteAddr() net.Addr          { return sc.conn.(net.Conn).RemoteAddr() }
+// SetDeadline implements net.Conn
 func (sc *SecretConnection) SetDeadline(t time.Time) error { return sc.conn.(net.Conn).SetDeadline(t) }
+
+// SetReadDeadline implements net.Conn
 func (sc *SecretConnection) SetReadDeadline(t time.Time) error {
        return sc.conn.(net.Conn).SetReadDeadline(t)
 }
+
+// SetWriteDeadline implements net.Conn
 func (sc *SecretConnection) SetWriteDeadline(t time.Time) error {
        return sc.conn.(net.Conn).SetWriteDeadline(t)
 }
 
-func genEphKeys() (ephPub, ephPriv *[32]byte) {
-       var err error
-       ephPub, ephPriv, err = box.GenerateKey(crand.Reader)
-       if err != nil {
-               cmn.PanicCrisis("Could not generate ephemeral keypairs")
-       }
+func computeSharedSecret(remPubKey, locPrivKey *[32]byte) (shrSecret *[32]byte) {
+       shrSecret = new([32]byte)
+       box.Precompute(shrSecret, remPubKey, locPrivKey)
        return
 }
 
-func shareEphPubKey(conn io.ReadWriteCloser, locEphPub *[32]byte) (remEphPub *[32]byte, err error) {
-       var err1, err2 error
-
-       cmn.Parallel(
-               func() {
-                       _, err1 = conn.Write(locEphPub[:])
-               },
-               func() {
-                       remEphPub = new([32]byte)
-                       _, err2 = io.ReadFull(conn, remEphPub[:])
-               },
-       )
-
-       if err1 != nil {
-               return nil, err1
-       }
-       if err2 != nil {
-               return nil, err2
-       }
+func genChallenge(loPubKey, hiPubKey *[32]byte) (challenge *[32]byte) {
+       return hash32(append(loPubKey[:], hiPubKey[:]...))
+}
 
-       return remEphPub, nil
+// increment nonce big-endian by 2 with wraparound.
+func incr2Nonce(nonce *[24]byte) {
+       incrNonce(nonce)
+       incrNonce(nonce)
 }
 
-func computeSharedSecret(remPubKey, locPrivKey *[32]byte) (shrSecret *[32]byte) {
-       shrSecret = new([32]byte)
-       box.Precompute(shrSecret, remPubKey, locPrivKey)
-       return
+// increment nonce big-endian by 1 with wraparound.
+func incrNonce(nonce *[24]byte) {
+       for i := 23; 0 <= i; i-- {
+               nonce[i]++
+               if nonce[i] != 0 {
+                       return
+               }
+       }
 }
 
-func sort32(foo, bar *[32]byte) (lo, hi *[32]byte) {
-       if bytes.Compare(foo[:], bar[:]) < 0 {
-               lo = foo
-               hi = bar
-       } else {
-               lo = bar
-               hi = foo
+func genEphKeys() (ephPub, ephPriv *[32]byte) {
+       var err error
+       ephPub, ephPriv, err = box.GenerateKey(crand.Reader)
+       if err != nil {
+               log.Panic("Could not generate ephemeral keypairs")
        }
        return
 }
 
-func genNonces(loPubKey, hiPubKey *[32]byte, locIsLo bool) (recvNonce, sendNonce *[24]byte) {
+func genNonces(loPubKey, hiPubKey *[32]byte, locIsLo bool) (*[24]byte, *[24]byte) {
        nonce1 := hash24(append(loPubKey[:], hiPubKey[:]...))
        nonce2 := new([24]byte)
        copy(nonce2[:], nonce1[:])
        nonce2[len(nonce2)-1] ^= 0x01
        if locIsLo {
-               recvNonce = nonce1
-               sendNonce = nonce2
-       } else {
-               recvNonce = nonce2
-               sendNonce = nonce1
+               return nonce1, nonce2
        }
-       return
-}
-
-func genChallenge(loPubKey, hiPubKey *[32]byte) (challenge *[32]byte) {
-       return hash32(append(loPubKey[:], hiPubKey[:]...))
+       return nonce2, nonce1
 }
 
 func signChallenge(challenge *[32]byte, locPrivKey crypto.PrivKeyEd25519) (signature crypto.SignatureEd25519) {
@@ -259,11 +240,6 @@ func signChallenge(challenge *[32]byte, locPrivKey crypto.PrivKeyEd25519) (signa
        return
 }
 
-type authSigMessage struct {
-       Key crypto.PubKey
-       Sig crypto.Signature
-}
-
 func shareAuthSignature(sc *SecretConnection, pubKey crypto.PubKeyEd25519, signature crypto.SignatureEd25519) (*authSigMessage, error) {
        var recvMsg authSigMessage
        var err1, err2 error
@@ -281,7 +257,8 @@ func shareAuthSignature(sc *SecretConnection, pubKey crypto.PubKeyEd25519, signa
                        }
                        n := int(0) // not used.
                        recvMsg = wire.ReadBinary(authSigMessage{}, bytes.NewBuffer(readBuffer), authSigMsgSize, &n, &err2).(authSigMessage)
-               })
+               },
+       )
 
        if err1 != nil {
                return nil, err1
@@ -289,15 +266,37 @@ func shareAuthSignature(sc *SecretConnection, pubKey crypto.PubKeyEd25519, signa
        if err2 != nil {
                return nil, err2
        }
-
        return &recvMsg, nil
 }
 
-func verifyChallengeSignature(challenge *[32]byte, remPubKey crypto.PubKeyEd25519, remSignature crypto.SignatureEd25519) bool {
-       return remPubKey.VerifyBytes(challenge[:], remSignature.Wrap())
+func shareEphPubKey(conn io.ReadWriteCloser, locEphPub *[32]byte) (remEphPub *[32]byte, err error) {
+       var err1, err2 error
+
+       cmn.Parallel(
+               func() {
+                       _, err1 = conn.Write(locEphPub[:])
+               },
+               func() {
+                       remEphPub = new([32]byte)
+                       _, err2 = io.ReadFull(conn, remEphPub[:])
+               },
+       )
+
+       if err1 != nil {
+               return nil, err1
+       }
+       if err2 != nil {
+               return nil, err2
+       }
+       return remEphPub, nil
 }
 
-//--------------------------------------------------------------------------------
+func sort32(foo, bar *[32]byte) (*[32]byte, *[32]byte) {
+       if bytes.Compare(foo[:], bar[:]) < 0 {
+               return foo, bar
+       }
+       return bar, foo
+}
 
 // sha256
 func hash32(input []byte) (res *[32]byte) {
@@ -318,29 +317,3 @@ func hash24(input []byte) (res *[24]byte) {
        copy(res[:], resSlice)
        return
 }
-
-// ripemd160
-func hash20(input []byte) (res *[20]byte) {
-       hasher := ripemd160.New()
-       hasher.Write(input) // does not error
-       resSlice := hasher.Sum(nil)
-       res = new([20]byte)
-       copy(res[:], resSlice)
-       return
-}
-
-// increment nonce big-endian by 2 with wraparound.
-func incr2Nonce(nonce *[24]byte) {
-       incrNonce(nonce)
-       incrNonce(nonce)
-}
-
-// increment nonce big-endian by 1 with wraparound.
-func incrNonce(nonce *[24]byte) {
-       for i := 23; 0 <= i; i-- {
-               nonce[i] += 1
-               if nonce[i] != 0 {
-                       return
-               }
-       }
-}