OSDN Git Service

Add subnet having the same network ID isolation function (#57)
authoryahtoo <yahtoo.ma@gmail.com>
Tue, 14 May 2019 09:19:14 +0000 (17:19 +0800)
committerPaladz <yzhu101@uottawa.ca>
Tue, 14 May 2019 09:19:14 +0000 (17:19 +0800)
* Network isolation main chain and side chain

* Add test cases

* Del unused code

* Uniform variable name

* Opz log

* Generate unique network ID using genesisBlock hash

* Change Bech32HRPSegwit to 'vp'

* Change netID generate method

12 files changed:
cmd/bytomd/commands/init.go
config/genesis.go
config/toml.go
consensus/general.go
netsync/handle.go
p2p/discover/dht/net.go
p2p/discover/dht/ntp.go [deleted file]
p2p/discover/dht/udp.go
p2p/discover/dht/udp_test.go [new file with mode: 0644]
p2p/node_info.go
p2p/switch.go
p2p/test_util.go

index 33e6c00..331a346 100644 (file)
@@ -17,7 +17,7 @@ var initFilesCmd = &cobra.Command{
 }
 
 func init() {
-       initFilesCmd.Flags().String("chain_id", config.ChainID, "Select [mainnet] or [testnet] or [solonet]")
+       initFilesCmd.Flags().String("chain_id", config.ChainID, "Select [mainnet] or [testnet] or [solonet] or [vapor]")
 
        RootCmd.AddCommand(initFilesCmd)
 }
@@ -30,7 +30,7 @@ func initFiles(cmd *cobra.Command, args []string) {
        }
 
        switch config.ChainID {
-       case "mainnet", "testnet":
+       case "mainnet", "testnet", "vapor":
                cfg.EnsureRoot(config.RootDir, config.ChainID)
        default:
                cfg.EnsureRoot(config.RootDir, "solonet")
index d25a78a..f588962 100644 (file)
@@ -124,8 +124,9 @@ func soloNetGenesisBlock() *types.Block {
 // GenesisBlock will return genesis block
 func GenesisBlock() *types.Block {
        return map[string]func() *types.Block{
-               "main": mainNetGenesisBlock,
-               "test": testNetGenesisBlock,
-               "solo": soloNetGenesisBlock,
+               "main":  mainNetGenesisBlock,
+               "test":  testNetGenesisBlock,
+               "solo":  soloNetGenesisBlock,
+               "vapor": soloNetGenesisBlock,
        }[consensus.ActiveNetParams.Name]()
 }
index 5774f6f..1d9b0c6 100644 (file)
@@ -45,6 +45,12 @@ laddr = "tcp://0.0.0.0:46658"
 seeds = ""
 `
 
+var vaporNetConfigTmpl = `chain_id = "vapor"
+[p2p]
+laddr = "tcp://0.0.0.0:56659"
+seeds = ""
+`
+
 // Select network seeds to merge a new string.
 func selectNetwork(network string) string {
        switch network {
@@ -52,6 +58,8 @@ func selectNetwork(network string) string {
                return defaultConfigTmpl + mainNetConfigTmpl
        case "testnet":
                return defaultConfigTmpl + testNetConfigTmpl
+       case "vapor":
+               return defaultConfigTmpl + vaporNetConfigTmpl
        default:
                return defaultConfigTmpl + soloNetConfigTmpl
        }
index a545694..6b67870 100644 (file)
@@ -108,6 +108,7 @@ var NetParams = map[string]Params{
        "mainnet": MainNetParams,
        "wisdom":  TestNetParams,
        "solonet": SoloNetParams,
+       "vapor":   VaporNetParams,
 }
 
 // MainNetParams is the config for production
@@ -161,3 +162,10 @@ var SoloNetParams = Params{
        Bech32HRPSegwit: "sm",
        Checkpoints:     []Checkpoint{},
 }
+
+// VaporNetParams is the config for vapor-net
+var VaporNetParams = Params{
+       Name:            "vapor",
+       Bech32HRPSegwit: "vp",
+       Checkpoints:     []Checkpoint{},
+}
index 67dd8c2..db50d8a 100644 (file)
@@ -6,7 +6,6 @@ import (
 
        log "github.com/sirupsen/logrus"
 
-       "github.com/tendermint/go-crypto"
        cfg "github.com/vapor/config"
        "github.com/vapor/consensus"
        "github.com/vapor/event"
@@ -352,9 +351,6 @@ func (sm *SyncManager) IsListening() bool {
 }
 
 func (sm *SyncManager) NodeInfo() *p2p.NodeInfo {
-       if sm.config.VaultMode {
-               return p2p.NewNodeInfo(sm.config, crypto.PubKeyEd25519{}, "")
-       }
        return sm.sw.NodeInfo()
 }
 
index eeb3c19..ee7aea1 100644 (file)
@@ -83,6 +83,7 @@ type transport interface {
        send(remote *Node, ptype nodeEvent, p interface{}) (hash []byte)
 
        localAddr() *net.UDPAddr
+       getNetID() uint64
        Close()
 }
 
@@ -1162,7 +1163,7 @@ func (net *Network) handleQueryEvent(n *Node, ev nodeEvent, pkt *ingressPacket)
        case topicRegisterPacket:
                //fmt.Println("got topicRegisterPacket")
                regdata := pkt.data.(*topicRegister)
-               pong, err := net.checkTopicRegister(regdata)
+               pong, err := net.checkTopicRegister(regdata, net.conn.getNetID())
                if err != nil {
                        //fmt.Println(err)
                        return n.state, fmt.Errorf("bad waiting ticket: %v", err)
@@ -1198,9 +1199,9 @@ func (net *Network) handleQueryEvent(n *Node, ev nodeEvent, pkt *ingressPacket)
        }
 }
 
-func (net *Network) checkTopicRegister(data *topicRegister) (*pong, error) {
+func (net *Network) checkTopicRegister(data *topicRegister, netID uint64) (*pong, error) {
        var pongpkt ingressPacket
-       if err := decodePacket(data.Pong, &pongpkt); err != nil {
+       if err := decodePacket(data.Pong, &pongpkt, netID); err != nil {
                return nil, err
        }
        if pongpkt.ev != pongPacket {
diff --git a/p2p/discover/dht/ntp.go b/p2p/discover/dht/ntp.go
deleted file mode 100644 (file)
index c3354ad..0000000
+++ /dev/null
@@ -1,110 +0,0 @@
-// Contains the NTP time drift detection via the SNTP protocol:
-//   https://tools.ietf.org/html/rfc4330
-
-package dht
-
-import (
-       "fmt"
-       "net"
-       "sort"
-       "strings"
-       "time"
-
-       log "github.com/sirupsen/logrus"
-)
-
-const (
-       ntpPool   = "pool.ntp.org" // ntpPool is the NTP server to query for the current time
-       ntpChecks = 3              // Number of measurements to do against the NTP server
-)
-
-// durationSlice attaches the methods of sort.Interface to []time.Duration,
-// sorting in increasing order.
-type durationSlice []time.Duration
-
-func (s durationSlice) Len() int           { return len(s) }
-func (s durationSlice) Less(i, j int) bool { return s[i] < s[j] }
-func (s durationSlice) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
-
-// checkClockDrift queries an NTP server for clock drifts and warns the user if
-// one large enough is detected.
-func checkClockDrift() {
-       drift, err := sntpDrift(ntpChecks)
-       if err != nil {
-               return
-       }
-       if drift < -driftThreshold || drift > driftThreshold {
-               warning := fmt.Sprintf("System clock seems off by %v, which can prevent network connectivity", drift)
-               howtofix := fmt.Sprintf("Please enable network time synchronisation in system settings")
-               separator := strings.Repeat("-", len(warning))
-
-               log.WithFields(log.Fields{"module": logModule}).Warn(separator)
-               log.WithFields(log.Fields{"module": logModule}).Warn(warning)
-               log.WithFields(log.Fields{"module": logModule}).Warn(howtofix)
-               log.WithFields(log.Fields{"module": logModule}).Warn(separator)
-       } else {
-               log.WithFields(log.Fields{"module": logModule, "drift": drift}).Debug(fmt.Sprintf("Sanity NTP check reported all ok"))
-       }
-}
-
-// sntpDrift does a naive time resolution against an NTP server and returns the
-// measured drift. This method uses the simple version of NTP. It's not precise
-// but should be fine for these purposes.
-//
-// Note, it executes two extra measurements compared to the number of requested
-// ones to be able to discard the two extremes as outliers.
-func sntpDrift(measurements int) (time.Duration, error) {
-       // Resolve the address of the NTP server
-       addr, err := net.ResolveUDPAddr("udp", ntpPool+":123")
-       if err != nil {
-               return 0, err
-       }
-       // Construct the time request (empty package with only 2 fields set):
-       //   Bits 3-5: Protocol version, 3
-       //   Bits 6-8: Mode of operation, client, 3
-       request := make([]byte, 48)
-       request[0] = 3<<3 | 3
-
-       // Execute each of the measurements
-       drifts := []time.Duration{}
-       for i := 0; i < measurements+2; i++ {
-               // Dial the NTP server and send the time retrieval request
-               conn, err := net.DialUDP("udp", nil, addr)
-               if err != nil {
-                       return 0, err
-               }
-               defer conn.Close()
-
-               sent := time.Now()
-               if _, err = conn.Write(request); err != nil {
-                       return 0, err
-               }
-               // Retrieve the reply and calculate the elapsed time
-               conn.SetDeadline(time.Now().Add(5 * time.Second))
-
-               reply := make([]byte, 48)
-               if _, err = conn.Read(reply); err != nil {
-                       return 0, err
-               }
-               elapsed := time.Since(sent)
-
-               // Reconstruct the time from the reply data
-               sec := uint64(reply[43]) | uint64(reply[42])<<8 | uint64(reply[41])<<16 | uint64(reply[40])<<24
-               frac := uint64(reply[47]) | uint64(reply[46])<<8 | uint64(reply[45])<<16 | uint64(reply[44])<<24
-
-               nanosec := sec*1e9 + (frac*1e9)>>32
-
-               t := time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC).Add(time.Duration(nanosec)).Local()
-
-               // Calculate the drift based on an assumed answer time of RRT/2
-               drifts = append(drifts, sent.Sub(t)+elapsed/2)
-       }
-       // Calculate average drif (drop two extremities to avoid outliers)
-       sort.Sort(durationSlice(drifts))
-
-       drift := time.Duration(0)
-       for i := 1; i < len(drifts)-1; i++ {
-               drift += drifts[i]
-       }
-       return drift / time.Duration(measurements), nil
-}
index 94a2cf2..a3c11ce 100644 (file)
@@ -23,31 +23,23 @@ import (
 )
 
 const (
-       Version   = 4
+       //Version dht discover protocol version
+       Version   = 5
        logModule = "discover"
 )
 
 // Errors
 var (
-       errPacketTooSmall   = errors.New("too small")
-       errBadPrefix        = errors.New("bad prefix")
-       errExpired          = errors.New("expired")
-       errUnsolicitedReply = errors.New("unsolicited reply")
-       errUnknownNode      = errors.New("unknown node")
-       errTimeout          = errors.New("RPC timeout")
-       errClockWarp        = errors.New("reply deadline too far in the future")
-       errClosed           = errors.New("socket closed")
+       errPacketTooSmall = errors.New("too small")
+       errPrefixMismatch = errors.New("prefix mismatch")
+       errNetIDMismatch  = errors.New("network id mismatch")
+       errPacketType     = errors.New("unknown packet type")
 )
 
 // Timeouts
 const (
        respTimeout = 1 * time.Second
-       queryDelay  = 1000 * time.Millisecond
        expiration  = 20 * time.Second
-
-       ntpFailureThreshold = 32               // Continuous timeouts after which to check NTP
-       ntpWarningCooldown  = 10 * time.Minute // Minimum amount of time to pass before repeating NTP warning
-       driftThreshold      = 10 * time.Second // Allowed clock drift before warning user
 )
 
 // ReadPacket is sent to the unhandled channel when it could not be processed
@@ -158,11 +150,10 @@ type (
 )
 
 var (
-       versionPrefix     = []byte("bytom discovery")
-       versionPrefixSize = len(versionPrefix)
-       nodeIDSize        = 32
-       sigSize           = 520 / 8
-       headSize          = versionPrefixSize + nodeIDSize + sigSize // space of packet frame data
+       netIDSize  = 8
+       nodeIDSize = 32
+       sigSize    = 520 / 8
+       headSize   = netIDSize + nodeIDSize + sigSize // space of packet frame data
 )
 
 // Neighbors replies are sent across multiple packets to
@@ -254,14 +245,16 @@ type netWork interface {
 
 // udp implements the RPC protocol.
 type udp struct {
-       conn        conn
-       priv        ed25519.PrivateKey
+       conn conn
+       priv ed25519.PrivateKey
+       //netID used to isolate subnets
+       netID       uint64
        ourEndpoint rpcEndpoint
-       //nat         nat.Interface
-       net netWork
+       net         netWork
 }
 
-func NewDiscover(config *cfg.Config, priv ed25519.PrivateKey, port uint16) (*Network, error) {
+//NewDiscover create new dht discover
+func NewDiscover(config *cfg.Config, priv ed25519.PrivateKey, port uint16, netID uint64) (*Network, error) {
        addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort("0.0.0.0", strconv.FormatUint(uint64(port), 10)))
        if err != nil {
                return nil, err
@@ -273,7 +266,7 @@ func NewDiscover(config *cfg.Config, priv ed25519.PrivateKey, port uint16) (*Net
        }
 
        realaddr := conn.LocalAddr().(*net.UDPAddr)
-       ntab, err := ListenUDP(priv, conn, realaddr, path.Join(config.DBDir(), "discover"), nil)
+       ntab, err := ListenUDP(priv, conn, realaddr, path.Join(config.DBDir(), "discover"), nil, netID)
        if err != nil {
                return nil, err
        }
@@ -302,8 +295,8 @@ func NewDiscover(config *cfg.Config, priv ed25519.PrivateKey, port uint16) (*Net
 }
 
 // ListenUDP returns a new table that listens for UDP packets on laddr.
-func ListenUDP(priv ed25519.PrivateKey, conn conn, realaddr *net.UDPAddr, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) {
-       transport, err := listenUDP(priv, conn, realaddr)
+func ListenUDP(priv ed25519.PrivateKey, conn conn, realaddr *net.UDPAddr, nodeDBPath string, netrestrict *netutil.Netlist, netID uint64) (*Network, error) {
+       transport, err := listenUDP(priv, conn, realaddr, netID)
        if err != nil {
                return nil, err
        }
@@ -318,8 +311,8 @@ func ListenUDP(priv ed25519.PrivateKey, conn conn, realaddr *net.UDPAddr, nodeDB
        return net, nil
 }
 
-func listenUDP(priv ed25519.PrivateKey, conn conn, realaddr *net.UDPAddr) (*udp, error) {
-       return &udp{conn: conn, priv: priv, ourEndpoint: makeEndpoint(realaddr, uint16(realaddr.Port))}, nil
+func listenUDP(priv ed25519.PrivateKey, conn conn, realaddr *net.UDPAddr, netID uint64) (*udp, error) {
+       return &udp{conn: conn, priv: priv, netID: netID, ourEndpoint: makeEndpoint(realaddr, uint16(realaddr.Port))}, nil
 }
 
 func (t *udp) localAddr() *net.UDPAddr {
@@ -368,7 +361,7 @@ func (t *udp) sendNeighbours(remote *Node, results []*Node) {
 
 func (t *udp) sendFindnodeHash(remote *Node, target common.Hash) {
        t.sendPacket(remote.ID, remote.addr(), byte(findnodeHashPacket), findnodeHash{
-               Target:     common.Hash(target),
+               Target:     target,
                Expiration: uint64(time.Now().Add(expiration).Unix()),
        })
 }
@@ -400,7 +393,7 @@ func (t *udp) sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node)
 }
 
 func (t *udp) sendPacket(toid NodeID, toaddr *net.UDPAddr, ptype byte, req interface{}) (hash []byte, err error) {
-       packet, hash, err := encodePacket(t.priv, ptype, req)
+       packet, hash, err := encodePacket(t.priv, ptype, req, t.netID)
        if err != nil {
                return hash, err
        }
@@ -414,7 +407,7 @@ func (t *udp) sendPacket(toid NodeID, toaddr *net.UDPAddr, ptype byte, req inter
 // zeroed padding space for encodePacket.
 var headSpace = make([]byte, headSize)
 
-func encodePacket(priv ed25519.PrivateKey, ptype byte, req interface{}) (p, hash []byte, err error) {
+func encodePacket(priv ed25519.PrivateKey, ptype byte, req interface{}, netID uint64) (p, hash []byte, err error) {
        b := new(bytes.Buffer)
        b.Write(headSpace)
        b.WriteByte(ptype)
@@ -427,11 +420,12 @@ func encodePacket(priv ed25519.PrivateKey, ptype byte, req interface{}) (p, hash
        packet := b.Bytes()
        nodeID := priv.Public()
        sig := ed25519.Sign(priv, common.BytesToHash(packet[headSize:]).Bytes())
-       copy(packet, versionPrefix)
-       copy(packet[versionPrefixSize:], nodeID[:])
-       copy(packet[versionPrefixSize+nodeIDSize:], sig)
+       id := []byte(strconv.FormatUint(netID, 16))
+       copy(packet[:], id[:])
+       copy(packet[netIDSize:], nodeID[:])
+       copy(packet[netIDSize+nodeIDSize:], sig)
 
-       hash = common.BytesToHash(packet[versionPrefixSize:]).Bytes()
+       hash = common.BytesToHash(packet[:]).Bytes()
        return packet, hash, nil
 }
 
@@ -454,32 +448,39 @@ func (t *udp) readLoop() {
                        log.WithFields(log.Fields{"module": logModule, "error": err}).Debug("Read error")
                        return
                }
-               t.handlePacket(from, buf[:nbytes])
+               if err := t.handlePacket(from, buf[:nbytes]); err != nil {
+                       log.WithFields(log.Fields{"module": logModule, "from": from, "error": err}).Error("handle packet err")
+               }
        }
 }
 
 func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error {
        pkt := ingressPacket{remoteAddr: from}
-       if err := decodePacket(buf, &pkt); err != nil {
-               log.WithFields(log.Fields{"module": logModule, "from": from, "error": err}).Error("Bad packet")
+       if err := decodePacket(buf, &pkt, t.netID); err != nil {
                return err
        }
        t.net.reqReadPacket(pkt)
        return nil
 }
 
-func decodePacket(buffer []byte, pkt *ingressPacket) error {
+func (t *udp) getNetID() uint64 {
+       return t.netID
+}
+
+func decodePacket(buffer []byte, pkt *ingressPacket, netID uint64) error {
        if len(buffer) < headSize+1 {
                return errPacketTooSmall
        }
        buf := make([]byte, len(buffer))
        copy(buf, buffer)
-       prefix, fromID, sigdata := buf[:versionPrefixSize], buf[versionPrefixSize:versionPrefixSize+nodeIDSize], buf[headSize:]
-       if !bytes.Equal(prefix, versionPrefix) {
-               return errBadPrefix
+       fromID, sigdata := buf[netIDSize:netIDSize+nodeIDSize], buf[headSize:]
+
+       if !bytes.Equal(buf[:netIDSize], []byte(strconv.FormatUint(netID, 16))[:netIDSize]) {
+               return errNetIDMismatch
        }
+
        pkt.rawData = buf
-       pkt.hash = common.BytesToHash(buf[versionPrefixSize:]).Bytes()
+       pkt.hash = common.BytesToHash(buf[:]).Bytes()
        pkt.remoteID = ByteID(fromID)
        switch pkt.ev = nodeEvent(sigdata[0]); pkt.ev {
        case pingPacket:
@@ -499,7 +500,7 @@ func decodePacket(buffer []byte, pkt *ingressPacket) error {
        case topicNodesPacket:
                pkt.data = new(topicNodes)
        default:
-               return fmt.Errorf("unknown packet type: %d", sigdata[0])
+               return errPacketType
        }
        var err error
        wire.ReadJSON(pkt.data, sigdata[1:], &err)
diff --git a/p2p/discover/dht/udp_test.go b/p2p/discover/dht/udp_test.go
new file mode 100644 (file)
index 0000000..3353b20
--- /dev/null
@@ -0,0 +1,446 @@
+package dht
+
+import (
+       "bytes"
+       "net"
+       "reflect"
+       "testing"
+       "time"
+
+       "github.com/davecgh/go-spew/spew"
+       "github.com/vapor/common"
+       "github.com/vapor/crypto/ed25519"
+       "github.com/vapor/errors"
+)
+
+func TestPacketCodec(t *testing.T) {
+       var testPackets = []struct {
+               ptype      byte
+               wantErr    error
+               wantPacket interface{}
+       }{
+               {
+                       ptype:   byte(pingPacket),
+                       wantErr: nil,
+                       wantPacket: &ping{
+                               Version:    4,
+                               From:       rpcEndpoint{net.ParseIP("127.0.0.1").To4(), 3322, 5544},
+                               To:         rpcEndpoint{net.ParseIP("::1"), 2222, 3333},
+                               Expiration: 1136239445,
+                               Topics:     []Topic{"test topic"},
+                               Rest:       []byte{},
+                       },
+               },
+               {
+                       ptype:   byte(pingPacket),
+                       wantErr: nil,
+                       wantPacket: &ping{
+                               Version:    4,
+                               From:       rpcEndpoint{net.ParseIP("127.0.0.1").To4(), 3322, 5544},
+                               To:         rpcEndpoint{net.ParseIP("::1"), 2222, 3333},
+                               Expiration: 1136239445,
+                               Topics:     []Topic{"test topic"},
+                               Rest:       []byte{0x01, 0x02},
+                       },
+               },
+               {
+                       ptype:   byte(pingPacket),
+                       wantErr: nil,
+                       wantPacket: &ping{
+                               Version:    555,
+                               From:       rpcEndpoint{net.ParseIP("2001:db8:3c4d:15::abcd:ef12"), 3322, 5544},
+                               To:         rpcEndpoint{net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:7348"), 2222, 33338},
+                               Expiration: 1136239445,
+                               Topics:     []Topic{"test topic"},
+                               Rest:       []byte{0xC5, 0x01, 0x02, 0x03, 0x04, 0x05},
+                       },
+               },
+               {
+                       ptype:   byte(pongPacket),
+                       wantErr: nil,
+                       wantPacket: &pong{
+                               To:          rpcEndpoint{net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:7348"), 2222, 33338},
+                               ReplyTok:    []byte("fbc914b16819237dcd8801d7e53f69e9719adecb3cc0e790c57e91ca4461c954"),
+                               Expiration:  1136239445,
+                               WaitPeriods: []uint32{},
+                               Rest:        []byte{0xC6, 0x01, 0x02, 0x03, 0xC2, 0x04, 0x05, 0x06},
+                       },
+               },
+               {
+                       ptype:   byte(findnodePacket),
+                       wantErr: nil,
+                       wantPacket: &findnode{
+                               Target:     MustHexID("a2cb4c36765430f2e72564138c36f30fbc8af5a8bb91649822cd937dedbb8748"),
+                               Expiration: 1136239445,
+                               Rest:       []byte{0x82, 0x99, 0x99, 0x83, 0x99, 0x99, 0x99},
+                       },
+               },
+               {
+                       ptype:   byte(neighborsPacket),
+                       wantErr: nil,
+                       wantPacket: &neighbors{
+                               Nodes: []rpcNode{
+                                       {
+                                               ID:  MustHexID("a2cb4c36765430f2e72564138c36f30fbc8af5a8bb91649822cd937dedbb8748"),
+                                               IP:  net.ParseIP("99.33.22.55").To4(),
+                                               UDP: 4444,
+                                               TCP: 4445,
+                                       },
+                                       {
+                                               ID:  MustHexID("312c55512422cf9b8a4097e9a6ad79402e87a15ae909a4bfefa22398f03d2095"),
+                                               IP:  net.ParseIP("1.2.3.4").To4(),
+                                               UDP: 1,
+                                               TCP: 1,
+                                       },
+                                       {
+                                               ID:  MustHexID("38643200b172dcfef857492156971f0e6aa2c538d8b74010f8e140811d53b98c"),
+                                               IP:  net.ParseIP("2001:db8:3c4d:15::abcd:ef12"),
+                                               UDP: 3333,
+                                               TCP: 3333,
+                                       },
+                                       {
+                                               ID:  MustHexID("8dcab8618c3253b558d459da53bd8fa68935a719aff8b811197101a4b2b47dd2"),
+                                               IP:  net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:7348"),
+                                               UDP: 999,
+                                               TCP: 1000,
+                                       },
+                               },
+                               Expiration: 1136239445,
+                               Rest:       []byte{0x01, 0x02, 0x03},
+                       },
+               },
+               {
+                       ptype:   byte(findnodeHashPacket),
+                       wantErr: nil,
+                       wantPacket: &findnodeHash{
+                               Target:     common.Hash{0x0, 0x1, 0x2, 0x3},
+                               Expiration: 1136239445,
+                               Rest:       []byte{0x01, 0x02, 0x03},
+                       },
+               },
+               {
+                       ptype:   byte(topicRegisterPacket),
+                       wantErr: nil,
+                       wantPacket: &topicRegister{
+                               Topics: []Topic{"test topic"},
+                               Idx:    uint(0x01),
+                               Pong:   []byte{0x01, 0x02, 0x03},
+                       },
+               },
+               {
+                       ptype:   byte(topicQueryPacket),
+                       wantErr: nil,
+                       wantPacket: &topicQuery{
+                               Topic:      "test topic",
+                               Expiration: 1136239445,
+                       },
+               },
+               {
+                       ptype:   byte(topicNodesPacket),
+                       wantErr: nil,
+                       wantPacket: &topicNodes{
+                               Echo: common.Hash{0x00, 0x01, 0x02},
+                               Nodes: []rpcNode{
+                                       {
+                                               ID:  MustHexID("a2cb4c36765430f2e72564138c36f30fbc8af5a8bb91649822cd937dedbb8748"),
+                                               IP:  net.ParseIP("99.33.22.55").To4(),
+                                               UDP: 4444,
+                                               TCP: 4445,
+                                       },
+                                       {
+                                               ID:  MustHexID("312c55512422cf9b8a4097e9a6ad79402e87a15ae909a4bfefa22398f03d2095"),
+                                               IP:  net.ParseIP("1.2.3.4").To4(),
+                                               UDP: 1,
+                                               TCP: 1,
+                                       },
+                                       {
+                                               ID:  MustHexID("38643200b172dcfef857492156971f0e6aa2c538d8b74010f8e140811d53b98c"),
+                                               IP:  net.ParseIP("2001:db8:3c4d:15::abcd:ef12"),
+                                               UDP: 3333,
+                                               TCP: 3333,
+                                       },
+                                       {
+                                               ID:  MustHexID("8dcab8618c3253b558d459da53bd8fa68935a719aff8b811197101a4b2b47dd2"),
+                                               IP:  net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:7348"),
+                                               UDP: 999,
+                                               TCP: 1000,
+                                       },
+                               },
+                       },
+               },
+               {
+                       ptype:      byte(topicNodesPacket + 1),
+                       wantErr:    errPacketType,
+                       wantPacket: &topicNodes{},
+               },
+       }
+
+       _, privateKey, _ := ed25519.GenerateKey(nil)
+       netID := uint64(0x12345)
+       for i, test := range testPackets {
+               packet, h, err := encodePacket(privateKey, test.ptype, test.wantPacket, netID)
+               if err != nil {
+                       t.Fatal(err)
+               }
+
+               var pkt ingressPacket
+               if err := decodePacket(packet, &pkt, netID); err != nil {
+                       if errors.Root(err) != test.wantErr {
+                               t.Errorf("index %d did not accept packet %s\n%v", i, packet, err)
+                       }
+                       continue
+               }
+
+               if !reflect.DeepEqual(pkt.hash, h) {
+                       t.Fatalf("packet hash err. got %x, want %x", pkt.hash, h)
+               }
+
+               if !reflect.DeepEqual(pkt.data, test.wantPacket) {
+                       t.Errorf("got %s\nwant %s", spew.Sdump(pkt.data), spew.Sdump(test.wantPacket))
+               }
+       }
+}
+
+type testConn struct {
+       conn net.Conn
+}
+
+func (tc *testConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) {
+       n, err = tc.conn.Read(b)
+       return n, nil, err
+}
+
+func (tc *testConn) WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) {
+       return tc.conn.Write(b)
+}
+
+func (tc *testConn) Close() error {
+       return tc.conn.Close()
+}
+
+func (tc *testConn) LocalAddr() net.Addr {
+       return tc.conn.LocalAddr()
+}
+
+type testNetWork struct {
+       read chan ingressPacket // ingress packets arrive here
+       IP   net.IP
+}
+
+func (tw *testNetWork) reqReadPacket(pkt ingressPacket) {
+       tw.read <- pkt
+}
+
+func (tw *testNetWork) selfIP() net.IP {
+       return tw.IP
+}
+
+func TestPacketTransport(t *testing.T) {
+       c1, c2 := net.Pipe()
+       inConn := &testConn{conn: c1}
+       realaddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 40000}
+       toAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 40000}
+       _, inPrivKey, _ := ed25519.GenerateKey(nil)
+       _, outPrivKey, _ := ed25519.GenerateKey(nil)
+       netID := uint64(0x12345)
+
+       udpInput, err := listenUDP(inPrivKey, inConn, realaddr, netID)
+       if err != nil {
+               t.Fatal(err)
+       }
+       node := &Node{ID: MustHexID("8dcab8618c3253b558d459da53bd8fa68935a719aff8b811197101a4b2b47dd2"),
+               IP:  net.ParseIP("99.33.22.55").To4(),
+               UDP: 4444,
+               TCP: 4445,
+       }
+
+       udpInput.net = &testNetWork{read: make(chan ingressPacket, 100)}
+       go udpInput.readLoop()
+
+       outConn := &testConn{conn: c2}
+       udp, err := listenUDP(outPrivKey, outConn, realaddr, netID)
+       if err != nil {
+               t.Fatal(err)
+       }
+       udp.net = &testNetWork{IP: node.IP}
+       var hash []byte
+
+       //test sendPing
+       hash = udp.sendPing(node, toAddr, nil)
+       pkts := receivePacket(udpInput)
+       if !bytes.Equal(pkts[0].hash, hash) {
+               t.Fatal("pingPacket transport err")
+       }
+
+       //test sendFindnodeHash
+       target := common.Hash{0x01, 0x02}
+       udp.sendFindnodeHash(node, target)
+       pkts = receivePacket(udpInput)
+       if !bytes.Equal(pkts[0].data.(*findnodeHash).Target.Bytes(), target.Bytes()) {
+               t.Fatal("findnodeHashPacket transport err")
+       }
+
+       //test sendNeighbours
+       nodes := []*Node{
+               {
+                       ID:  MustHexID("a2cb4c36765430f2e72564138c36f30fbc8af5a8bb91649822cd937dedbb8748"),
+                       IP:  net.ParseIP("99.33.22.55").To4(),
+                       UDP: 4444,
+                       TCP: 4445,
+               },
+               {
+                       ID:  MustHexID("312c55512422cf9b8a4097e9a6ad79402e87a15ae909a4bfefa22398f03d2095"),
+                       IP:  net.ParseIP("1.2.3.4").To4(),
+                       UDP: 1,
+                       TCP: 1,
+               },
+               {
+                       ID:  MustHexID("38643200b172dcfef857492156971f0e6aa2c538d8b74010f8e140811d53b98c"),
+                       IP:  net.ParseIP("2001:db8:3c4d:15::abcd:ef12"),
+                       UDP: 3333,
+                       TCP: 3333,
+               },
+               {
+                       ID:  MustHexID("8dcab8618c3253b558d459da53bd8fa68935a719aff8b811197101a4b2b47dd2"),
+                       IP:  net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:7348"),
+                       UDP: 999,
+                       TCP: 1000,
+               },
+       }
+
+       udp.sendNeighbours(node, nodes)
+       pkts = receivePacket(udpInput)
+       var gotNodes []rpcNode
+       for _, pkt := range pkts {
+               gotNodes = append(gotNodes, pkt.data.(*neighbors).Nodes[:]...)
+       }
+       for i := 0; i < len(nodes); i++ {
+               if !reflect.DeepEqual(nodeToRPC(nodes[i]), gotNodes[i]) {
+                       t.Fatal("sendNeighboursPacket transport err")
+               }
+       }
+
+       //test sendFindnode
+       targetNode := NodeID{0x01, 0x02, 0x03}
+       udp.sendFindnode(node, targetNode)
+       pkts = receivePacket(udpInput)
+       if pkts[0].data.(*findnode).Target != targetNode {
+               t.Fatal("sendFindnode transport err")
+       }
+
+       //test sendTopicRegister
+       topics := []Topic{"topic1", "topic2", "topic3"}
+       idx := 0xff
+       pong := []byte{0x01, 0x02, 0x03}
+       udp.sendTopicRegister(node, topics, idx, pong)
+       pkts = receivePacket(udpInput)
+       if !bytes.Equal(pkts[0].data.(*topicRegister).Pong, pong) {
+               t.Fatal("sendTopicRegister pong field err")
+       }
+       if pkts[0].data.(*topicRegister).Idx != uint(idx) {
+               t.Fatal("sendTopicRegister idx field err")
+       }
+       if !reflect.DeepEqual(pkts[0].data.(*topicRegister).Topics, topics) {
+               t.Fatal("sendTopicRegister topic field err")
+       }
+
+       //test sendTopicNodes
+       queryHash := common.Hash{0x01, 0x02, 0x03}
+       udp.sendTopicNodes(node, queryHash, nodes)
+       pkts = receivePacket(udpInput)
+       gotNodes = gotNodes[:0]
+       for _, pkt := range pkts {
+               gotNodes = append(gotNodes, pkt.data.(*topicNodes).Nodes[:]...)
+       }
+
+       for i := 0; i < 2; i++ {
+               if !reflect.DeepEqual(nodeToRPC(nodes[i]), gotNodes[i]) {
+                       t.Fatal("sendTopicNodes node field err")
+               }
+       }
+
+       if pkts[0].data.(*topicNodes).Echo != queryHash {
+               t.Fatal("sendTopicNodes echo field err")
+       }
+}
+
+func TestSendTopicNodes(t *testing.T) {
+       c1, c2 := net.Pipe()
+       inConn := &testConn{conn: c1}
+       realaddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 40000}
+       _, inPrivKey, _ := ed25519.GenerateKey(nil)
+       _, outPrivKey, _ := ed25519.GenerateKey(nil)
+       netID := uint64(0x12345)
+
+       udpInput, err := listenUDP(inPrivKey, inConn, realaddr, netID)
+       if err != nil {
+               t.Fatal(err)
+       }
+       node := &Node{ID: MustHexID("8dcab8618c3253b558d459da53bd8fa68935a719aff8b811197101a4b2b47dd2"),
+               IP:  net.ParseIP("99.33.22.55").To4(),
+               UDP: 4444,
+               TCP: 4445,
+       }
+
+       udpInput.net = &testNetWork{read: make(chan ingressPacket, 100)}
+       go udpInput.readLoop()
+
+       outConn := &testConn{conn: c2}
+       udp, err := listenUDP(outPrivKey, outConn, realaddr, netID)
+       if err != nil {
+               t.Fatal(err)
+       }
+       udp.net = &testNetWork{IP: node.IP}
+
+       //test sendTopicNodes
+       queryHash := common.Hash{0x01, 0x02, 0x03}
+       var nodes []*Node
+       var gotNodes []rpcNode
+       for i := 0; i < 100; i++ {
+               node := &Node{
+                       ID:  MustHexID("a2cb4c36765430f2e72564138c36f30fbc8af5a8bb91649822cd937dedbb8748"),
+                       IP:  net.ParseIP("1.2.3.4").To4(),
+                       UDP: uint16(i),
+                       TCP: uint16(i),
+               }
+               nodes = append(nodes, node)
+       }
+       udp.sendTopicNodes(node, queryHash, nodes)
+       pkts := receivePacket(udpInput)
+       for _, pkt := range pkts {
+               gotNodes = append(gotNodes, pkt.data.(*topicNodes).Nodes[:]...)
+       }
+       for i := 0; i < len(gotNodes); i++ {
+               if !reflect.DeepEqual(nodeToRPC(nodes[i]), gotNodes[i]) {
+                       t.Fatal("sendTopicNodes node field err")
+               }
+       }
+
+       nodes = nodes[:0]
+       gotNodes = gotNodes[:0]
+       udp.sendTopicNodes(node, queryHash, nodes)
+       pkts = receivePacket(udpInput)
+       for _, pkt := range pkts {
+               gotNodes = append(gotNodes, pkt.data.(*topicNodes).Nodes[:]...)
+       }
+       for i := 0; i < len(gotNodes); i++ {
+               if !reflect.DeepEqual(nodeToRPC(nodes[i]), gotNodes[i]) {
+                       t.Fatal("sendTopicNodes node field err")
+               }
+       }
+}
+
+func receivePacket(udpInput *udp) []ingressPacket {
+       waitTicker := time.NewTimer(10 * time.Millisecond)
+       defer waitTicker.Stop()
+       var msgs []ingressPacket
+       for {
+               select {
+               case msg := <-udpInput.net.(*testNetWork).read:
+                       msgs = append(msgs, msg)
+               case <-waitTicker.C:
+                       return msgs
+               }
+       }
+       return msgs
+}
index 81dcac4..2e52e31 100644 (file)
@@ -16,18 +16,20 @@ const maxNodeInfoSize = 10240 // 10Kb
 
 //NodeInfo peer node info
 type NodeInfo struct {
-       PubKey     crypto.PubKeyEd25519 `json:"pub_key"`
-       Moniker    string               `json:"moniker"`
-       Network    string               `json:"network"`
-       RemoteAddr string               `json:"remote_addr"`
-       ListenAddr string               `json:"listen_addr"`
-       Version    string               `json:"version"` // major.minor.revision
+       PubKey  crypto.PubKeyEd25519 `json:"pub_key"`
+       Moniker string               `json:"moniker"`
+       Network string               `json:"network"`
+       //NetworkID used to isolate subnets with same network name
+       NetworkID  uint64 `json:"network_id"`
+       RemoteAddr string `json:"remote_addr"`
+       ListenAddr string `json:"listen_addr"`
+       Version    string `json:"version"` // major.minor.revision
        // other application specific data
        //field 0: node service flags. field 1: node alias.
        Other []string `json:"other"`
 }
 
-func NewNodeInfo(config *cfg.Config, pubkey crypto.PubKeyEd25519, listenAddr string) *NodeInfo {
+func NewNodeInfo(config *cfg.Config, pubkey crypto.PubKeyEd25519, listenAddr string, netID uint64) *NodeInfo {
        other := []string{strconv.FormatUint(uint64(consensus.DefaultServices), 10)}
        if config.NodeAlias != "" {
                other = append(other, config.NodeAlias)
@@ -36,6 +38,7 @@ func NewNodeInfo(config *cfg.Config, pubkey crypto.PubKeyEd25519, listenAddr str
                PubKey:     pubkey,
                Moniker:    config.Moniker,
                Network:    config.ChainID,
+               NetworkID:  netID,
                ListenAddr: listenAddr,
                Version:    version.Version,
                Other:      other,
@@ -45,6 +48,14 @@ func NewNodeInfo(config *cfg.Config, pubkey crypto.PubKeyEd25519, listenAddr str
 // CompatibleWith checks if two NodeInfo are compatible with eachother.
 // CONTRACT: two nodes are compatible if the major version matches and network match
 func (info *NodeInfo) CompatibleWith(other *NodeInfo) error {
+       if info.Network != other.Network {
+               return fmt.Errorf("Peer is on a different network. Peer network: %v, node network: %v", other.Network, info.Network)
+       }
+
+       if info.NetworkID != other.NetworkID {
+               return fmt.Errorf("Network id dismatch. Peer network id: %v, node network id: %v", other.NetworkID, info.NetworkID)
+       }
+
        compatible, err := version.CompatibleWith(other.Version)
        if err != nil {
                return err
@@ -53,9 +64,6 @@ func (info *NodeInfo) CompatibleWith(other *NodeInfo) error {
                return fmt.Errorf("Peer is on a different major version. Peer version: %v, node version: %v", other.Version, info.Version)
        }
 
-       if info.Network != other.Network {
-               return fmt.Errorf("Peer is on a different network. Peer network: %v, node network: %v", other.Network, info.Network)
-       }
        return nil
 }
 
index 4ed9c20..c64a69e 100644 (file)
@@ -1,6 +1,7 @@
 package p2p
 
 import (
+       "encoding/binary"
        "encoding/hex"
        "encoding/json"
        "fmt"
@@ -15,6 +16,7 @@ import (
        cfg "github.com/vapor/config"
        "github.com/vapor/consensus"
        "github.com/vapor/crypto/ed25519"
+       "github.com/vapor/crypto/sha3pool"
        dbm "github.com/vapor/database/leveldb"
        "github.com/vapor/errors"
        "github.com/vapor/event"
@@ -33,6 +35,8 @@ const (
 
        minNumOutboundPeers = 4
        maxNumLANPeers      = 5
+       //magicNumber used to generate unique netID
+       magicNumber = uint64(0x054c5638)
 )
 
 //pre-define errors for connecting fail
@@ -84,6 +88,16 @@ func NewSwitch(config *cfg.Config) (*Switch, error) {
        var discv *dht.Network
        var lanDiscv *mdns.LANDiscover
 
+       //generate unique netID
+       var data []byte
+       var h [32]byte
+       data = append(data, cfg.GenesisBlock().Hash().Bytes()...)
+       magic := make([]byte, 8)
+       binary.BigEndian.PutUint64(magic, magicNumber)
+       data = append(data, magic[:]...)
+       sha3pool.Sum256(h[:], data)
+       netID := binary.BigEndian.Uint64(h[:8])
+
        blacklistDB := dbm.NewDB("trusthistory", config.DBBackend, config.DBDir())
        config.P2P.PrivateKey, err = config.NodeKey()
        if err != nil {
@@ -101,7 +115,7 @@ func NewSwitch(config *cfg.Config) (*Switch, error) {
        if !config.VaultMode {
                // Create listener
                l, listenAddr = GetListener(config.P2P)
-               discv, err = dht.NewDiscover(config, ed25519.PrivateKey(bytes), l.ExternalAddress().Port)
+               discv, err = dht.NewDiscover(config, ed25519.PrivateKey(bytes), l.ExternalAddress().Port, netID)
                if err != nil {
                        return nil, err
                }
@@ -110,11 +124,11 @@ func NewSwitch(config *cfg.Config) (*Switch, error) {
                }
        }
 
-       return newSwitch(config, discv, lanDiscv, blacklistDB, l, privKey, listenAddr)
+       return newSwitch(config, discv, lanDiscv, blacklistDB, l, privKey, listenAddr, netID)
 }
 
 // newSwitch creates a new Switch with the given config.
-func newSwitch(config *cfg.Config, discv discv, lanDiscv lanDiscv, blacklistDB dbm.DB, l Listener, priv crypto.PrivKeyEd25519, listenAddr string) (*Switch, error) {
+func newSwitch(config *cfg.Config, discv discv, lanDiscv lanDiscv, blacklistDB dbm.DB, l Listener, priv crypto.PrivKeyEd25519, listenAddr string, netID uint64) (*Switch, error) {
        sw := &Switch{
                Config:       config,
                peerConfig:   DefaultPeerConfig(config.P2P),
@@ -127,7 +141,7 @@ func newSwitch(config *cfg.Config, discv discv, lanDiscv lanDiscv, blacklistDB d
                discv:        discv,
                lanDiscv:     lanDiscv,
                db:           blacklistDB,
-               nodeInfo:     NewNodeInfo(config, priv.PubKey().Unwrap().(crypto.PubKeyEd25519), listenAddr),
+               nodeInfo:     NewNodeInfo(config, priv.PubKey().Unwrap().(crypto.PubKeyEd25519), listenAddr, netID),
                bannedPeer:   make(map[string]time.Time),
        }
        if err := sw.loadBannedPeers(); err != nil {
index a7afc0a..8af7dc5 100644 (file)
@@ -92,7 +92,7 @@ func MakeSwitch(cfg *cfg.Config, testdb dbm.DB, privKey crypto.PrivKeyEd25519, i
        // new switch, add reactors
        l, listenAddr := GetListener(cfg.P2P)
        cfg.P2P.LANDiscover = false
-       sw, err := newSwitch(cfg, new(mockDiscv), nil, testdb, l, privKey, listenAddr)
+       sw, err := newSwitch(cfg, new(mockDiscv), nil, testdb, l, privKey, listenAddr, 0)
        if err != nil {
                log.Errorf("create switch error: %s", err)
                return nil