log "github.com/sirupsen/logrus"
"github.com/tendermint/go-wire"
- "github.com/vapor/common"
- cfg "github.com/vapor/config"
- "github.com/vapor/crypto"
- "github.com/vapor/crypto/ed25519"
- "github.com/vapor/p2p/netutil"
- "github.com/vapor/version"
+ "github.com/bytom/vapor/common"
+ cfg "github.com/bytom/vapor/config"
+ "github.com/bytom/vapor/crypto"
+ "github.com/bytom/vapor/p2p/netutil"
+ "github.com/bytom/vapor/p2p/signlib"
+ "github.com/bytom/vapor/version"
)
const (
- Version = 4
+ //Version dht discover protocol version
+ Version = 5
logModule = "discover"
)
// Errors
var (
- errPacketTooSmall = errors.New("too small")
- errBadPrefix = errors.New("bad prefix")
- errExpired = errors.New("expired")
- errUnsolicitedReply = errors.New("unsolicited reply")
- errUnknownNode = errors.New("unknown node")
- errTimeout = errors.New("RPC timeout")
- errClockWarp = errors.New("reply deadline too far in the future")
- errClosed = errors.New("socket closed")
+ errPacketTooSmall = errors.New("too small")
+ errPrefixMismatch = errors.New("prefix mismatch")
+ errNetIDMismatch = errors.New("network id mismatch")
+ errPacketType = errors.New("unknown packet type")
)
// Timeouts
const (
respTimeout = 1 * time.Second
- queryDelay = 1000 * time.Millisecond
expiration = 20 * time.Second
-
- ntpFailureThreshold = 32 // Continuous timeouts after which to check NTP
- ntpWarningCooldown = 10 * time.Minute // Minimum amount of time to pass before repeating NTP warning
- driftThreshold = 10 * time.Second // Allowed clock drift before warning user
)
// ReadPacket is sent to the unhandled channel when it could not be processed
)
var (
- versionPrefix = []byte("bytom discovery")
- versionPrefixSize = len(versionPrefix)
- nodeIDSize = 32
- sigSize = 520 / 8
- headSize = versionPrefixSize + nodeIDSize + sigSize // space of packet frame data
+ netIDSize = 8
+ nodeIDSize = 32
+ sigSize = 520 / 8
+ headSize = netIDSize + nodeIDSize + sigSize // space of packet frame data
)
// Neighbors replies are sent across multiple packets to
// udp implements the RPC protocol.
type udp struct {
- conn conn
- priv ed25519.PrivateKey
+ conn conn
+ priv signlib.PrivKey
+ //netID used to isolate subnets
+ netID uint64
ourEndpoint rpcEndpoint
- //nat nat.Interface
- net netWork
+ net netWork
}
-func NewDiscover(config *cfg.Config, priv ed25519.PrivateKey, port uint16) (*Network, error) {
+//NewDiscover create new dht discover
+func NewDiscover(config *cfg.Config, privKey signlib.PrivKey, port uint16, netID uint64) (*Network, error) {
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort("0.0.0.0", strconv.FormatUint(uint64(port), 10)))
if err != nil {
return nil, err
}
realaddr := conn.LocalAddr().(*net.UDPAddr)
- ntab, err := ListenUDP(priv, conn, realaddr, path.Join(config.DBDir(), "discover"), nil)
+ ntab, err := ListenUDP(privKey, conn, realaddr, path.Join(config.DBDir(), "discover"), nil, netID)
if err != nil {
return nil, err
}
}
// ListenUDP returns a new table that listens for UDP packets on laddr.
-func ListenUDP(priv ed25519.PrivateKey, conn conn, realaddr *net.UDPAddr, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) {
- transport, err := listenUDP(priv, conn, realaddr)
+func ListenUDP(privKey signlib.PrivKey, conn conn, realaddr *net.UDPAddr, nodeDBPath string, netrestrict *netutil.Netlist, netID uint64) (*Network, error) {
+ transport, err := listenUDP(privKey, conn, realaddr, netID)
if err != nil {
return nil, err
}
- net, err := newNetwork(transport, priv.Public(), nodeDBPath, netrestrict)
+ net, err := newNetwork(transport, privKey.XPub(), nodeDBPath, netrestrict)
if err != nil {
return nil, err
}
return net, nil
}
-func listenUDP(priv ed25519.PrivateKey, conn conn, realaddr *net.UDPAddr) (*udp, error) {
- return &udp{conn: conn, priv: priv, ourEndpoint: makeEndpoint(realaddr, uint16(realaddr.Port))}, nil
+func listenUDP(priv signlib.PrivKey, conn conn, realaddr *net.UDPAddr, netID uint64) (*udp, error) {
+ return &udp{conn: conn, priv: priv, netID: netID, ourEndpoint: makeEndpoint(realaddr, uint16(realaddr.Port))}, nil
}
func (t *udp) localAddr() *net.UDPAddr {
func (t *udp) sendFindnodeHash(remote *Node, target common.Hash) {
t.sendPacket(remote.ID, remote.addr(), byte(findnodeHashPacket), findnodeHash{
- Target: common.Hash(target),
+ Target: target,
Expiration: uint64(time.Now().Add(expiration).Unix()),
})
}
}
func (t *udp) sendPacket(toid NodeID, toaddr *net.UDPAddr, ptype byte, req interface{}) (hash []byte, err error) {
- packet, hash, err := encodePacket(t.priv, ptype, req)
+ packet, hash, err := encodePacket(t.priv, ptype, req, t.netID)
if err != nil {
return hash, err
}
// zeroed padding space for encodePacket.
var headSpace = make([]byte, headSize)
-func encodePacket(priv ed25519.PrivateKey, ptype byte, req interface{}) (p, hash []byte, err error) {
+func encodePacket(privKey signlib.PrivKey, ptype byte, req interface{}, netID uint64) (p, hash []byte, err error) {
b := new(bytes.Buffer)
b.Write(headSpace)
b.WriteByte(ptype)
return nil, nil, err
}
packet := b.Bytes()
- nodeID := priv.Public()
- sig := ed25519.Sign(priv, common.BytesToHash(packet[headSize:]).Bytes())
- copy(packet, versionPrefix)
- copy(packet[versionPrefixSize:], nodeID[:])
- copy(packet[versionPrefixSize+nodeIDSize:], sig)
-
- hash = common.BytesToHash(packet[versionPrefixSize:]).Bytes()
+ nodeID := privKey.XPub()
+ sig := privKey.Sign(common.BytesToHash(packet[headSize:]).Bytes())
+ id := []byte(strconv.FormatUint(netID, 16))
+ copy(packet[:], id[:])
+ copy(packet[netIDSize:], nodeID[:nodeIDSize])
+ copy(packet[netIDSize+nodeIDSize:], sig)
+
+ hash = common.BytesToHash(packet[:]).Bytes()
return packet, hash, nil
}
log.WithFields(log.Fields{"module": logModule, "error": err}).Debug("Read error")
return
}
- t.handlePacket(from, buf[:nbytes])
+ if err := t.handlePacket(from, buf[:nbytes]); err != nil {
+ log.WithFields(log.Fields{"module": logModule, "from": from, "error": err}).Error("handle packet err")
+ }
}
}
func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error {
pkt := ingressPacket{remoteAddr: from}
- if err := decodePacket(buf, &pkt); err != nil {
- log.WithFields(log.Fields{"module": logModule, "from": from, "error": err}).Error("Bad packet")
+ if err := decodePacket(buf, &pkt, t.netID); err != nil {
return err
}
t.net.reqReadPacket(pkt)
return nil
}
-func decodePacket(buffer []byte, pkt *ingressPacket) error {
+func (t *udp) getNetID() uint64 {
+ return t.netID
+}
+
+func decodePacket(buffer []byte, pkt *ingressPacket, netID uint64) error {
if len(buffer) < headSize+1 {
return errPacketTooSmall
}
buf := make([]byte, len(buffer))
copy(buf, buffer)
- prefix, fromID, sigdata := buf[:versionPrefixSize], buf[versionPrefixSize:versionPrefixSize+nodeIDSize], buf[headSize:]
- if !bytes.Equal(prefix, versionPrefix) {
- return errBadPrefix
+ fromID, sigdata := buf[netIDSize:netIDSize+nodeIDSize], buf[headSize:]
+
+ if !bytes.Equal(buf[:netIDSize], []byte(strconv.FormatUint(netID, 16))[:netIDSize]) {
+ return errNetIDMismatch
}
+
pkt.rawData = buf
- pkt.hash = common.BytesToHash(buf[versionPrefixSize:]).Bytes()
+ pkt.hash = common.BytesToHash(buf[:]).Bytes()
pkt.remoteID = ByteID(fromID)
switch pkt.ev = nodeEvent(sigdata[0]); pkt.ev {
case pingPacket:
case topicNodesPacket:
pkt.data = new(topicNodes)
default:
- return fmt.Errorf("unknown packet type: %d", sigdata[0])
+ return errPacketType
}
var err error
wire.ReadJSON(pkt.data, sigdata[1:], &err)