OSDN Git Service

Specify lan discovery service name (#407)
[bytom/vapor.git] / p2p / switch.go
index cfae512..72a995d 100644 (file)
@@ -2,38 +2,32 @@ package p2p
 
 import (
        "encoding/binary"
-       "encoding/hex"
-       "encoding/json"
        "fmt"
        "net"
        "sync"
        "time"
 
        log "github.com/sirupsen/logrus"
-       crypto "github.com/tendermint/go-crypto"
        cmn "github.com/tendermint/tmlibs/common"
 
        cfg "github.com/vapor/config"
        "github.com/vapor/consensus"
-       "github.com/vapor/crypto/ed25519"
        "github.com/vapor/crypto/sha3pool"
-       dbm "github.com/vapor/database/leveldb"
        "github.com/vapor/errors"
        "github.com/vapor/event"
        "github.com/vapor/p2p/connection"
        "github.com/vapor/p2p/discover/dht"
        "github.com/vapor/p2p/discover/mdns"
        "github.com/vapor/p2p/netutil"
-       "github.com/vapor/p2p/trust"
+       security "github.com/vapor/p2p/security"
+       "github.com/vapor/p2p/signlib"
        "github.com/vapor/version"
 )
 
 const (
-       bannedPeerKey      = "BannedPeer"
-       defaultBanDuration = time.Hour * 1
-       logModule          = "p2p"
+       logModule = "p2p"
 
-       minNumOutboundPeers = 4
+       minNumOutboundPeers = 3
        maxNumLANPeers      = 5
        //magicNumber used to generate unique netID
        magicNumber = uint64(0x054c5638)
@@ -41,10 +35,9 @@ const (
 
 //pre-define errors for connecting fail
 var (
-       ErrDuplicatePeer     = errors.New("Duplicate peer")
-       ErrConnectSelf       = errors.New("Connect self")
-       ErrConnectBannedPeer = errors.New("Connect banned peer")
-       ErrConnectSpvPeer    = errors.New("Outbound connect spv peer")
+       ErrDuplicatePeer  = errors.New("Duplicate peer")
+       ErrConnectSelf    = errors.New("Connect self")
+       ErrConnectSpvPeer = errors.New("Outbound connect spv peer")
 )
 
 type discv interface {
@@ -56,6 +49,13 @@ type lanDiscv interface {
        Stop()
 }
 
+type Security interface {
+       DoFilter(ip string, pubKey string) error
+       IsBanned(ip string, level byte, reason string) bool
+       RegisterFilter(filter security.Filter)
+       Start() error
+}
+
 // Switch handles peer connections and exposes an API to receive incoming messages
 // on `Reactors`.  Each `Reactor` is responsible for handling incoming messages of one
 // or more `Channels`.  So while sending outgoing messages is typically performed on the peer,
@@ -71,17 +71,15 @@ type Switch struct {
        reactorsByCh map[byte]Reactor
        peers        *PeerSet
        dialing      *cmn.CMap
-       nodeInfo     *NodeInfo             // our node info
-       nodePrivKey  crypto.PrivKeyEd25519 // our node privkey
+       nodeInfo     *NodeInfo       // our node info
+       nodePrivKey  signlib.PrivKey // our node privkey
        discv        discv
        lanDiscv     lanDiscv
-       bannedPeer   map[string]time.Time
-       db           dbm.DB
-       mtx          sync.Mutex
+       security     Security
 }
 
-// NewSwitch create a new Switch and set discover.
-func NewSwitch(config *cfg.Config) (*Switch, error) {
+// NewSwitchMaybeDiscover create a new Switch and set discover.
+func NewSwitchMaybeDiscover(config *cfg.Config) (*Switch, error) {
        var err error
        var l Listener
        var listenAddr string
@@ -98,35 +96,24 @@ func NewSwitch(config *cfg.Config) (*Switch, error) {
        sha3pool.Sum256(h[:], data)
        netID := binary.BigEndian.Uint64(h[:8])
 
-       blacklistDB := dbm.NewDB("trusthistory", config.DBBackend, config.DBDir())
-
-       _, yyy, _ := ed25519.GenerateKey(nil)
-       zzz := yyy.String()
-
-       bytes, err := hex.DecodeString(zzz)
-       if err != nil {
-               return nil, err
-       }
-       var newKey [64]byte
-       copy(newKey[:], bytes)
-       privKey := crypto.PrivKeyEd25519(newKey)
+       privateKey := config.PrivateKey()
        if !config.VaultMode {
                // Create listener
                l, listenAddr = GetListener(config.P2P)
-               discv, err = dht.NewDiscover(config, ed25519.PrivateKey(bytes), l.ExternalAddress().Port, netID)
+               discv, err = dht.NewDiscover(config, *privateKey, l.ExternalAddress().Port, netID)
                if err != nil {
                        return nil, err
                }
                if config.P2P.LANDiscover {
-                       lanDiscv = mdns.NewLANDiscover(mdns.NewProtocol(), int(l.ExternalAddress().Port))
+                       lanDiscv = mdns.NewLANDiscover(mdns.NewProtocol(config.ChainID), int(l.ExternalAddress().Port))
                }
        }
 
-       return newSwitch(config, discv, lanDiscv, blacklistDB, l, privKey, listenAddr, netID)
+       return NewSwitch(config, discv, lanDiscv, l, *privateKey, listenAddr, netID)
 }
 
 // newSwitch creates a new Switch with the given config.
-func newSwitch(config *cfg.Config, discv discv, lanDiscv lanDiscv, blacklistDB dbm.DB, l Listener, priv crypto.PrivKeyEd25519, listenAddr string, netID uint64) (*Switch, error) {
+func NewSwitch(config *cfg.Config, discv discv, lanDiscv lanDiscv, l Listener, privKey signlib.PrivKey, listenAddr string, netID uint64) (*Switch, error) {
        sw := &Switch{
                Config:       config,
                peerConfig:   DefaultPeerConfig(config.P2P),
@@ -135,24 +122,39 @@ func newSwitch(config *cfg.Config, discv discv, lanDiscv lanDiscv, blacklistDB d
                reactorsByCh: make(map[byte]Reactor),
                peers:        NewPeerSet(),
                dialing:      cmn.NewCMap(),
-               nodePrivKey:  priv,
+               nodePrivKey:  privKey,
                discv:        discv,
                lanDiscv:     lanDiscv,
-               db:           blacklistDB,
-               nodeInfo:     NewNodeInfo(config, priv.PubKey().Unwrap().(crypto.PubKeyEd25519), listenAddr, netID),
-               bannedPeer:   make(map[string]time.Time),
-       }
-       if err := sw.loadBannedPeers(); err != nil {
-               return nil, err
+               nodeInfo:     NewNodeInfo(config, privKey.XPub(), listenAddr, netID),
+               security:     security.NewSecurity(config),
        }
 
        sw.AddListener(l)
        sw.BaseService = *cmn.NewBaseService(nil, "P2P Switch", sw)
-       trust.Init()
        log.WithFields(log.Fields{"module": logModule, "nodeInfo": sw.nodeInfo}).Info("init p2p network")
        return sw, nil
 }
 
+func (sw *Switch) GetDiscv() discv {
+       return sw.discv
+}
+
+func (sw *Switch) GetNodeInfo() *NodeInfo {
+       return sw.nodeInfo
+}
+
+func (sw *Switch) GetPeers() *PeerSet {
+       return sw.peers
+}
+
+func (sw *Switch) GetReactors() map[string]Reactor {
+       return sw.reactors
+}
+
+func (sw *Switch) GetSecurity() Security {
+       return sw.security
+}
+
 // OnStart implements BaseService. It starts all the reactors, peers, and listeners.
 func (sw *Switch) OnStart() error {
        for _, reactor := range sw.reactors {
@@ -160,6 +162,13 @@ func (sw *Switch) OnStart() error {
                        return err
                }
        }
+
+       sw.security.RegisterFilter(sw.nodeInfo)
+       sw.security.RegisterFilter(sw.peers)
+       if err := sw.security.Start(); err != nil {
+               return err
+       }
+
        for _, listener := range sw.listeners {
                go sw.listenerRoutine(listener)
        }
@@ -190,21 +199,6 @@ func (sw *Switch) OnStop() {
        }
 }
 
-//AddBannedPeer add peer to blacklist
-func (sw *Switch) AddBannedPeer(ip string) error {
-       sw.mtx.Lock()
-       defer sw.mtx.Unlock()
-
-       sw.bannedPeer[ip] = time.Now().Add(defaultBanDuration)
-       dataJSON, err := json.Marshal(sw.bannedPeer)
-       if err != nil {
-               return err
-       }
-
-       sw.db.Set([]byte(bannedPeerKey), dataJSON)
-       return nil
-}
-
 // AddPeer performs the P2P handshake with a peer
 // that already has a SecretConnection. If all goes well,
 // it starts the peer and adds it to the switch.
@@ -225,7 +219,7 @@ func (sw *Switch) AddPeer(pc *peerConn, isLAN bool) error {
        }
 
        peer := newPeer(pc, peerNodeInfo, sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, isLAN)
-       if err := sw.filterConnByPeer(peer); err != nil {
+       if err := sw.security.DoFilter(peer.RemoteAddrHost(), peer.PubKey()); err != nil {
                return err
        }
 
@@ -272,7 +266,7 @@ func (sw *Switch) DialPeerWithAddress(addr *NetAddress) error {
        log.WithFields(log.Fields{"module": logModule, "address": addr}).Debug("Dialing peer")
        sw.dialing.Set(addr.IP.String(), addr)
        defer sw.dialing.Delete(addr.IP.String())
-       if err := sw.filterConnByIP(addr.IP.String()); err != nil {
+       if err := sw.security.DoFilter(addr.IP.String(), ""); err != nil {
                return err
        }
 
@@ -291,8 +285,8 @@ func (sw *Switch) DialPeerWithAddress(addr *NetAddress) error {
        return nil
 }
 
-func (sw *Switch) ID() [32]byte {
-       return sw.nodeInfo.PubKey
+func (sw *Switch) IsBanned(ip string, level byte, reason string) bool {
+       return sw.security.IsBanned(ip, level, reason)
 }
 
 //IsDialing prevent duplicate dialing
@@ -306,17 +300,6 @@ func (sw *Switch) IsListening() bool {
        return len(sw.listeners) > 0
 }
 
-// loadBannedPeers load banned peers from db
-func (sw *Switch) loadBannedPeers() error {
-       if dataJSON := sw.db.Get([]byte(bannedPeerKey)); dataJSON != nil {
-               if err := json.Unmarshal(dataJSON, &sw.bannedPeer); err != nil {
-                       return err
-               }
-       }
-
-       return nil
-}
-
 // Listeners returns the list of listeners the switch listens on.
 // NOTE: Not goroutine safe.
 func (sw *Switch) Listeners() []Listener {
@@ -378,22 +361,6 @@ func (sw *Switch) addPeerWithConnection(conn net.Conn) error {
        return nil
 }
 
-func (sw *Switch) checkBannedPeer(peer string) error {
-       sw.mtx.Lock()
-       defer sw.mtx.Unlock()
-
-       if banEnd, ok := sw.bannedPeer[peer]; ok {
-               if time.Now().Before(banEnd) {
-                       return ErrConnectBannedPeer
-               }
-
-               if err := sw.delBannedPeer(peer); err != nil {
-                       return err
-               }
-       }
-       return nil
-}
-
 func (sw *Switch) connectLANPeers(lanPeer mdns.LANPeerEvent) {
        lanPeers, _, _, numDialing := sw.NumPeers()
        numToDial := maxNumLANPeers - lanPeers
@@ -405,7 +372,7 @@ func (sw *Switch) connectLANPeers(lanPeer mdns.LANPeerEvent) {
        for i := 0; i < len(lanPeer.IP); i++ {
                addresses = append(addresses, NewLANNetAddressIPPort(lanPeer.IP[i], uint16(lanPeer.Port)))
        }
-       sw.dialPeers(addresses)
+       sw.DialPeers(addresses)
 }
 
 func (sw *Switch) connectLANPeersRoutine() {
@@ -438,42 +405,6 @@ func (sw *Switch) connectLANPeersRoutine() {
        }
 }
 
-func (sw *Switch) delBannedPeer(addr string) error {
-       sw.mtx.Lock()
-       defer sw.mtx.Unlock()
-
-       delete(sw.bannedPeer, addr)
-       datajson, err := json.Marshal(sw.bannedPeer)
-       if err != nil {
-               return err
-       }
-
-       sw.db.Set([]byte(bannedPeerKey), datajson)
-       return nil
-}
-
-func (sw *Switch) filterConnByIP(ip string) error {
-       if ip == sw.nodeInfo.listenHost() {
-               return ErrConnectSelf
-       }
-       return sw.checkBannedPeer(ip)
-}
-
-func (sw *Switch) filterConnByPeer(peer *Peer) error {
-       if err := sw.checkBannedPeer(peer.remoteAddrHost()); err != nil {
-               return err
-       }
-
-       if sw.nodeInfo.PubKey.Equals(peer.PubKey().Wrap()) {
-               return ErrConnectSelf
-       }
-
-       if sw.peers.Has(peer.Key) {
-               return ErrDuplicatePeer
-       }
-       return nil
-}
-
 func (sw *Switch) listenerRoutine(l Listener) {
        for {
                inConn, ok := <-l.Connections()
@@ -481,7 +412,7 @@ func (sw *Switch) listenerRoutine(l Listener) {
                        break
                }
 
-               // disconnect if we alrady have MaxNumPeers
+               // disconnect if we already have MaxNumPeers
                if sw.peers.Size() >= sw.Config.P2P.MaxNumPeers {
                        if err := inConn.Close(); err != nil {
                                log.WithFields(log.Fields{"module": logModule, "remote peer:": inConn.RemoteAddr().String(), " err:": err}).Error("closes connection err")
@@ -505,10 +436,10 @@ func (sw *Switch) dialPeerWorker(a *NetAddress, wg *sync.WaitGroup) {
        wg.Done()
 }
 
-func (sw *Switch) dialPeers(addresses []*NetAddress) {
+func (sw *Switch) DialPeers(addresses []*NetAddress) {
        connectedPeers := make(map[string]struct{})
        for _, peer := range sw.Peers().List() {
-               connectedPeers[peer.remoteAddrHost()] = struct{}{}
+               connectedPeers[peer.RemoteAddrHost()] = struct{}{}
        }
 
        var wg sync.WaitGroup
@@ -541,7 +472,7 @@ func (sw *Switch) ensureKeepConnectPeers() {
                addresses = append(addresses, address)
        }
 
-       sw.dialPeers(addresses)
+       sw.DialPeers(addresses)
 }
 
 func (sw *Switch) ensureOutboundPeers() {
@@ -559,7 +490,7 @@ func (sw *Switch) ensureOutboundPeers() {
                address := NewNetAddressIPPort(nodes[i].IP, nodes[i].TCP)
                addresses = append(addresses, address)
        }
-       sw.dialPeers(addresses)
+       sw.DialPeers(addresses)
 }
 
 func (sw *Switch) ensureOutboundPeersRoutine() {