OSDN Git Service

rename (#465)
[bytom/vapor.git] / p2p / discover / dht / udp.go
index 94a2cf2..5244d8c 100644 (file)
@@ -14,40 +14,32 @@ import (
        log "github.com/sirupsen/logrus"
        "github.com/tendermint/go-wire"
 
-       "github.com/vapor/common"
-       cfg "github.com/vapor/config"
-       "github.com/vapor/crypto"
-       "github.com/vapor/crypto/ed25519"
-       "github.com/vapor/p2p/netutil"
-       "github.com/vapor/version"
+       "github.com/bytom/vapor/common"
+       cfg "github.com/bytom/vapor/config"
+       "github.com/bytom/vapor/crypto"
+       "github.com/bytom/vapor/p2p/netutil"
+       "github.com/bytom/vapor/p2p/signlib"
+       "github.com/bytom/vapor/version"
 )
 
 const (
-       Version   = 4
+       //Version dht discover protocol version
+       Version   = 5
        logModule = "discover"
 )
 
 // Errors
 var (
-       errPacketTooSmall   = errors.New("too small")
-       errBadPrefix        = errors.New("bad prefix")
-       errExpired          = errors.New("expired")
-       errUnsolicitedReply = errors.New("unsolicited reply")
-       errUnknownNode      = errors.New("unknown node")
-       errTimeout          = errors.New("RPC timeout")
-       errClockWarp        = errors.New("reply deadline too far in the future")
-       errClosed           = errors.New("socket closed")
+       errPacketTooSmall = errors.New("too small")
+       errPrefixMismatch = errors.New("prefix mismatch")
+       errNetIDMismatch  = errors.New("network id mismatch")
+       errPacketType     = errors.New("unknown packet type")
 )
 
 // Timeouts
 const (
        respTimeout = 1 * time.Second
-       queryDelay  = 1000 * time.Millisecond
        expiration  = 20 * time.Second
-
-       ntpFailureThreshold = 32               // Continuous timeouts after which to check NTP
-       ntpWarningCooldown  = 10 * time.Minute // Minimum amount of time to pass before repeating NTP warning
-       driftThreshold      = 10 * time.Second // Allowed clock drift before warning user
 )
 
 // ReadPacket is sent to the unhandled channel when it could not be processed
@@ -158,11 +150,10 @@ type (
 )
 
 var (
-       versionPrefix     = []byte("bytom discovery")
-       versionPrefixSize = len(versionPrefix)
-       nodeIDSize        = 32
-       sigSize           = 520 / 8
-       headSize          = versionPrefixSize + nodeIDSize + sigSize // space of packet frame data
+       netIDSize  = 8
+       nodeIDSize = 32
+       sigSize    = 520 / 8
+       headSize   = netIDSize + nodeIDSize + sigSize // space of packet frame data
 )
 
 // Neighbors replies are sent across multiple packets to
@@ -254,14 +245,16 @@ type netWork interface {
 
 // udp implements the RPC protocol.
 type udp struct {
-       conn        conn
-       priv        ed25519.PrivateKey
+       conn conn
+       priv signlib.PrivKey
+       //netID used to isolate subnets
+       netID       uint64
        ourEndpoint rpcEndpoint
-       //nat         nat.Interface
-       net netWork
+       net         netWork
 }
 
-func NewDiscover(config *cfg.Config, priv ed25519.PrivateKey, port uint16) (*Network, error) {
+//NewDiscover create new dht discover
+func NewDiscover(config *cfg.Config, privKey signlib.PrivKey, port uint16, netID uint64) (*Network, error) {
        addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort("0.0.0.0", strconv.FormatUint(uint64(port), 10)))
        if err != nil {
                return nil, err
@@ -273,7 +266,7 @@ func NewDiscover(config *cfg.Config, priv ed25519.PrivateKey, port uint16) (*Net
        }
 
        realaddr := conn.LocalAddr().(*net.UDPAddr)
-       ntab, err := ListenUDP(priv, conn, realaddr, path.Join(config.DBDir(), "discover"), nil)
+       ntab, err := ListenUDP(privKey, conn, realaddr, path.Join(config.DBDir(), "discover"), nil, netID)
        if err != nil {
                return nil, err
        }
@@ -302,13 +295,13 @@ func NewDiscover(config *cfg.Config, priv ed25519.PrivateKey, port uint16) (*Net
 }
 
 // ListenUDP returns a new table that listens for UDP packets on laddr.
-func ListenUDP(priv ed25519.PrivateKey, conn conn, realaddr *net.UDPAddr, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) {
-       transport, err := listenUDP(priv, conn, realaddr)
+func ListenUDP(privKey signlib.PrivKey, conn conn, realaddr *net.UDPAddr, nodeDBPath string, netrestrict *netutil.Netlist, netID uint64) (*Network, error) {
+       transport, err := listenUDP(privKey, conn, realaddr, netID)
        if err != nil {
                return nil, err
        }
 
-       net, err := newNetwork(transport, priv.Public(), nodeDBPath, netrestrict)
+       net, err := newNetwork(transport, privKey.XPub(), nodeDBPath, netrestrict)
        if err != nil {
                return nil, err
        }
@@ -318,8 +311,8 @@ func ListenUDP(priv ed25519.PrivateKey, conn conn, realaddr *net.UDPAddr, nodeDB
        return net, nil
 }
 
-func listenUDP(priv ed25519.PrivateKey, conn conn, realaddr *net.UDPAddr) (*udp, error) {
-       return &udp{conn: conn, priv: priv, ourEndpoint: makeEndpoint(realaddr, uint16(realaddr.Port))}, nil
+func listenUDP(priv signlib.PrivKey, conn conn, realaddr *net.UDPAddr, netID uint64) (*udp, error) {
+       return &udp{conn: conn, priv: priv, netID: netID, ourEndpoint: makeEndpoint(realaddr, uint16(realaddr.Port))}, nil
 }
 
 func (t *udp) localAddr() *net.UDPAddr {
@@ -368,7 +361,7 @@ func (t *udp) sendNeighbours(remote *Node, results []*Node) {
 
 func (t *udp) sendFindnodeHash(remote *Node, target common.Hash) {
        t.sendPacket(remote.ID, remote.addr(), byte(findnodeHashPacket), findnodeHash{
-               Target:     common.Hash(target),
+               Target:     target,
                Expiration: uint64(time.Now().Add(expiration).Unix()),
        })
 }
@@ -400,7 +393,7 @@ func (t *udp) sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node)
 }
 
 func (t *udp) sendPacket(toid NodeID, toaddr *net.UDPAddr, ptype byte, req interface{}) (hash []byte, err error) {
-       packet, hash, err := encodePacket(t.priv, ptype, req)
+       packet, hash, err := encodePacket(t.priv, ptype, req, t.netID)
        if err != nil {
                return hash, err
        }
@@ -414,7 +407,7 @@ func (t *udp) sendPacket(toid NodeID, toaddr *net.UDPAddr, ptype byte, req inter
 // zeroed padding space for encodePacket.
 var headSpace = make([]byte, headSize)
 
-func encodePacket(priv ed25519.PrivateKey, ptype byte, req interface{}) (p, hash []byte, err error) {
+func encodePacket(privKey signlib.PrivKey, ptype byte, req interface{}, netID uint64) (p, hash []byte, err error) {
        b := new(bytes.Buffer)
        b.Write(headSpace)
        b.WriteByte(ptype)
@@ -425,13 +418,14 @@ func encodePacket(priv ed25519.PrivateKey, ptype byte, req interface{}) (p, hash
                return nil, nil, err
        }
        packet := b.Bytes()
-       nodeID := priv.Public()
-       sig := ed25519.Sign(priv, common.BytesToHash(packet[headSize:]).Bytes())
-       copy(packet, versionPrefix)
-       copy(packet[versionPrefixSize:], nodeID[:])
-       copy(packet[versionPrefixSize+nodeIDSize:], sig)
-
-       hash = common.BytesToHash(packet[versionPrefixSize:]).Bytes()
+       nodeID := privKey.XPub()
+       sig := privKey.Sign(common.BytesToHash(packet[headSize:]).Bytes())
+       id := []byte(strconv.FormatUint(netID, 16))
+       copy(packet[:], id[:])
+       copy(packet[netIDSize:], nodeID[:nodeIDSize])
+       copy(packet[netIDSize+nodeIDSize:], sig)
+
+       hash = common.BytesToHash(packet[:]).Bytes()
        return packet, hash, nil
 }
 
@@ -454,32 +448,39 @@ func (t *udp) readLoop() {
                        log.WithFields(log.Fields{"module": logModule, "error": err}).Debug("Read error")
                        return
                }
-               t.handlePacket(from, buf[:nbytes])
+               if err := t.handlePacket(from, buf[:nbytes]); err != nil {
+                       log.WithFields(log.Fields{"module": logModule, "from": from, "error": err}).Error("handle packet err")
+               }
        }
 }
 
 func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error {
        pkt := ingressPacket{remoteAddr: from}
-       if err := decodePacket(buf, &pkt); err != nil {
-               log.WithFields(log.Fields{"module": logModule, "from": from, "error": err}).Error("Bad packet")
+       if err := decodePacket(buf, &pkt, t.netID); err != nil {
                return err
        }
        t.net.reqReadPacket(pkt)
        return nil
 }
 
-func decodePacket(buffer []byte, pkt *ingressPacket) error {
+func (t *udp) getNetID() uint64 {
+       return t.netID
+}
+
+func decodePacket(buffer []byte, pkt *ingressPacket, netID uint64) error {
        if len(buffer) < headSize+1 {
                return errPacketTooSmall
        }
        buf := make([]byte, len(buffer))
        copy(buf, buffer)
-       prefix, fromID, sigdata := buf[:versionPrefixSize], buf[versionPrefixSize:versionPrefixSize+nodeIDSize], buf[headSize:]
-       if !bytes.Equal(prefix, versionPrefix) {
-               return errBadPrefix
+       fromID, sigdata := buf[netIDSize:netIDSize+nodeIDSize], buf[headSize:]
+
+       if !bytes.Equal(buf[:netIDSize], []byte(strconv.FormatUint(netID, 16))[:netIDSize]) {
+               return errNetIDMismatch
        }
+
        pkt.rawData = buf
-       pkt.hash = common.BytesToHash(buf[versionPrefixSize:]).Bytes()
+       pkt.hash = common.BytesToHash(buf[:]).Bytes()
        pkt.remoteID = ByteID(fromID)
        switch pkt.ev = nodeEvent(sigdata[0]); pkt.ev {
        case pingPacket:
@@ -499,7 +500,7 @@ func decodePacket(buffer []byte, pkt *ingressPacket) error {
        case topicNodesPacket:
                pkt.data = new(topicNodes)
        default:
-               return fmt.Errorf("unknown packet type: %d", sigdata[0])
+               return errPacketType
        }
        var err error
        wire.ReadJSON(pkt.data, sigdata[1:], &err)