OSDN Git Service

Add p2p security module (#143)
authoryahtoo <yahtoo.ma@gmail.com>
Mon, 10 Jun 2019 02:41:36 +0000 (10:41 +0800)
committerPaladz <yzhu101@uottawa.ca>
Mon, 10 Jun 2019 02:41:36 +0000 (10:41 +0800)
* tmp

* Add p2p security module

* Fix review bugs

19 files changed:
netsync/chainmgr/block_keeper.go
netsync/chainmgr/handle.go
netsync/chainmgr/tool_test.go
netsync/consensusmgr/block_fetcher.go
netsync/consensusmgr/block_fetcher_test.go
netsync/consensusmgr/handle.go
netsync/peers/peer.go
p2p/node_info.go
p2p/peer_set.go
p2p/peer_test.go
p2p/security/banscore.go [moved from p2p/trust/banscore.go with 99% similarity]
p2p/security/banscore_test.go [moved from p2p/trust/banscore_test.go with 99% similarity]
p2p/security/blacklist.go [new file with mode: 0644]
p2p/security/filter.go [new file with mode: 0644]
p2p/security/score.go [new file with mode: 0644]
p2p/security/security.go [new file with mode: 0644]
p2p/switch.go
p2p/switch_test.go
p2p/test_util.go

index 8ce20ea..112fd50 100644 (file)
@@ -9,6 +9,7 @@ import (
        "github.com/vapor/consensus"
        "github.com/vapor/errors"
        "github.com/vapor/netsync/peers"
+       "github.com/vapor/p2p/security"
        "github.com/vapor/protocol/bc"
        "github.com/vapor/protocol/bc/types"
 )
@@ -353,7 +354,7 @@ func (bk *blockKeeper) startSync() bool {
                bk.syncPeer = peer
                if err := bk.fastBlockSync(checkPoint); err != nil {
                        log.WithFields(log.Fields{"module": logModule, "err": err}).Warning("fail on fastBlockSync")
-                       bk.peers.ErrorHandler(peer.ID(), err)
+                       bk.peers.ErrorHandler(peer.ID(), security.LevelMsgIllegal, err)
                        return false
                }
                return true
@@ -370,7 +371,7 @@ func (bk *blockKeeper) startSync() bool {
 
                if err := bk.regularBlockSync(targetHeight); err != nil {
                        log.WithFields(log.Fields{"module": logModule, "err": err}).Warning("fail on regularBlockSync")
-                       bk.peers.ErrorHandler(peer.ID(), err)
+                       bk.peers.ErrorHandler(peer.ID(),security.LevelMsgIllegal, err)
                        return false
                }
                return true
index 51dee4c..0d13847 100644 (file)
@@ -12,6 +12,7 @@ import (
        msgs "github.com/vapor/netsync/messages"
        "github.com/vapor/netsync/peers"
        "github.com/vapor/p2p"
+       "github.com/vapor/p2p/security"
        core "github.com/vapor/protocol"
        "github.com/vapor/protocol/bc"
        "github.com/vapor/protocol/bc/types"
@@ -37,7 +38,6 @@ type Chain interface {
 
 type Switch interface {
        AddReactor(name string, reactor p2p.Reactor) p2p.Reactor
-       AddBannedPeer(string) error
        Start() (bool, error)
        Stop() bool
        IsListening() bool
@@ -247,12 +247,12 @@ func (m *Manager) handleStatusMsg(basePeer peers.BasePeer, msg *msgs.StatusMessa
 func (m *Manager) handleTransactionMsg(peer *peers.Peer, msg *msgs.TransactionMessage) {
        tx, err := msg.GetTransaction()
        if err != nil {
-               m.peers.AddBanScore(peer.ID(), 0, 10, "fail on get tx from message")
+               m.peers.ProcessIllegal(peer.ID(), security.LevelConnException, "fail on get tx from message")
                return
        }
 
        if isOrphan, err := m.chain.ValidateTx(tx); err != nil && err != core.ErrDustTx && !isOrphan {
-               m.peers.AddBanScore(peer.ID(), 10, 0, "fail on validate tx transaction")
+               m.peers.ProcessIllegal(peer.ID(), security.LevelMsgIllegal, "fail on validate tx transaction")
        }
        m.peers.MarkTx(peer.ID(), tx.ID)
 }
@@ -260,18 +260,18 @@ func (m *Manager) handleTransactionMsg(peer *peers.Peer, msg *msgs.TransactionMe
 func (m *Manager) handleTransactionsMsg(peer *peers.Peer, msg *msgs.TransactionsMessage) {
        txs, err := msg.GetTransactions()
        if err != nil {
-               m.peers.AddBanScore(peer.ID(), 0, 20, "fail on get txs from message")
+               m.peers.ProcessIllegal(peer.ID(), security.LevelConnException, "fail on get txs from message")
                return
        }
 
        if len(txs) > msgs.TxsMsgMaxTxNum {
-               m.peers.AddBanScore(peer.ID(), 20, 0, "exceeded the maximum tx number limit")
+               m.peers.ProcessIllegal(peer.ID(), security.LevelMsgIllegal, "exceeded the maximum tx number limit")
                return
        }
 
        for _, tx := range txs {
                if isOrphan, err := m.chain.ValidateTx(tx); err != nil && !isOrphan {
-                       m.peers.AddBanScore(peer.ID(), 10, 0, "fail on validate tx transaction")
+                       m.peers.ProcessIllegal(peer.ID(), security.LevelMsgIllegal, "fail on validate tx transaction")
                        return
                }
                m.peers.MarkTx(peer.ID(), tx.ID)
index e354984..db17b4f 100644 (file)
@@ -89,8 +89,11 @@ func NewPeerSet() *PeerSet {
        return &PeerSet{}
 }
 
-func (ps *PeerSet) AddBannedPeer(string) error { return nil }
-func (ps *PeerSet) StopPeerGracefully(string)  {}
+func (ps *PeerSet) IsBanned(peerID string, level byte, reason string) bool {
+       return false
+}
+
+func (ps *PeerSet) StopPeerGracefully(string) {}
 
 type NetWork struct {
        nodes map[*Manager]P2PPeer
index 3b6c746..dc42923 100644 (file)
@@ -5,6 +5,7 @@ import (
        "gopkg.in/karalabe/cookiejar.v2/collections/prque"
 
        "github.com/vapor/netsync/peers"
+       "github.com/vapor/p2p/security"
        "github.com/vapor/protocol/bc"
 )
 
@@ -80,7 +81,7 @@ func (f *blockFetcher) insert(msg *blockMsg) {
                        return
                }
 
-               f.peers.AddBanScore(msg.peerID, 20, 0, err.Error())
+               f.peers.ProcessIllegal(msg.peerID, security.LevelMsgIllegal, err.Error())
                return
        }
 
index f7d6936..f5b44a8 100644 (file)
@@ -12,8 +12,8 @@ import (
 type peerMgr struct {
 }
 
-func (pm *peerMgr) AddBannedPeer(string) error {
-       return nil
+func (pm *peerMgr) IsBanned(peerID string, level byte, reason string) bool{
+       return false
 }
 
 func (pm *peerMgr) StopPeerGracefully(string) {
index 668a2f7..66fddd9 100644 (file)
@@ -8,6 +8,7 @@ import (
        "github.com/vapor/event"
        "github.com/vapor/netsync/peers"
        "github.com/vapor/p2p"
+       "github.com/vapor/p2p/security"
        "github.com/vapor/protocol/bc"
        "github.com/vapor/protocol/bc/types"
 )
@@ -15,7 +16,6 @@ import (
 // Switch is the interface for p2p switch.
 type Switch interface {
        AddReactor(name string, reactor p2p.Reactor) p2p.Reactor
-       AddBannedPeer(string) error
 }
 
 // Chain is the interface for Bytom core.
@@ -96,7 +96,7 @@ func (m *Manager) handleBlockProposeMsg(peerID string, msg *BlockProposeMsg) {
 func (m *Manager) handleBlockSignatureMsg(peerID string, msg *BlockSignatureMsg) {
        blockHash := bc.NewHash(msg.BlockHash)
        if err := m.chain.ProcessBlockSignature(msg.Signature, msg.PubKey, &blockHash); err != nil {
-               m.peers.AddBanScore(peerID, 20, 0, err.Error())
+               m.peers.ProcessIllegal(peerID, security.LevelMsgIllegal, err.Error())
                return
        }
 }
index 0f62149..1f0ac24 100644 (file)
@@ -13,7 +13,6 @@ import (
        "github.com/vapor/consensus"
        "github.com/vapor/errors"
        msgs "github.com/vapor/netsync/messages"
-       "github.com/vapor/p2p/trust"
        "github.com/vapor/protocol/bc"
        "github.com/vapor/protocol/bc/types"
 )
@@ -22,7 +21,6 @@ const (
        maxKnownTxs           = 32768 // Maximum transactions hashes to keep in the known list (prevent DOS)
        maxKnownSignatures    = 1024  // Maximum block signatures to keep in the known list (prevent DOS)
        maxKnownBlocks        = 1024  // Maximum block hashes to keep in the known list (prevent DOS)
-       defaultBanThreshold   = uint32(100)
        maxFilterAddressSize  = 50
        maxFilterAddressCount = 1000
 
@@ -46,8 +44,8 @@ type BasePeer interface {
 
 //BasePeerSet is the intergace for connection level peer manager
 type BasePeerSet interface {
-       AddBannedPeer(string) error
        StopPeerGracefully(string)
+       IsBanned(peerID string, level byte, reason string) bool
 }
 
 type BroadcastMsg interface {
@@ -79,7 +77,6 @@ type Peer struct {
        services        consensus.ServiceFlag
        height          uint64
        hash            *bc.Hash
-       banScore        trust.DynamicBanScore
        knownTxs        *set.Set // Set of transaction hashes known to be known by this peer
        knownBlocks     *set.Set // Set of block hashes known to be known by this peer
        knownSignatures *set.Set // Set of block signatures known to be known by this peer
@@ -104,30 +101,6 @@ func (p *Peer) Height() uint64 {
        return p.height
 }
 
-func (p *Peer) addBanScore(persistent, transient uint32, reason string) bool {
-       score := p.banScore.Increase(persistent, transient)
-       if score > defaultBanThreshold {
-               log.WithFields(log.Fields{
-                       "module":  logModule,
-                       "address": p.Addr(),
-                       "score":   score,
-                       "reason":  reason,
-               }).Errorf("banning and disconnecting")
-               return true
-       }
-
-       warnThreshold := defaultBanThreshold >> 1
-       if score > warnThreshold {
-               log.WithFields(log.Fields{
-                       "module":  logModule,
-                       "address": p.Addr(),
-                       "score":   score,
-                       "reason":  reason,
-               }).Warning("ban score increasing")
-       }
-       return false
-}
-
 func (p *Peer) AddFilterAddress(address []byte) {
        p.mtx.Lock()
        defer p.mtx.Unlock()
@@ -417,7 +390,7 @@ func NewPeerSet(basePeerSet BasePeerSet) *PeerSet {
        }
 }
 
-func (ps *PeerSet) AddBanScore(peerID string, persistent, transient uint32, reason string) {
+func (ps *PeerSet) ProcessIllegal(peerID string, level byte, reason string) {
        ps.mtx.Lock()
        peer := ps.peers[peerID]
        ps.mtx.Unlock()
@@ -425,13 +398,10 @@ func (ps *PeerSet) AddBanScore(peerID string, persistent, transient uint32, reas
        if peer == nil {
                return
        }
-       if ban := peer.addBanScore(persistent, transient, reason); !ban {
-               return
-       }
-       if err := ps.AddBannedPeer(peer.Addr().String()); err != nil {
-               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on add ban peer")
+       if banned := ps.IsBanned(peer.Addr().String(), level, reason); banned {
+               ps.RemovePeer(peerID)
        }
-       ps.RemovePeer(peerID)
+       return
 }
 
 func (ps *PeerSet) AddPeer(peer BasePeer) {
@@ -536,9 +506,9 @@ func (ps *PeerSet) BroadcastTx(tx *types.Tx) error {
        return nil
 }
 
-func (ps *PeerSet) ErrorHandler(peerID string, err error) {
+func (ps *PeerSet) ErrorHandler(peerID string, level byte, err error) {
        if errors.Root(err) == ErrPeerMisbehave {
-               ps.AddBanScore(peerID, 20, 0, err.Error())
+               ps.ProcessIllegal(peerID, level, err.Error())
        } else {
                ps.RemovePeer(peerID)
        }
index 00f818d..e602a0d 100644 (file)
@@ -71,6 +71,14 @@ func (info *NodeInfo) compatibleWith(other *NodeInfo, versionCompatibleWith Vers
        return nil
 }
 
+func (info NodeInfo) DoFilter(ip string, pubKey string) error {
+       if info.PubKey == pubKey {
+               return ErrConnectSelf
+       }
+
+       return nil
+}
+
 //listenHost peer listener ip address
 func (info NodeInfo) listenHost() string {
        host, _, _ := net.SplitHostPort(info.ListenAddr)
index e26746b..c652371 100644 (file)
@@ -50,6 +50,14 @@ func (ps *PeerSet) Add(peer *Peer) error {
        return nil
 }
 
+func (ps *PeerSet) DoFilter(ip string, pubKey string) error {
+       if ps.Has(pubKey) {
+               return ErrDuplicatePeer
+       }
+
+       return nil
+}
+
 // Get looks up a peer by the provided peerKey.
 func (ps *PeerSet) Get(peerKey string) *Peer {
        ps.mtx.Lock()
index 8559f11..bcaed18 100644 (file)
@@ -7,6 +7,7 @@ import (
        "time"
 
        cfg "github.com/vapor/config"
+       "github.com/vapor/consensus"
        conn "github.com/vapor/p2p/connection"
        "github.com/vapor/p2p/signlib"
        "github.com/vapor/version"
@@ -142,11 +143,12 @@ func (rp *remotePeer) accept(l net.Listener) {
                }
 
                _, err = pc.HandshakeTimeout(&NodeInfo{
-                       PubKey:     rp.PrivKey.XPub().String(),
-                       Moniker:    "remote_peer",
-                       Network:    rp.Config.ChainID,
-                       Version:    version.Version,
-                       ListenAddr: l.Addr().String(),
+                       PubKey:      rp.PrivKey.XPub().String(),
+                       Moniker:     "remote_peer",
+                       Network:     rp.Config.ChainID,
+                       Version:     version.Version,
+                       ListenAddr:  l.Addr().String(),
+                       ServiceFlag: consensus.DefaultServices,
                }, 5*time.Second)
                if err != nil {
                        fmt.Println("Failed to perform handshake:", err)
similarity index 99%
rename from p2p/trust/banscore.go
rename to p2p/security/banscore.go
index 892d653..5892a5f 100644 (file)
@@ -1,4 +1,4 @@
-package trust
+package security
 
 import (
        "fmt"
@@ -29,7 +29,7 @@ const (
 var precomputedFactor [precomputedLen]float64
 
 // init precomputes decay factors.
-func Init() {
+func init() {
        for i := range precomputedFactor {
                precomputedFactor[i] = math.Exp(-1.0 * float64(i) * lambda)
        }
similarity index 99%
rename from p2p/trust/banscore_test.go
rename to p2p/security/banscore_test.go
index a4a4fcf..6dd0944 100644 (file)
@@ -1,4 +1,4 @@
-package trust
+package security
 
 import (
        "math"
@@ -23,7 +23,6 @@ func TestInt(t *testing.T) {
                {bs: DynamicBanScore{lastUnix: 0, transient: math.MaxUint32, persistent: math.MaxUint32}, timeLapse: 0, wantValue: math.MaxUint32 - 1},
        }
 
-       Init()
        for i, intTest := range banScoreIntTests {
                rst := intTest.bs.int(time.Unix(intTest.timeLapse, 0))
                if rst != intTest.wantValue {
@@ -53,7 +52,6 @@ func TestIncrease(t *testing.T) {
                {bs: DynamicBanScore{lastUnix: 0, transient: 0, persistent: math.MaxUint32}, transientAdd: math.MaxUint32, persistentAdd: 0, timeLapse: Lifetime + 1, wantValue: math.MaxUint32 - 1},
        }
 
-       Init()
        for i, incTest := range banScoreIncreaseTests {
                rst := incTest.bs.increase(incTest.persistentAdd, incTest.transientAdd, time.Unix(incTest.timeLapse, 0))
                if rst != incTest.wantValue {
diff --git a/p2p/security/blacklist.go b/p2p/security/blacklist.go
new file mode 100644 (file)
index 0000000..f8ca05e
--- /dev/null
@@ -0,0 +1,91 @@
+package security
+
+import (
+       "encoding/json"
+       "errors"
+       "sync"
+       "time"
+
+       cfg "github.com/vapor/config"
+       dbm "github.com/vapor/database/leveldb"
+)
+
+const (
+       defaultBanDuration = time.Hour * 1
+       blacklistKey       = "BlacklistPeers"
+)
+
+var (
+       ErrConnectBannedPeer = errors.New("connect banned peer")
+)
+
+type Blacklist struct {
+       peers map[string]time.Time
+       db    dbm.DB
+
+       mtx sync.Mutex
+}
+
+func NewBlacklist(config *cfg.Config) *Blacklist {
+       return &Blacklist{
+               peers: make(map[string]time.Time),
+               db:    dbm.NewDB("blacklist", config.DBBackend, config.DBDir()),
+       }
+}
+
+//AddPeer add peer to blacklist
+func (bl *Blacklist) AddPeer(ip string) error {
+       bl.mtx.Lock()
+       defer bl.mtx.Unlock()
+
+       bl.peers[ip] = time.Now().Add(defaultBanDuration)
+       dataJSON, err := json.Marshal(bl.peers)
+       if err != nil {
+               return err
+       }
+
+       bl.db.Set([]byte(blacklistKey), dataJSON)
+       return nil
+}
+
+func (bl *Blacklist) delPeer(ip string) error {
+       delete(bl.peers, ip)
+       dataJson, err := json.Marshal(bl.peers)
+       if err != nil {
+               return err
+       }
+
+       bl.db.Set([]byte(blacklistKey), dataJson)
+       return nil
+}
+
+func (bl *Blacklist) DoFilter(ip string, pubKey string) error {
+       bl.mtx.Lock()
+       defer bl.mtx.Unlock()
+
+       if banEnd, ok := bl.peers[ip]; ok {
+               if time.Now().Before(banEnd) {
+                       return ErrConnectBannedPeer
+               }
+
+               if err := bl.delPeer(ip); err != nil {
+                       return err
+               }
+       }
+
+       return nil
+}
+
+// LoadPeers load banned peers from db
+func (bl *Blacklist) LoadPeers() error {
+       bl.mtx.Lock()
+       defer bl.mtx.Unlock()
+
+       if dataJSON := bl.db.Get([]byte(blacklistKey)); dataJSON != nil {
+               if err := json.Unmarshal(dataJSON, &bl.peers); err != nil {
+                       return err
+               }
+       }
+
+       return nil
+}
diff --git a/p2p/security/filter.go b/p2p/security/filter.go
new file mode 100644 (file)
index 0000000..409952a
--- /dev/null
@@ -0,0 +1,38 @@
+package security
+
+import "sync"
+
+type Filter interface {
+       DoFilter(string, string) error
+}
+
+type PeerFilter struct {
+       filterChain []Filter
+       mtx         sync.RWMutex
+}
+
+func NewPeerFilter() *PeerFilter {
+       return &PeerFilter{
+               filterChain: make([]Filter, 0),
+       }
+}
+
+func (pf *PeerFilter) register(filter Filter) {
+       pf.mtx.Lock()
+       defer pf.mtx.Unlock()
+
+       pf.filterChain = append(pf.filterChain, filter)
+}
+
+func (pf *PeerFilter) doFilter(ip string, pubKey string) error {
+       pf.mtx.RLock()
+       defer pf.mtx.RUnlock()
+
+       for _, filter := range pf.filterChain {
+               if err := filter.DoFilter(ip, pubKey); err != nil {
+                       return err
+               }
+       }
+
+       return nil
+}
diff --git a/p2p/security/score.go b/p2p/security/score.go
new file mode 100644 (file)
index 0000000..fea3149
--- /dev/null
@@ -0,0 +1,69 @@
+package security
+
+import (
+       "sync"
+
+       log "github.com/sirupsen/logrus"
+)
+
+const (
+       defaultBanThreshold  = uint32(100)
+       defaultWarnThreshold = uint32(50)
+
+       LevelMsgIllegal              = 0x01
+       levelMsgIllegalPersistent    = uint32(20)
+       levelMsgIllegalTransient     = uint32(0)
+       LevelConnException           = 0x02
+       levelConnExceptionPersistent = uint32(0)
+       levelConnExceptionTransient  = uint32(20)
+)
+
+type PeersBanScore struct {
+       peers map[string]*DynamicBanScore
+       mtx   sync.Mutex
+}
+
+func NewPeersScore() *PeersBanScore {
+       return &PeersBanScore{
+               peers: make(map[string]*DynamicBanScore),
+       }
+}
+
+func (ps *PeersBanScore) DelPeer(ip string) {
+       ps.mtx.Lock()
+       defer ps.mtx.Unlock()
+
+       delete(ps.peers, ip)
+}
+
+func (ps *PeersBanScore) Increase(ip string, level byte, reason string) bool {
+       ps.mtx.Lock()
+       defer ps.mtx.Unlock()
+
+       var persistent, transient uint32
+       switch level {
+       case LevelMsgIllegal:
+               persistent = levelMsgIllegalPersistent
+               transient = levelMsgIllegalTransient
+       case LevelConnException:
+               persistent = levelConnExceptionPersistent
+               transient = levelConnExceptionTransient
+       default:
+               return false
+       }
+       banScore, ok := ps.peers[ip]
+       if !ok {
+               banScore = &DynamicBanScore{}
+               ps.peers[ip] = banScore
+       }
+       score := banScore.Increase(persistent, transient)
+       if score > defaultBanThreshold {
+               log.WithFields(log.Fields{"module": logModule, "address": ip, "score": score, "reason": reason}).Errorf("banning and disconnecting")
+               return true
+       }
+
+       if score > defaultWarnThreshold {
+               log.WithFields(log.Fields{"module": logModule, "address": ip, "score": score, "reason": reason}).Warning("ban score increasing")
+       }
+       return false
+}
diff --git a/p2p/security/security.go b/p2p/security/security.go
new file mode 100644 (file)
index 0000000..7c9e945
--- /dev/null
@@ -0,0 +1,53 @@
+package security
+
+import (
+       log "github.com/sirupsen/logrus"
+
+       cfg "github.com/vapor/config"
+)
+
+const logModule = "p2p/security"
+
+type Security struct {
+       filter        *PeerFilter
+       blacklist     *Blacklist
+       peersBanScore *PeersBanScore
+}
+
+func NewSecurity(config *cfg.Config) *Security {
+       return &Security{
+               filter:        NewPeerFilter(),
+               blacklist:     NewBlacklist(config),
+               peersBanScore: NewPeersScore(),
+       }
+}
+
+func (s *Security) DoFilter(ip string, pubKey string) error {
+       return s.filter.doFilter(ip, pubKey)
+}
+
+func (s *Security) IsBanned(ip string, level byte, reason string) bool {
+       if ok := s.peersBanScore.Increase(ip, level, reason); !ok {
+               return false
+       }
+
+       if err := s.blacklist.AddPeer(ip); err != nil {
+               log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on add ban peer")
+       }
+       //clear peer score
+       s.peersBanScore.DelPeer(ip)
+       return true
+}
+
+func (s *Security) RegisterFilter(filter Filter) {
+       s.filter.register(filter)
+}
+
+func (s *Security) Start() error {
+       if err := s.blacklist.LoadPeers(); err != nil {
+               return err
+       }
+
+       s.filter.register(s.blacklist)
+       return nil
+}
index 62d47f1..ef8306d 100644 (file)
@@ -2,7 +2,6 @@ package p2p
 
 import (
        "encoding/binary"
-       "encoding/json"
        "fmt"
        "net"
        "sync"
@@ -14,7 +13,6 @@ import (
        cfg "github.com/vapor/config"
        "github.com/vapor/consensus"
        "github.com/vapor/crypto/sha3pool"
-       dbm "github.com/vapor/database/leveldb"
        "github.com/vapor/errors"
        "github.com/vapor/event"
        "github.com/vapor/p2p/connection"
@@ -22,14 +20,12 @@ import (
        "github.com/vapor/p2p/discover/mdns"
        "github.com/vapor/p2p/netutil"
        "github.com/vapor/p2p/signlib"
-       "github.com/vapor/p2p/trust"
+       security "github.com/vapor/p2p/security"
        "github.com/vapor/version"
 )
 
 const (
-       bannedPeerKey      = "BannedPeer"
-       defaultBanDuration = time.Hour * 1
-       logModule          = "p2p"
+       logModule = "p2p"
 
        minNumOutboundPeers = 4
        maxNumLANPeers      = 5
@@ -39,10 +35,9 @@ const (
 
 //pre-define errors for connecting fail
 var (
-       ErrDuplicatePeer     = errors.New("Duplicate peer")
-       ErrConnectSelf       = errors.New("Connect self")
-       ErrConnectBannedPeer = errors.New("Connect banned peer")
-       ErrConnectSpvPeer    = errors.New("Outbound connect spv peer")
+       ErrDuplicatePeer  = errors.New("Duplicate peer")
+       ErrConnectSelf    = errors.New("Connect self")
+       ErrConnectSpvPeer = errors.New("Outbound connect spv peer")
 )
 
 type discv interface {
@@ -54,6 +49,13 @@ type lanDiscv interface {
        Stop()
 }
 
+type Security interface {
+       DoFilter(ip string, pubKey string) error
+       IsBanned(ip string, level byte, reason string) bool
+       RegisterFilter(filter security.Filter)
+       Start() error
+}
+
 // Switch handles peer connections and exposes an API to receive incoming messages
 // on `Reactors`.  Each `Reactor` is responsible for handling incoming messages of one
 // or more `Channels`.  So while sending outgoing messages is typically performed on the peer,
@@ -73,9 +75,7 @@ type Switch struct {
        nodePrivKey  signlib.PrivKey // our node privkey
        discv        discv
        lanDiscv     lanDiscv
-       bannedPeer   map[string]time.Time
-       db           dbm.DB
-       mtx          sync.Mutex
+       security     Security
 }
 
 // NewSwitch create a new Switch and set discover.
@@ -96,7 +96,6 @@ func NewSwitch(config *cfg.Config) (*Switch, error) {
        sha3pool.Sum256(h[:], data)
        netID := binary.BigEndian.Uint64(h[:8])
 
-       blacklistDB := dbm.NewDB("trusthistory", config.DBBackend, config.DBDir())
        privateKey := config.PrivateKey()
        if !config.VaultMode {
                // Create listener
@@ -110,11 +109,11 @@ func NewSwitch(config *cfg.Config) (*Switch, error) {
                }
        }
 
-       return newSwitch(config, discv, lanDiscv, blacklistDB, l, *privateKey, listenAddr, netID)
+       return newSwitch(config, discv, lanDiscv, l, *privateKey, listenAddr, netID)
 }
 
 // newSwitch creates a new Switch with the given config.
-func newSwitch(config *cfg.Config, discv discv, lanDiscv lanDiscv, blacklistDB dbm.DB, l Listener, privKey signlib.PrivKey, listenAddr string, netID uint64) (*Switch, error) {
+func newSwitch(config *cfg.Config, discv discv, lanDiscv lanDiscv, l Listener, privKey signlib.PrivKey, listenAddr string, netID uint64) (*Switch, error) {
        sw := &Switch{
                Config:       config,
                peerConfig:   DefaultPeerConfig(config.P2P),
@@ -126,17 +125,12 @@ func newSwitch(config *cfg.Config, discv discv, lanDiscv lanDiscv, blacklistDB d
                nodePrivKey:  privKey,
                discv:        discv,
                lanDiscv:     lanDiscv,
-               db:           blacklistDB,
                nodeInfo:     NewNodeInfo(config, privKey.XPub(), listenAddr, netID),
-               bannedPeer:   make(map[string]time.Time),
-       }
-       if err := sw.loadBannedPeers(); err != nil {
-               return nil, err
+               security:     security.NewSecurity(config),
        }
 
        sw.AddListener(l)
        sw.BaseService = *cmn.NewBaseService(nil, "P2P Switch", sw)
-       trust.Init()
        log.WithFields(log.Fields{"module": logModule, "nodeInfo": sw.nodeInfo}).Info("init p2p network")
        return sw, nil
 }
@@ -148,6 +142,13 @@ func (sw *Switch) OnStart() error {
                        return err
                }
        }
+
+       sw.security.RegisterFilter(sw.nodeInfo)
+       sw.security.RegisterFilter(sw.peers)
+       if err := sw.security.Start(); err != nil {
+               return err
+       }
+
        for _, listener := range sw.listeners {
                go sw.listenerRoutine(listener)
        }
@@ -178,21 +179,6 @@ func (sw *Switch) OnStop() {
        }
 }
 
-//AddBannedPeer add peer to blacklist
-func (sw *Switch) AddBannedPeer(ip string) error {
-       sw.mtx.Lock()
-       defer sw.mtx.Unlock()
-
-       sw.bannedPeer[ip] = time.Now().Add(defaultBanDuration)
-       dataJSON, err := json.Marshal(sw.bannedPeer)
-       if err != nil {
-               return err
-       }
-
-       sw.db.Set([]byte(bannedPeerKey), dataJSON)
-       return nil
-}
-
 // AddPeer performs the P2P handshake with a peer
 // that already has a SecretConnection. If all goes well,
 // it starts the peer and adds it to the switch.
@@ -213,7 +199,7 @@ func (sw *Switch) AddPeer(pc *peerConn, isLAN bool) error {
        }
 
        peer := newPeer(pc, peerNodeInfo, sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, isLAN)
-       if err := sw.filterConnByPeer(peer); err != nil {
+       if err := sw.security.DoFilter(peer.remoteAddrHost(), peer.PubKey()); err != nil {
                return err
        }
 
@@ -260,7 +246,7 @@ func (sw *Switch) DialPeerWithAddress(addr *NetAddress) error {
        log.WithFields(log.Fields{"module": logModule, "address": addr}).Debug("Dialing peer")
        sw.dialing.Set(addr.IP.String(), addr)
        defer sw.dialing.Delete(addr.IP.String())
-       if err := sw.filterConnByIP(addr.IP.String()); err != nil {
+       if err := sw.security.DoFilter(addr.IP.String(), ""); err != nil {
                return err
        }
 
@@ -279,6 +265,10 @@ func (sw *Switch) DialPeerWithAddress(addr *NetAddress) error {
        return nil
 }
 
+func (sw *Switch) IsBanned(ip string, level byte, reason string) bool {
+       return sw.security.IsBanned(ip, level, reason)
+}
+
 //IsDialing prevent duplicate dialing
 func (sw *Switch) IsDialing(addr *NetAddress) bool {
        return sw.dialing.Has(addr.IP.String())
@@ -290,17 +280,6 @@ func (sw *Switch) IsListening() bool {
        return len(sw.listeners) > 0
 }
 
-// loadBannedPeers load banned peers from db
-func (sw *Switch) loadBannedPeers() error {
-       if dataJSON := sw.db.Get([]byte(bannedPeerKey)); dataJSON != nil {
-               if err := json.Unmarshal(dataJSON, &sw.bannedPeer); err != nil {
-                       return err
-               }
-       }
-
-       return nil
-}
-
 // Listeners returns the list of listeners the switch listens on.
 // NOTE: Not goroutine safe.
 func (sw *Switch) Listeners() []Listener {
@@ -362,22 +341,6 @@ func (sw *Switch) addPeerWithConnection(conn net.Conn) error {
        return nil
 }
 
-func (sw *Switch) checkBannedPeer(peer string) error {
-       sw.mtx.Lock()
-       defer sw.mtx.Unlock()
-
-       if banEnd, ok := sw.bannedPeer[peer]; ok {
-               if time.Now().Before(banEnd) {
-                       return ErrConnectBannedPeer
-               }
-
-               if err := sw.delBannedPeer(peer); err != nil {
-                       return err
-               }
-       }
-       return nil
-}
-
 func (sw *Switch) connectLANPeers(lanPeer mdns.LANPeerEvent) {
        lanPeers, _, _, numDialing := sw.NumPeers()
        numToDial := maxNumLANPeers - lanPeers
@@ -422,42 +385,6 @@ func (sw *Switch) connectLANPeersRoutine() {
        }
 }
 
-func (sw *Switch) delBannedPeer(addr string) error {
-       sw.mtx.Lock()
-       defer sw.mtx.Unlock()
-
-       delete(sw.bannedPeer, addr)
-       datajson, err := json.Marshal(sw.bannedPeer)
-       if err != nil {
-               return err
-       }
-
-       sw.db.Set([]byte(bannedPeerKey), datajson)
-       return nil
-}
-
-func (sw *Switch) filterConnByIP(ip string) error {
-       if ip == sw.nodeInfo.listenHost() {
-               return ErrConnectSelf
-       }
-       return sw.checkBannedPeer(ip)
-}
-
-func (sw *Switch) filterConnByPeer(peer *Peer) error {
-       if err := sw.checkBannedPeer(peer.remoteAddrHost()); err != nil {
-               return err
-       }
-
-       if sw.nodeInfo.PubKey == peer.PubKey() {
-               return ErrConnectSelf
-       }
-
-       if sw.peers.Has(peer.Key) {
-               return ErrDuplicatePeer
-       }
-       return nil
-}
-
 func (sw *Switch) listenerRoutine(l Listener) {
        for {
                inConn, ok := <-l.Connections()
index 6b2cdb8..19c475e 100644 (file)
@@ -12,6 +12,7 @@ import (
        dbm "github.com/vapor/database/leveldb"
        "github.com/vapor/errors"
        conn "github.com/vapor/p2p/connection"
+       "github.com/vapor/p2p/security"
        "github.com/vapor/p2p/signlib"
 )
 
@@ -125,7 +126,6 @@ func initSwitchFunc(sw *Switch) *Switch {
 
 //Test connect self.
 func TestFiltersOutItself(t *testing.T) {
-       t.Skip("skipping test")
        dirPath, err := ioutil.TempDir(".", "")
        if err != nil {
                t.Fatal(err)
@@ -134,6 +134,7 @@ func TestFiltersOutItself(t *testing.T) {
 
        testDB := dbm.NewDB("testdb", "leveldb", dirPath)
        cfg := *testCfg
+       cfg.DBPath = dirPath
        cfg.P2P.ListenAddress = "127.0.1.1:0"
        swPrivKey, err := signlib.NewPrivKey()
        if err != nil {
@@ -144,8 +145,15 @@ func TestFiltersOutItself(t *testing.T) {
        s1.Start()
        defer s1.Stop()
 
+       rmdirPath, err := ioutil.TempDir(".", "")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer os.RemoveAll(rmdirPath)
+
        // simulate s1 having a public key and creating a remote peer with the same key
        rpCfg := *testCfg
+       rpCfg.DBPath = rmdirPath
        rp := &remotePeer{PrivKey: s1.nodePrivKey, Config: &rpCfg}
        rp.Start()
        defer rp.Stop()
@@ -162,7 +170,6 @@ func TestFiltersOutItself(t *testing.T) {
 }
 
 func TestDialBannedPeer(t *testing.T) {
-       t.Skip("skipping test")
        dirPath, err := ioutil.TempDir(".", "")
        if err != nil {
                t.Fatal(err)
@@ -171,6 +178,7 @@ func TestDialBannedPeer(t *testing.T) {
 
        testDB := dbm.NewDB("testdb", "leveldb", dirPath)
        cfg := *testCfg
+       cfg.DBPath = dirPath
        cfg.P2P.ListenAddress = "127.0.1.1:0"
        swPrivKey, err := signlib.NewPrivKey()
        if err != nil {
@@ -180,7 +188,14 @@ func TestDialBannedPeer(t *testing.T) {
        s1.Start()
        defer s1.Stop()
 
+       rmdirPath, err := ioutil.TempDir(".", "")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer os.RemoveAll(rmdirPath)
+
        rpCfg := *testCfg
+       rpCfg.DBPath = rmdirPath
        remotePrivKey, err := signlib.NewPrivKey()
        if err != nil {
                t.Fatal(err)
@@ -189,19 +204,17 @@ func TestDialBannedPeer(t *testing.T) {
        rp := &remotePeer{PrivKey: remotePrivKey, Config: &rpCfg}
        rp.Start()
        defer rp.Stop()
-       s1.AddBannedPeer(rp.addr.IP.String())
-       if err := s1.DialPeerWithAddress(rp.addr); errors.Root(err) != ErrConnectBannedPeer {
-               t.Fatal(err)
+       for {
+               if ok := s1.security.IsBanned(rp.addr.IP.String(), security.LevelMsgIllegal, "test"); ok {
+                       break
+               }
        }
-
-       s1.delBannedPeer(rp.addr.IP.String())
-       if err := s1.DialPeerWithAddress(rp.addr); err != nil {
+       if err := s1.DialPeerWithAddress(rp.addr); errors.Root(err) != security.ErrConnectBannedPeer {
                t.Fatal(err)
        }
 }
 
 func TestDuplicateOutBoundPeer(t *testing.T) {
-       t.Skip("skipping test")
        dirPath, err := ioutil.TempDir(".", "")
        if err != nil {
                t.Fatal(err)
@@ -210,6 +223,7 @@ func TestDuplicateOutBoundPeer(t *testing.T) {
 
        testDB := dbm.NewDB("testdb", "leveldb", dirPath)
        cfg := *testCfg
+       cfg.DBPath = dirPath
        cfg.P2P.ListenAddress = "127.0.1.1:0"
        swPrivKey, err := signlib.NewPrivKey()
        if err != nil {
@@ -220,7 +234,14 @@ func TestDuplicateOutBoundPeer(t *testing.T) {
        s1.Start()
        defer s1.Stop()
 
+       rmdirPath, err := ioutil.TempDir(".", "")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer os.RemoveAll(rmdirPath)
+
        rpCfg := *testCfg
+       rpCfg.DBPath = rmdirPath
        remotePrivKey, err := signlib.NewPrivKey()
        if err != nil {
                t.Fatal(err)
@@ -240,7 +261,6 @@ func TestDuplicateOutBoundPeer(t *testing.T) {
 }
 
 func TestDuplicateInBoundPeer(t *testing.T) {
-       t.Skip("skipping test")
        dirPath, err := ioutil.TempDir(".", "")
        if err != nil {
                t.Fatal(err)
@@ -249,6 +269,7 @@ func TestDuplicateInBoundPeer(t *testing.T) {
 
        testDB := dbm.NewDB("testdb", "leveldb", dirPath)
        cfg := *testCfg
+       cfg.DBPath = dirPath
        cfg.P2P.ListenAddress = "127.0.1.1:0"
        swPrivKey, err := signlib.NewPrivKey()
        if err != nil {
@@ -282,7 +303,6 @@ func TestDuplicateInBoundPeer(t *testing.T) {
 }
 
 func TestAddInboundPeer(t *testing.T) {
-       t.Skip("skipping test")
        dirPath, err := ioutil.TempDir(".", "")
        if err != nil {
                t.Fatal(err)
@@ -291,6 +311,7 @@ func TestAddInboundPeer(t *testing.T) {
 
        testDB := dbm.NewDB("testdb", "leveldb", dirPath)
        cfg := *testCfg
+       cfg.DBPath = dirPath
        cfg.P2P.MaxNumPeers = 2
        cfg.P2P.ListenAddress = "127.0.1.1:0"
        swPrivKey, err := signlib.NewPrivKey()
@@ -344,7 +365,6 @@ func TestAddInboundPeer(t *testing.T) {
 }
 
 func TestStopPeer(t *testing.T) {
-       t.Skip("skipping test")
        dirPath, err := ioutil.TempDir(".", "")
        if err != nil {
                t.Fatal(err)
@@ -353,6 +373,7 @@ func TestStopPeer(t *testing.T) {
 
        testDB := dbm.NewDB("testdb", "leveldb", dirPath)
        cfg := *testCfg
+       cfg.DBPath = dirPath
        cfg.P2P.MaxNumPeers = 2
        cfg.P2P.ListenAddress = "127.0.1.1:0"
        swPrivKey, err := signlib.NewPrivKey()
index 76c52ad..fa5d631 100644 (file)
@@ -92,7 +92,7 @@ func MakeSwitch(cfg *cfg.Config, testdb dbm.DB, privKey signlib.PrivKey, initSwi
        // new switch, add reactors
        l, listenAddr := GetListener(cfg.P2P)
        cfg.P2P.LANDiscover = false
-       sw, err := newSwitch(cfg, new(mockDiscv), nil, testdb, l, privKey, listenAddr, 0)
+       sw, err := newSwitch(cfg, new(mockDiscv), nil, l, privKey, listenAddr, 0)
        if err != nil {
                log.Errorf("create switch error: %s", err)
                return nil