OSDN Git Service

Merge pull request #201 from Bytom/v0.1
[bytom/vapor.git] / p2p / switch.go
index d00c9aa..ef8306d 100644 (file)
@@ -1,40 +1,61 @@
 package p2p
 
 import (
-       "encoding/json"
+       "encoding/binary"
        "fmt"
        "net"
        "sync"
        "time"
 
        log "github.com/sirupsen/logrus"
-       "github.com/tendermint/go-crypto"
        cmn "github.com/tendermint/tmlibs/common"
-       dbm "github.com/tendermint/tmlibs/db"
 
        cfg "github.com/vapor/config"
        "github.com/vapor/consensus"
+       "github.com/vapor/crypto/sha3pool"
        "github.com/vapor/errors"
+       "github.com/vapor/event"
        "github.com/vapor/p2p/connection"
-       "github.com/vapor/p2p/discover"
-       "github.com/vapor/p2p/trust"
+       "github.com/vapor/p2p/discover/dht"
+       "github.com/vapor/p2p/discover/mdns"
+       "github.com/vapor/p2p/netutil"
+       "github.com/vapor/p2p/signlib"
+       security "github.com/vapor/p2p/security"
        "github.com/vapor/version"
 )
 
 const (
-       bannedPeerKey       = "BannedPeer"
-       defaultBanDuration  = time.Hour * 1
-       minNumOutboundPeers = 3
+       logModule = "p2p"
+
+       minNumOutboundPeers = 4
+       maxNumLANPeers      = 5
+       //magicNumber used to generate unique netID
+       magicNumber = uint64(0x054c5638)
 )
 
 //pre-define errors for connecting fail
 var (
-       ErrDuplicatePeer     = errors.New("Duplicate peer")
-       ErrConnectSelf       = errors.New("Connect self")
-       ErrConnectBannedPeer = errors.New("Connect banned peer")
-       ErrConnectSpvPeer    = errors.New("Outbound connect spv peer")
+       ErrDuplicatePeer  = errors.New("Duplicate peer")
+       ErrConnectSelf    = errors.New("Connect self")
+       ErrConnectSpvPeer = errors.New("Outbound connect spv peer")
 )
 
+type discv interface {
+       ReadRandomNodes(buf []*dht.Node) (n int)
+}
+
+type lanDiscv interface {
+       Subscribe() (*event.Subscription, error)
+       Stop()
+}
+
+type Security interface {
+       DoFilter(ip string, pubKey string) error
+       IsBanned(ip string, level byte, reason string) bool
+       RegisterFilter(filter security.Filter)
+       Start() error
+}
+
 // Switch handles peer connections and exposes an API to receive incoming messages
 // on `Reactors`.  Each `Reactor` is responsible for handling incoming messages of one
 // or more `Channels`.  So while sending outgoing messages is typically performed on the peer,
@@ -50,16 +71,49 @@ type Switch struct {
        reactorsByCh map[byte]Reactor
        peers        *PeerSet
        dialing      *cmn.CMap
-       nodeInfo     *NodeInfo             // our node info
-       nodePrivKey  crypto.PrivKeyEd25519 // our node privkey
-       discv        *discover.Network
-       bannedPeer   map[string]time.Time
-       db           dbm.DB
-       mtx          sync.Mutex
+       nodeInfo     *NodeInfo       // our node info
+       nodePrivKey  signlib.PrivKey // our node privkey
+       discv        discv
+       lanDiscv     lanDiscv
+       security     Security
+}
+
+// NewSwitch create a new Switch and set discover.
+func NewSwitch(config *cfg.Config) (*Switch, error) {
+       var err error
+       var l Listener
+       var listenAddr string
+       var discv *dht.Network
+       var lanDiscv *mdns.LANDiscover
+
+       //generate unique netID
+       var data []byte
+       var h [32]byte
+       data = append(data, cfg.GenesisBlock().Hash().Bytes()...)
+       magic := make([]byte, 8)
+       binary.BigEndian.PutUint64(magic, magicNumber)
+       data = append(data, magic[:]...)
+       sha3pool.Sum256(h[:], data)
+       netID := binary.BigEndian.Uint64(h[:8])
+
+       privateKey := config.PrivateKey()
+       if !config.VaultMode {
+               // Create listener
+               l, listenAddr = GetListener(config.P2P)
+               discv, err = dht.NewDiscover(config, *privateKey, l.ExternalAddress().Port, netID)
+               if err != nil {
+                       return nil, err
+               }
+               if config.P2P.LANDiscover {
+                       lanDiscv = mdns.NewLANDiscover(mdns.NewProtocol(), int(l.ExternalAddress().Port))
+               }
+       }
+
+       return newSwitch(config, discv, lanDiscv, l, *privateKey, listenAddr, netID)
 }
 
-// NewSwitch creates a new Switch with the given config.
-func NewSwitch(config *cfg.Config) *Switch {
+// newSwitch creates a new Switch with the given config.
+func newSwitch(config *cfg.Config, discv discv, lanDiscv lanDiscv, l Listener, privKey signlib.PrivKey, listenAddr string, netID uint64) (*Switch, error) {
        sw := &Switch{
                Config:       config,
                peerConfig:   DefaultPeerConfig(config.P2P),
@@ -68,18 +122,17 @@ func NewSwitch(config *cfg.Config) *Switch {
                reactorsByCh: make(map[byte]Reactor),
                peers:        NewPeerSet(),
                dialing:      cmn.NewCMap(),
-               nodeInfo:     nil,
-               db:           dbm.NewDB("trusthistory", config.DBBackend, config.DBDir()),
+               nodePrivKey:  privKey,
+               discv:        discv,
+               lanDiscv:     lanDiscv,
+               nodeInfo:     NewNodeInfo(config, privKey.XPub(), listenAddr, netID),
+               security:     security.NewSecurity(config),
        }
+
+       sw.AddListener(l)
        sw.BaseService = *cmn.NewBaseService(nil, "P2P Switch", sw)
-       sw.bannedPeer = make(map[string]time.Time)
-       if datajson := sw.db.Get([]byte(bannedPeerKey)); datajson != nil {
-               if err := json.Unmarshal(datajson, &sw.bannedPeer); err != nil {
-                       return nil
-               }
-       }
-       trust.Init()
-       return sw
+       log.WithFields(log.Fields{"module": logModule, "nodeInfo": sw.nodeInfo}).Info("init p2p network")
+       return sw, nil
 }
 
 // OnStart implements BaseService. It starts all the reactors, peers, and listeners.
@@ -89,15 +142,28 @@ func (sw *Switch) OnStart() error {
                        return err
                }
        }
+
+       sw.security.RegisterFilter(sw.nodeInfo)
+       sw.security.RegisterFilter(sw.peers)
+       if err := sw.security.Start(); err != nil {
+               return err
+       }
+
        for _, listener := range sw.listeners {
                go sw.listenerRoutine(listener)
        }
        go sw.ensureOutboundPeersRoutine()
+       go sw.connectLANPeersRoutine()
+
        return nil
 }
 
 // OnStop implements BaseService. It stops all listeners, peers, and reactors.
 func (sw *Switch) OnStop() {
+       if sw.Config.P2P.LANDiscover {
+               sw.lanDiscv.Stop()
+       }
+
        for _, listener := range sw.listeners {
                listener.Stop()
        }
@@ -113,28 +179,13 @@ func (sw *Switch) OnStop() {
        }
 }
 
-//AddBannedPeer add peer to blacklist
-func (sw *Switch) AddBannedPeer(ip string) error {
-       sw.mtx.Lock()
-       defer sw.mtx.Unlock()
-
-       sw.bannedPeer[ip] = time.Now().Add(defaultBanDuration)
-       datajson, err := json.Marshal(sw.bannedPeer)
-       if err != nil {
-               return err
-       }
-
-       sw.db.Set([]byte(bannedPeerKey), datajson)
-       return nil
-}
-
 // AddPeer performs the P2P handshake with a peer
 // that already has a SecretConnection. If all goes well,
 // it starts the peer and adds it to the switch.
 // NOTE: This performs a blocking handshake before the peer is added.
 // CONTRACT: If error is returned, peer is nil, and conn is immediately closed.
-func (sw *Switch) AddPeer(pc *peerConn) error {
-       peerNodeInfo, err := pc.HandshakeTimeout(sw.nodeInfo, time.Duration(sw.peerConfig.HandshakeTimeout))
+func (sw *Switch) AddPeer(pc *peerConn, isLAN bool) error {
+       peerNodeInfo, err := pc.HandshakeTimeout(sw.nodeInfo, sw.peerConfig.HandshakeTimeout)
        if err != nil {
                return err
        }
@@ -142,12 +193,13 @@ func (sw *Switch) AddPeer(pc *peerConn) error {
        if err := version.Status.CheckUpdate(sw.nodeInfo.Version, peerNodeInfo.Version, peerNodeInfo.RemoteAddr); err != nil {
                return err
        }
-       if err := sw.nodeInfo.CompatibleWith(peerNodeInfo); err != nil {
+
+       if err := sw.nodeInfo.compatibleWith(peerNodeInfo, version.CompatibleWith); err != nil {
                return err
        }
 
-       peer := newPeer(pc, peerNodeInfo, sw.reactorsByCh, sw.chDescs, sw.StopPeerForError)
-       if err := sw.filterConnByPeer(peer); err != nil {
+       peer := newPeer(pc, peerNodeInfo, sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, isLAN)
+       if err := sw.security.DoFilter(peer.remoteAddrHost(), peer.PubKey()); err != nil {
                return err
        }
 
@@ -161,6 +213,7 @@ func (sw *Switch) AddPeer(pc *peerConn) error {
                        return err
                }
        }
+
        return sw.peers.Add(peer)
 }
 
@@ -190,28 +243,32 @@ func (sw *Switch) AddListener(l Listener) {
 
 //DialPeerWithAddress dial node from net address
 func (sw *Switch) DialPeerWithAddress(addr *NetAddress) error {
-       log.Debug("Dialing peer address:", addr)
+       log.WithFields(log.Fields{"module": logModule, "address": addr}).Debug("Dialing peer")
        sw.dialing.Set(addr.IP.String(), addr)
        defer sw.dialing.Delete(addr.IP.String())
-       if err := sw.filterConnByIP(addr.IP.String()); err != nil {
+       if err := sw.security.DoFilter(addr.IP.String(), ""); err != nil {
                return err
        }
 
        pc, err := newOutboundPeerConn(addr, sw.nodePrivKey, sw.peerConfig)
        if err != nil {
-               log.WithFields(log.Fields{"address": addr, " err": err}).Debug("DialPeer fail on newOutboundPeerConn")
+               log.WithFields(log.Fields{"module": logModule, "address": addr, " err": err}).Error("DialPeer fail on newOutboundPeerConn")
                return err
        }
 
-       if err = sw.AddPeer(pc); err != nil {
-               log.WithFields(log.Fields{"address": addr, " err": err}).Debug("DialPeer fail on switch AddPeer")
+       if err = sw.AddPeer(pc, addr.isLAN); err != nil {
+               log.WithFields(log.Fields{"module": logModule, "address": addr, " err": err}).Error("DialPeer fail on switch AddPeer")
                pc.CloseConn()
                return err
        }
-       log.Debug("DialPeer added peer:", addr)
+       log.WithFields(log.Fields{"module": logModule, "address": addr, "peer num": sw.peers.Size()}).Debug("DialPeer added peer")
        return nil
 }
 
+func (sw *Switch) IsBanned(ip string, level byte, reason string) bool {
+       return sw.security.IsBanned(ip, level, reason)
+}
+
 //IsDialing prevent duplicate dialing
 func (sw *Switch) IsDialing(addr *NetAddress) bool {
        return sw.dialing.Has(addr.IP.String())
@@ -230,48 +287,30 @@ func (sw *Switch) Listeners() []Listener {
 }
 
 // NumPeers Returns the count of outbound/inbound and outbound-dialing peers.
-func (sw *Switch) NumPeers() (outbound, inbound, dialing int) {
+func (sw *Switch) NumPeers() (lan, outbound, inbound, dialing int) {
        peers := sw.peers.List()
        for _, peer := range peers {
-               if peer.outbound {
+               if peer.outbound && !peer.isLAN {
                        outbound++
                } else {
                        inbound++
                }
+               if peer.isLAN {
+                       lan++
+               }
        }
        dialing = sw.dialing.Size()
        return
 }
 
-// NodeInfo returns the switch's NodeInfo.
-// NOTE: Not goroutine safe.
-func (sw *Switch) NodeInfo() *NodeInfo {
-       return sw.nodeInfo
-}
-
 //Peers return switch peerset
 func (sw *Switch) Peers() *PeerSet {
        return sw.peers
 }
 
-// SetNodeInfo sets the switch's NodeInfo for checking compatibility and handshaking with other nodes.
-// NOTE: Not goroutine safe.
-func (sw *Switch) SetNodeInfo(nodeInfo *NodeInfo) {
-       sw.nodeInfo = nodeInfo
-}
-
-// SetNodePrivKey sets the switch's private key for authenticated encryption.
-// NOTE: Not goroutine safe.
-func (sw *Switch) SetNodePrivKey(nodePrivKey crypto.PrivKeyEd25519) {
-       sw.nodePrivKey = nodePrivKey
-       if sw.nodeInfo != nil {
-               sw.nodeInfo.PubKey = nodePrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519)
-       }
-}
-
 // StopPeerForError disconnects from a peer due to external error.
 func (sw *Switch) StopPeerForError(peer *Peer, reason interface{}) {
-       log.WithFields(log.Fields{"peer": peer, " err": reason}).Debug("stopping peer for error")
+       log.WithFields(log.Fields{"module": logModule, "peer": peer, " err": reason}).Debug("stopping peer for error")
        sw.stopAndRemovePeer(peer, reason)
 }
 
@@ -285,64 +324,65 @@ func (sw *Switch) StopPeerGracefully(peerID string) {
 func (sw *Switch) addPeerWithConnection(conn net.Conn) error {
        peerConn, err := newInboundPeerConn(conn, sw.nodePrivKey, sw.Config.P2P)
        if err != nil {
-               conn.Close()
+               if err := conn.Close(); err != nil {
+                       log.WithFields(log.Fields{"module": logModule, "remote peer:": conn.RemoteAddr().String(), " err:": err}).Error("closes connection err")
+               }
                return err
        }
 
-       if err = sw.AddPeer(peerConn); err != nil {
-               conn.Close()
-               return err
-       }
-       return nil
-}
-
-func (sw *Switch) checkBannedPeer(peer string) error {
-       sw.mtx.Lock()
-       defer sw.mtx.Unlock()
-
-       if banEnd, ok := sw.bannedPeer[peer]; ok {
-               if time.Now().Before(banEnd) {
-                       return ErrConnectBannedPeer
+       if err = sw.AddPeer(peerConn, false); err != nil {
+               if err := conn.Close(); err != nil {
+                       log.WithFields(log.Fields{"module": logModule, "remote peer:": conn.RemoteAddr().String(), " err:": err}).Error("closes connection err")
                }
-               sw.delBannedPeer(peer)
-       }
-       return nil
-}
-
-func (sw *Switch) delBannedPeer(addr string) error {
-       sw.mtx.Lock()
-       defer sw.mtx.Unlock()
-
-       delete(sw.bannedPeer, addr)
-       datajson, err := json.Marshal(sw.bannedPeer)
-       if err != nil {
                return err
        }
 
-       sw.db.Set([]byte(bannedPeerKey), datajson)
+       log.WithFields(log.Fields{"module": logModule, "address": conn.RemoteAddr().String(), "peer num": sw.peers.Size()}).Debug("add inbound peer")
        return nil
 }
 
-func (sw *Switch) filterConnByIP(ip string) error {
-       if ip == sw.nodeInfo.ListenHost() {
-               return ErrConnectSelf
+func (sw *Switch) connectLANPeers(lanPeer mdns.LANPeerEvent) {
+       lanPeers, _, _, numDialing := sw.NumPeers()
+       numToDial := maxNumLANPeers - lanPeers
+       log.WithFields(log.Fields{"module": logModule, "numDialing": numDialing, "numToDial": numToDial}).Debug("connect LAN peers")
+       if numToDial <= 0 {
+               return
+       }
+       addresses := make([]*NetAddress, 0)
+       for i := 0; i < len(lanPeer.IP); i++ {
+               addresses = append(addresses, NewLANNetAddressIPPort(lanPeer.IP[i], uint16(lanPeer.Port)))
        }
-       return sw.checkBannedPeer(ip)
+       sw.dialPeers(addresses)
 }
 
-func (sw *Switch) filterConnByPeer(peer *Peer) error {
-       if err := sw.checkBannedPeer(peer.RemoteAddrHost()); err != nil {
-               return err
+func (sw *Switch) connectLANPeersRoutine() {
+       if !sw.Config.P2P.LANDiscover {
+               return
        }
 
-       if sw.nodeInfo.PubKey.Equals(peer.PubKey().Wrap()) {
-               return ErrConnectSelf
+       lanPeerEventSub, err := sw.lanDiscv.Subscribe()
+       if err != nil {
+               log.WithFields(log.Fields{"module": logModule, "err": err}).Warning("subscribe LAN Peer Event error")
+               return
        }
 
-       if sw.peers.Has(peer.Key) {
-               return ErrDuplicatePeer
+       for {
+               select {
+               case obj, ok := <-lanPeerEventSub.Chan():
+                       if !ok {
+                               log.WithFields(log.Fields{"module": logModule}).Warning("LAN peer event subscription channel closed")
+                               return
+                       }
+                       LANPeer, ok := obj.Data.(mdns.LANPeerEvent)
+                       if !ok {
+                               log.WithFields(log.Fields{"module": logModule}).Error("event type error")
+                               continue
+                       }
+                       sw.connectLANPeers(LANPeer)
+               case <-sw.Quit:
+                       return
+               }
        }
-       return nil
 }
 
 func (sw *Switch) listenerRoutine(l Listener) {
@@ -354,7 +394,9 @@ func (sw *Switch) listenerRoutine(l Listener) {
 
                // disconnect if we alrady have MaxNumPeers
                if sw.peers.Size() >= sw.Config.P2P.MaxNumPeers {
-                       inConn.Close()
+                       if err := inConn.Close(); err != nil {
+                               log.WithFields(log.Fields{"module": logModule, "remote peer:": inConn.RemoteAddr().String(), " err:": err}).Error("closes connection err")
+                       }
                        log.Info("Ignoring inbound connection: already have enough peers.")
                        continue
                }
@@ -367,53 +409,72 @@ func (sw *Switch) listenerRoutine(l Listener) {
        }
 }
 
-// SetDiscv connect the discv model to the switch
-func (sw *Switch) SetDiscv(discv *discover.Network) {
-       sw.discv = discv
-}
-
 func (sw *Switch) dialPeerWorker(a *NetAddress, wg *sync.WaitGroup) {
        if err := sw.DialPeerWithAddress(a); err != nil {
-               log.WithFields(log.Fields{"addr": a, "err": err}).Error("dialPeerWorker fail on dial peer")
+               log.WithFields(log.Fields{"module": logModule, "addr": a, "err": err}).Error("dialPeerWorker fail on dial peer")
        }
        wg.Done()
 }
 
-func (sw *Switch) ensureOutboundPeers() {
-       numOutPeers, _, numDialing := sw.NumPeers()
-       numToDial := (minNumOutboundPeers - (numOutPeers + numDialing))
-       log.WithFields(log.Fields{"numOutPeers": numOutPeers, "numDialing": numDialing, "numToDial": numToDial}).Debug("ensure peers")
-       if numToDial <= 0 {
-               return
-       }
-
+func (sw *Switch) dialPeers(addresses []*NetAddress) {
        connectedPeers := make(map[string]struct{})
        for _, peer := range sw.Peers().List() {
-               connectedPeers[peer.RemoteAddrHost()] = struct{}{}
+               connectedPeers[peer.remoteAddrHost()] = struct{}{}
        }
 
        var wg sync.WaitGroup
-       nodes := make([]*discover.Node, numToDial)
-       n := sw.discv.ReadRandomNodes(nodes)
-       for i := 0; i < n; i++ {
-               try := NewNetAddressIPPort(nodes[i].IP, nodes[i].TCP)
-               if sw.NodeInfo().ListenAddr == try.String() {
+       for _, address := range addresses {
+               if sw.nodeInfo.ListenAddr == address.String() {
                        continue
                }
-               if dialling := sw.IsDialing(try); dialling {
+               if dialling := sw.IsDialing(address); dialling {
                        continue
                }
-               if _, ok := connectedPeers[try.IP.String()]; ok {
+               if _, ok := connectedPeers[address.IP.String()]; ok {
                        continue
                }
 
                wg.Add(1)
-               go sw.dialPeerWorker(try, &wg)
+               go sw.dialPeerWorker(address, &wg)
        }
        wg.Wait()
 }
 
+func (sw *Switch) ensureKeepConnectPeers() {
+       keepDials := netutil.CheckAndSplitAddresses(sw.Config.P2P.KeepDial)
+       addresses := make([]*NetAddress, 0)
+       for _, keepDial := range keepDials {
+               address, err := NewNetAddressString(keepDial)
+               if err != nil {
+                       log.WithFields(log.Fields{"module": logModule, "err": err, "address": keepDial}).Warn("parse address to NetAddress")
+                       continue
+               }
+               addresses = append(addresses, address)
+       }
+
+       sw.dialPeers(addresses)
+}
+
+func (sw *Switch) ensureOutboundPeers() {
+       lanPeers, numOutPeers, _, numDialing := sw.NumPeers()
+       numToDial := minNumOutboundPeers - (numOutPeers + numDialing)
+       log.WithFields(log.Fields{"module": logModule, "numOutPeers": numOutPeers, "LANPeers": lanPeers, "numDialing": numDialing, "numToDial": numToDial}).Debug("ensure peers")
+       if numToDial <= 0 {
+               return
+       }
+
+       nodes := make([]*dht.Node, numToDial)
+       n := sw.discv.ReadRandomNodes(nodes)
+       addresses := make([]*NetAddress, 0)
+       for i := 0; i < n; i++ {
+               address := NewNetAddressIPPort(nodes[i].IP, nodes[i].TCP)
+               addresses = append(addresses, address)
+       }
+       sw.dialPeers(addresses)
+}
+
 func (sw *Switch) ensureOutboundPeersRoutine() {
+       sw.ensureKeepConnectPeers()
        sw.ensureOutboundPeers()
 
        ticker := time.NewTicker(10 * time.Second)
@@ -422,6 +483,7 @@ func (sw *Switch) ensureOutboundPeersRoutine() {
        for {
                select {
                case <-ticker.C:
+                       sw.ensureKeepConnectPeers()
                        sw.ensureOutboundPeers()
                case <-sw.Quit:
                        return
@@ -430,7 +492,11 @@ func (sw *Switch) ensureOutboundPeersRoutine() {
 }
 
 func (sw *Switch) startInitPeer(peer *Peer) error {
-       peer.Start() // spawn send/recv routines
+       // spawn send/recv routines
+       if _, err := peer.Start(); err != nil {
+               log.WithFields(log.Fields{"module": logModule, "remote peer:": peer.RemoteAddr, " err:": err}).Error("init peer err")
+       }
+
        for _, reactor := range sw.reactors {
                if err := reactor.AddPeer(peer); err != nil {
                        return err
@@ -448,6 +514,7 @@ func (sw *Switch) stopAndRemovePeer(peer *Peer, reason interface{}) {
 
        sentStatus, receivedStatus := peer.TrafficStatus()
        log.WithFields(log.Fields{
+               "module":                logModule,
                "address":               peer.Addr().String(),
                "reason":                reason,
                "duration":              sentStatus.Duration.String(),
@@ -455,5 +522,6 @@ func (sw *Switch) stopAndRemovePeer(peer *Peer, reason interface{}) {
                "total_received":        receivedStatus.Bytes,
                "average_sent_rate":     sentStatus.AvgRate,
                "average_received_rate": receivedStatus.AvgRate,
+               "peer num":              sw.peers.Size(),
        }).Info("disconnect with peer")
 }