From 807d99726f6a0610fa9c835e2aabd983801d3510 Mon Sep 17 00:00:00 2001 From: yahtoo Date: Mon, 10 Jun 2019 10:41:36 +0800 Subject: [PATCH 1/1] Add p2p security module (#143) * tmp * Add p2p security module * Fix review bugs --- netsync/chainmgr/block_keeper.go | 5 +- netsync/chainmgr/handle.go | 12 +-- netsync/chainmgr/tool_test.go | 7 +- netsync/consensusmgr/block_fetcher.go | 3 +- netsync/consensusmgr/block_fetcher_test.go | 4 +- netsync/consensusmgr/handle.go | 4 +- netsync/peers/peer.go | 44 ++-------- p2p/node_info.go | 8 ++ p2p/peer_set.go | 8 ++ p2p/peer_test.go | 12 +-- p2p/{trust => security}/banscore.go | 4 +- p2p/{trust => security}/banscore_test.go | 4 +- p2p/security/blacklist.go | 91 ++++++++++++++++++++ p2p/security/filter.go | 38 +++++++++ p2p/security/score.go | 69 +++++++++++++++ p2p/security/security.go | 53 ++++++++++++ p2p/switch.go | 131 +++++++---------------------- p2p/switch_test.go | 45 +++++++--- p2p/test_util.go | 2 +- 19 files changed, 367 insertions(+), 177 deletions(-) rename p2p/{trust => security}/banscore.go (99%) rename p2p/{trust => security}/banscore_test.go (99%) create mode 100644 p2p/security/blacklist.go create mode 100644 p2p/security/filter.go create mode 100644 p2p/security/score.go create mode 100644 p2p/security/security.go diff --git a/netsync/chainmgr/block_keeper.go b/netsync/chainmgr/block_keeper.go index 8ce20ea7..112fd50e 100644 --- a/netsync/chainmgr/block_keeper.go +++ b/netsync/chainmgr/block_keeper.go @@ -9,6 +9,7 @@ import ( "github.com/vapor/consensus" "github.com/vapor/errors" "github.com/vapor/netsync/peers" + "github.com/vapor/p2p/security" "github.com/vapor/protocol/bc" "github.com/vapor/protocol/bc/types" ) @@ -353,7 +354,7 @@ func (bk *blockKeeper) startSync() bool { bk.syncPeer = peer if err := bk.fastBlockSync(checkPoint); err != nil { log.WithFields(log.Fields{"module": logModule, "err": err}).Warning("fail on fastBlockSync") - bk.peers.ErrorHandler(peer.ID(), err) + bk.peers.ErrorHandler(peer.ID(), security.LevelMsgIllegal, err) return false } return true @@ -370,7 +371,7 @@ func (bk *blockKeeper) startSync() bool { if err := bk.regularBlockSync(targetHeight); err != nil { log.WithFields(log.Fields{"module": logModule, "err": err}).Warning("fail on regularBlockSync") - bk.peers.ErrorHandler(peer.ID(), err) + bk.peers.ErrorHandler(peer.ID(),security.LevelMsgIllegal, err) return false } return true diff --git a/netsync/chainmgr/handle.go b/netsync/chainmgr/handle.go index 51dee4c0..0d13847d 100644 --- a/netsync/chainmgr/handle.go +++ b/netsync/chainmgr/handle.go @@ -12,6 +12,7 @@ import ( msgs "github.com/vapor/netsync/messages" "github.com/vapor/netsync/peers" "github.com/vapor/p2p" + "github.com/vapor/p2p/security" core "github.com/vapor/protocol" "github.com/vapor/protocol/bc" "github.com/vapor/protocol/bc/types" @@ -37,7 +38,6 @@ type Chain interface { type Switch interface { AddReactor(name string, reactor p2p.Reactor) p2p.Reactor - AddBannedPeer(string) error Start() (bool, error) Stop() bool IsListening() bool @@ -247,12 +247,12 @@ func (m *Manager) handleStatusMsg(basePeer peers.BasePeer, msg *msgs.StatusMessa func (m *Manager) handleTransactionMsg(peer *peers.Peer, msg *msgs.TransactionMessage) { tx, err := msg.GetTransaction() if err != nil { - m.peers.AddBanScore(peer.ID(), 0, 10, "fail on get tx from message") + m.peers.ProcessIllegal(peer.ID(), security.LevelConnException, "fail on get tx from message") return } if isOrphan, err := m.chain.ValidateTx(tx); err != nil && err != core.ErrDustTx && !isOrphan { - m.peers.AddBanScore(peer.ID(), 10, 0, "fail on validate tx transaction") + m.peers.ProcessIllegal(peer.ID(), security.LevelMsgIllegal, "fail on validate tx transaction") } m.peers.MarkTx(peer.ID(), tx.ID) } @@ -260,18 +260,18 @@ func (m *Manager) handleTransactionMsg(peer *peers.Peer, msg *msgs.TransactionMe func (m *Manager) handleTransactionsMsg(peer *peers.Peer, msg *msgs.TransactionsMessage) { txs, err := msg.GetTransactions() if err != nil { - m.peers.AddBanScore(peer.ID(), 0, 20, "fail on get txs from message") + m.peers.ProcessIllegal(peer.ID(), security.LevelConnException, "fail on get txs from message") return } if len(txs) > msgs.TxsMsgMaxTxNum { - m.peers.AddBanScore(peer.ID(), 20, 0, "exceeded the maximum tx number limit") + m.peers.ProcessIllegal(peer.ID(), security.LevelMsgIllegal, "exceeded the maximum tx number limit") return } for _, tx := range txs { if isOrphan, err := m.chain.ValidateTx(tx); err != nil && !isOrphan { - m.peers.AddBanScore(peer.ID(), 10, 0, "fail on validate tx transaction") + m.peers.ProcessIllegal(peer.ID(), security.LevelMsgIllegal, "fail on validate tx transaction") return } m.peers.MarkTx(peer.ID(), tx.ID) diff --git a/netsync/chainmgr/tool_test.go b/netsync/chainmgr/tool_test.go index e3549841..db17b4ff 100644 --- a/netsync/chainmgr/tool_test.go +++ b/netsync/chainmgr/tool_test.go @@ -89,8 +89,11 @@ func NewPeerSet() *PeerSet { return &PeerSet{} } -func (ps *PeerSet) AddBannedPeer(string) error { return nil } -func (ps *PeerSet) StopPeerGracefully(string) {} +func (ps *PeerSet) IsBanned(peerID string, level byte, reason string) bool { + return false +} + +func (ps *PeerSet) StopPeerGracefully(string) {} type NetWork struct { nodes map[*Manager]P2PPeer diff --git a/netsync/consensusmgr/block_fetcher.go b/netsync/consensusmgr/block_fetcher.go index 3b6c746a..dc429236 100644 --- a/netsync/consensusmgr/block_fetcher.go +++ b/netsync/consensusmgr/block_fetcher.go @@ -5,6 +5,7 @@ import ( "gopkg.in/karalabe/cookiejar.v2/collections/prque" "github.com/vapor/netsync/peers" + "github.com/vapor/p2p/security" "github.com/vapor/protocol/bc" ) @@ -80,7 +81,7 @@ func (f *blockFetcher) insert(msg *blockMsg) { return } - f.peers.AddBanScore(msg.peerID, 20, 0, err.Error()) + f.peers.ProcessIllegal(msg.peerID, security.LevelMsgIllegal, err.Error()) return } diff --git a/netsync/consensusmgr/block_fetcher_test.go b/netsync/consensusmgr/block_fetcher_test.go index f7d6936b..f5b44a87 100644 --- a/netsync/consensusmgr/block_fetcher_test.go +++ b/netsync/consensusmgr/block_fetcher_test.go @@ -12,8 +12,8 @@ import ( type peerMgr struct { } -func (pm *peerMgr) AddBannedPeer(string) error { - return nil +func (pm *peerMgr) IsBanned(peerID string, level byte, reason string) bool{ + return false } func (pm *peerMgr) StopPeerGracefully(string) { diff --git a/netsync/consensusmgr/handle.go b/netsync/consensusmgr/handle.go index 668a2f7a..66fddd9e 100644 --- a/netsync/consensusmgr/handle.go +++ b/netsync/consensusmgr/handle.go @@ -8,6 +8,7 @@ import ( "github.com/vapor/event" "github.com/vapor/netsync/peers" "github.com/vapor/p2p" + "github.com/vapor/p2p/security" "github.com/vapor/protocol/bc" "github.com/vapor/protocol/bc/types" ) @@ -15,7 +16,6 @@ import ( // Switch is the interface for p2p switch. type Switch interface { AddReactor(name string, reactor p2p.Reactor) p2p.Reactor - AddBannedPeer(string) error } // Chain is the interface for Bytom core. @@ -96,7 +96,7 @@ func (m *Manager) handleBlockProposeMsg(peerID string, msg *BlockProposeMsg) { func (m *Manager) handleBlockSignatureMsg(peerID string, msg *BlockSignatureMsg) { blockHash := bc.NewHash(msg.BlockHash) if err := m.chain.ProcessBlockSignature(msg.Signature, msg.PubKey, &blockHash); err != nil { - m.peers.AddBanScore(peerID, 20, 0, err.Error()) + m.peers.ProcessIllegal(peerID, security.LevelMsgIllegal, err.Error()) return } } diff --git a/netsync/peers/peer.go b/netsync/peers/peer.go index 0f62149e..1f0ac249 100644 --- a/netsync/peers/peer.go +++ b/netsync/peers/peer.go @@ -13,7 +13,6 @@ import ( "github.com/vapor/consensus" "github.com/vapor/errors" msgs "github.com/vapor/netsync/messages" - "github.com/vapor/p2p/trust" "github.com/vapor/protocol/bc" "github.com/vapor/protocol/bc/types" ) @@ -22,7 +21,6 @@ const ( maxKnownTxs = 32768 // Maximum transactions hashes to keep in the known list (prevent DOS) maxKnownSignatures = 1024 // Maximum block signatures to keep in the known list (prevent DOS) maxKnownBlocks = 1024 // Maximum block hashes to keep in the known list (prevent DOS) - defaultBanThreshold = uint32(100) maxFilterAddressSize = 50 maxFilterAddressCount = 1000 @@ -46,8 +44,8 @@ type BasePeer interface { //BasePeerSet is the intergace for connection level peer manager type BasePeerSet interface { - AddBannedPeer(string) error StopPeerGracefully(string) + IsBanned(peerID string, level byte, reason string) bool } type BroadcastMsg interface { @@ -79,7 +77,6 @@ type Peer struct { services consensus.ServiceFlag height uint64 hash *bc.Hash - banScore trust.DynamicBanScore knownTxs *set.Set // Set of transaction hashes known to be known by this peer knownBlocks *set.Set // Set of block hashes known to be known by this peer knownSignatures *set.Set // Set of block signatures known to be known by this peer @@ -104,30 +101,6 @@ func (p *Peer) Height() uint64 { return p.height } -func (p *Peer) addBanScore(persistent, transient uint32, reason string) bool { - score := p.banScore.Increase(persistent, transient) - if score > defaultBanThreshold { - log.WithFields(log.Fields{ - "module": logModule, - "address": p.Addr(), - "score": score, - "reason": reason, - }).Errorf("banning and disconnecting") - return true - } - - warnThreshold := defaultBanThreshold >> 1 - if score > warnThreshold { - log.WithFields(log.Fields{ - "module": logModule, - "address": p.Addr(), - "score": score, - "reason": reason, - }).Warning("ban score increasing") - } - return false -} - func (p *Peer) AddFilterAddress(address []byte) { p.mtx.Lock() defer p.mtx.Unlock() @@ -417,7 +390,7 @@ func NewPeerSet(basePeerSet BasePeerSet) *PeerSet { } } -func (ps *PeerSet) AddBanScore(peerID string, persistent, transient uint32, reason string) { +func (ps *PeerSet) ProcessIllegal(peerID string, level byte, reason string) { ps.mtx.Lock() peer := ps.peers[peerID] ps.mtx.Unlock() @@ -425,13 +398,10 @@ func (ps *PeerSet) AddBanScore(peerID string, persistent, transient uint32, reas if peer == nil { return } - if ban := peer.addBanScore(persistent, transient, reason); !ban { - return - } - if err := ps.AddBannedPeer(peer.Addr().String()); err != nil { - log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on add ban peer") + if banned := ps.IsBanned(peer.Addr().String(), level, reason); banned { + ps.RemovePeer(peerID) } - ps.RemovePeer(peerID) + return } func (ps *PeerSet) AddPeer(peer BasePeer) { @@ -536,9 +506,9 @@ func (ps *PeerSet) BroadcastTx(tx *types.Tx) error { return nil } -func (ps *PeerSet) ErrorHandler(peerID string, err error) { +func (ps *PeerSet) ErrorHandler(peerID string, level byte, err error) { if errors.Root(err) == ErrPeerMisbehave { - ps.AddBanScore(peerID, 20, 0, err.Error()) + ps.ProcessIllegal(peerID, level, err.Error()) } else { ps.RemovePeer(peerID) } diff --git a/p2p/node_info.go b/p2p/node_info.go index 00f818dd..e602a0d7 100644 --- a/p2p/node_info.go +++ b/p2p/node_info.go @@ -71,6 +71,14 @@ func (info *NodeInfo) compatibleWith(other *NodeInfo, versionCompatibleWith Vers return nil } +func (info NodeInfo) DoFilter(ip string, pubKey string) error { + if info.PubKey == pubKey { + return ErrConnectSelf + } + + return nil +} + //listenHost peer listener ip address func (info NodeInfo) listenHost() string { host, _, _ := net.SplitHostPort(info.ListenAddr) diff --git a/p2p/peer_set.go b/p2p/peer_set.go index e26746b4..c6523715 100644 --- a/p2p/peer_set.go +++ b/p2p/peer_set.go @@ -50,6 +50,14 @@ func (ps *PeerSet) Add(peer *Peer) error { return nil } +func (ps *PeerSet) DoFilter(ip string, pubKey string) error { + if ps.Has(pubKey) { + return ErrDuplicatePeer + } + + return nil +} + // Get looks up a peer by the provided peerKey. func (ps *PeerSet) Get(peerKey string) *Peer { ps.mtx.Lock() diff --git a/p2p/peer_test.go b/p2p/peer_test.go index 8559f119..bcaed18b 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -7,6 +7,7 @@ import ( "time" cfg "github.com/vapor/config" + "github.com/vapor/consensus" conn "github.com/vapor/p2p/connection" "github.com/vapor/p2p/signlib" "github.com/vapor/version" @@ -142,11 +143,12 @@ func (rp *remotePeer) accept(l net.Listener) { } _, err = pc.HandshakeTimeout(&NodeInfo{ - PubKey: rp.PrivKey.XPub().String(), - Moniker: "remote_peer", - Network: rp.Config.ChainID, - Version: version.Version, - ListenAddr: l.Addr().String(), + PubKey: rp.PrivKey.XPub().String(), + Moniker: "remote_peer", + Network: rp.Config.ChainID, + Version: version.Version, + ListenAddr: l.Addr().String(), + ServiceFlag: consensus.DefaultServices, }, 5*time.Second) if err != nil { fmt.Println("Failed to perform handshake:", err) diff --git a/p2p/trust/banscore.go b/p2p/security/banscore.go similarity index 99% rename from p2p/trust/banscore.go rename to p2p/security/banscore.go index 892d653f..5892a5f3 100644 --- a/p2p/trust/banscore.go +++ b/p2p/security/banscore.go @@ -1,4 +1,4 @@ -package trust +package security import ( "fmt" @@ -29,7 +29,7 @@ const ( var precomputedFactor [precomputedLen]float64 // init precomputes decay factors. -func Init() { +func init() { for i := range precomputedFactor { precomputedFactor[i] = math.Exp(-1.0 * float64(i) * lambda) } diff --git a/p2p/trust/banscore_test.go b/p2p/security/banscore_test.go similarity index 99% rename from p2p/trust/banscore_test.go rename to p2p/security/banscore_test.go index a4a4fcf7..6dd0944f 100644 --- a/p2p/trust/banscore_test.go +++ b/p2p/security/banscore_test.go @@ -1,4 +1,4 @@ -package trust +package security import ( "math" @@ -23,7 +23,6 @@ func TestInt(t *testing.T) { {bs: DynamicBanScore{lastUnix: 0, transient: math.MaxUint32, persistent: math.MaxUint32}, timeLapse: 0, wantValue: math.MaxUint32 - 1}, } - Init() for i, intTest := range banScoreIntTests { rst := intTest.bs.int(time.Unix(intTest.timeLapse, 0)) if rst != intTest.wantValue { @@ -53,7 +52,6 @@ func TestIncrease(t *testing.T) { {bs: DynamicBanScore{lastUnix: 0, transient: 0, persistent: math.MaxUint32}, transientAdd: math.MaxUint32, persistentAdd: 0, timeLapse: Lifetime + 1, wantValue: math.MaxUint32 - 1}, } - Init() for i, incTest := range banScoreIncreaseTests { rst := incTest.bs.increase(incTest.persistentAdd, incTest.transientAdd, time.Unix(incTest.timeLapse, 0)) if rst != incTest.wantValue { diff --git a/p2p/security/blacklist.go b/p2p/security/blacklist.go new file mode 100644 index 00000000..f8ca05e0 --- /dev/null +++ b/p2p/security/blacklist.go @@ -0,0 +1,91 @@ +package security + +import ( + "encoding/json" + "errors" + "sync" + "time" + + cfg "github.com/vapor/config" + dbm "github.com/vapor/database/leveldb" +) + +const ( + defaultBanDuration = time.Hour * 1 + blacklistKey = "BlacklistPeers" +) + +var ( + ErrConnectBannedPeer = errors.New("connect banned peer") +) + +type Blacklist struct { + peers map[string]time.Time + db dbm.DB + + mtx sync.Mutex +} + +func NewBlacklist(config *cfg.Config) *Blacklist { + return &Blacklist{ + peers: make(map[string]time.Time), + db: dbm.NewDB("blacklist", config.DBBackend, config.DBDir()), + } +} + +//AddPeer add peer to blacklist +func (bl *Blacklist) AddPeer(ip string) error { + bl.mtx.Lock() + defer bl.mtx.Unlock() + + bl.peers[ip] = time.Now().Add(defaultBanDuration) + dataJSON, err := json.Marshal(bl.peers) + if err != nil { + return err + } + + bl.db.Set([]byte(blacklistKey), dataJSON) + return nil +} + +func (bl *Blacklist) delPeer(ip string) error { + delete(bl.peers, ip) + dataJson, err := json.Marshal(bl.peers) + if err != nil { + return err + } + + bl.db.Set([]byte(blacklistKey), dataJson) + return nil +} + +func (bl *Blacklist) DoFilter(ip string, pubKey string) error { + bl.mtx.Lock() + defer bl.mtx.Unlock() + + if banEnd, ok := bl.peers[ip]; ok { + if time.Now().Before(banEnd) { + return ErrConnectBannedPeer + } + + if err := bl.delPeer(ip); err != nil { + return err + } + } + + return nil +} + +// LoadPeers load banned peers from db +func (bl *Blacklist) LoadPeers() error { + bl.mtx.Lock() + defer bl.mtx.Unlock() + + if dataJSON := bl.db.Get([]byte(blacklistKey)); dataJSON != nil { + if err := json.Unmarshal(dataJSON, &bl.peers); err != nil { + return err + } + } + + return nil +} diff --git a/p2p/security/filter.go b/p2p/security/filter.go new file mode 100644 index 00000000..409952aa --- /dev/null +++ b/p2p/security/filter.go @@ -0,0 +1,38 @@ +package security + +import "sync" + +type Filter interface { + DoFilter(string, string) error +} + +type PeerFilter struct { + filterChain []Filter + mtx sync.RWMutex +} + +func NewPeerFilter() *PeerFilter { + return &PeerFilter{ + filterChain: make([]Filter, 0), + } +} + +func (pf *PeerFilter) register(filter Filter) { + pf.mtx.Lock() + defer pf.mtx.Unlock() + + pf.filterChain = append(pf.filterChain, filter) +} + +func (pf *PeerFilter) doFilter(ip string, pubKey string) error { + pf.mtx.RLock() + defer pf.mtx.RUnlock() + + for _, filter := range pf.filterChain { + if err := filter.DoFilter(ip, pubKey); err != nil { + return err + } + } + + return nil +} diff --git a/p2p/security/score.go b/p2p/security/score.go new file mode 100644 index 00000000..fea3149c --- /dev/null +++ b/p2p/security/score.go @@ -0,0 +1,69 @@ +package security + +import ( + "sync" + + log "github.com/sirupsen/logrus" +) + +const ( + defaultBanThreshold = uint32(100) + defaultWarnThreshold = uint32(50) + + LevelMsgIllegal = 0x01 + levelMsgIllegalPersistent = uint32(20) + levelMsgIllegalTransient = uint32(0) + LevelConnException = 0x02 + levelConnExceptionPersistent = uint32(0) + levelConnExceptionTransient = uint32(20) +) + +type PeersBanScore struct { + peers map[string]*DynamicBanScore + mtx sync.Mutex +} + +func NewPeersScore() *PeersBanScore { + return &PeersBanScore{ + peers: make(map[string]*DynamicBanScore), + } +} + +func (ps *PeersBanScore) DelPeer(ip string) { + ps.mtx.Lock() + defer ps.mtx.Unlock() + + delete(ps.peers, ip) +} + +func (ps *PeersBanScore) Increase(ip string, level byte, reason string) bool { + ps.mtx.Lock() + defer ps.mtx.Unlock() + + var persistent, transient uint32 + switch level { + case LevelMsgIllegal: + persistent = levelMsgIllegalPersistent + transient = levelMsgIllegalTransient + case LevelConnException: + persistent = levelConnExceptionPersistent + transient = levelConnExceptionTransient + default: + return false + } + banScore, ok := ps.peers[ip] + if !ok { + banScore = &DynamicBanScore{} + ps.peers[ip] = banScore + } + score := banScore.Increase(persistent, transient) + if score > defaultBanThreshold { + log.WithFields(log.Fields{"module": logModule, "address": ip, "score": score, "reason": reason}).Errorf("banning and disconnecting") + return true + } + + if score > defaultWarnThreshold { + log.WithFields(log.Fields{"module": logModule, "address": ip, "score": score, "reason": reason}).Warning("ban score increasing") + } + return false +} diff --git a/p2p/security/security.go b/p2p/security/security.go new file mode 100644 index 00000000..7c9e9454 --- /dev/null +++ b/p2p/security/security.go @@ -0,0 +1,53 @@ +package security + +import ( + log "github.com/sirupsen/logrus" + + cfg "github.com/vapor/config" +) + +const logModule = "p2p/security" + +type Security struct { + filter *PeerFilter + blacklist *Blacklist + peersBanScore *PeersBanScore +} + +func NewSecurity(config *cfg.Config) *Security { + return &Security{ + filter: NewPeerFilter(), + blacklist: NewBlacklist(config), + peersBanScore: NewPeersScore(), + } +} + +func (s *Security) DoFilter(ip string, pubKey string) error { + return s.filter.doFilter(ip, pubKey) +} + +func (s *Security) IsBanned(ip string, level byte, reason string) bool { + if ok := s.peersBanScore.Increase(ip, level, reason); !ok { + return false + } + + if err := s.blacklist.AddPeer(ip); err != nil { + log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on add ban peer") + } + //clear peer score + s.peersBanScore.DelPeer(ip) + return true +} + +func (s *Security) RegisterFilter(filter Filter) { + s.filter.register(filter) +} + +func (s *Security) Start() error { + if err := s.blacklist.LoadPeers(); err != nil { + return err + } + + s.filter.register(s.blacklist) + return nil +} diff --git a/p2p/switch.go b/p2p/switch.go index 62d47f1c..ef8306d3 100644 --- a/p2p/switch.go +++ b/p2p/switch.go @@ -2,7 +2,6 @@ package p2p import ( "encoding/binary" - "encoding/json" "fmt" "net" "sync" @@ -14,7 +13,6 @@ import ( cfg "github.com/vapor/config" "github.com/vapor/consensus" "github.com/vapor/crypto/sha3pool" - dbm "github.com/vapor/database/leveldb" "github.com/vapor/errors" "github.com/vapor/event" "github.com/vapor/p2p/connection" @@ -22,14 +20,12 @@ import ( "github.com/vapor/p2p/discover/mdns" "github.com/vapor/p2p/netutil" "github.com/vapor/p2p/signlib" - "github.com/vapor/p2p/trust" + security "github.com/vapor/p2p/security" "github.com/vapor/version" ) const ( - bannedPeerKey = "BannedPeer" - defaultBanDuration = time.Hour * 1 - logModule = "p2p" + logModule = "p2p" minNumOutboundPeers = 4 maxNumLANPeers = 5 @@ -39,10 +35,9 @@ const ( //pre-define errors for connecting fail var ( - ErrDuplicatePeer = errors.New("Duplicate peer") - ErrConnectSelf = errors.New("Connect self") - ErrConnectBannedPeer = errors.New("Connect banned peer") - ErrConnectSpvPeer = errors.New("Outbound connect spv peer") + ErrDuplicatePeer = errors.New("Duplicate peer") + ErrConnectSelf = errors.New("Connect self") + ErrConnectSpvPeer = errors.New("Outbound connect spv peer") ) type discv interface { @@ -54,6 +49,13 @@ type lanDiscv interface { Stop() } +type Security interface { + DoFilter(ip string, pubKey string) error + IsBanned(ip string, level byte, reason string) bool + RegisterFilter(filter security.Filter) + Start() error +} + // Switch handles peer connections and exposes an API to receive incoming messages // on `Reactors`. Each `Reactor` is responsible for handling incoming messages of one // or more `Channels`. So while sending outgoing messages is typically performed on the peer, @@ -73,9 +75,7 @@ type Switch struct { nodePrivKey signlib.PrivKey // our node privkey discv discv lanDiscv lanDiscv - bannedPeer map[string]time.Time - db dbm.DB - mtx sync.Mutex + security Security } // NewSwitch create a new Switch and set discover. @@ -96,7 +96,6 @@ func NewSwitch(config *cfg.Config) (*Switch, error) { sha3pool.Sum256(h[:], data) netID := binary.BigEndian.Uint64(h[:8]) - blacklistDB := dbm.NewDB("trusthistory", config.DBBackend, config.DBDir()) privateKey := config.PrivateKey() if !config.VaultMode { // Create listener @@ -110,11 +109,11 @@ func NewSwitch(config *cfg.Config) (*Switch, error) { } } - return newSwitch(config, discv, lanDiscv, blacklistDB, l, *privateKey, listenAddr, netID) + return newSwitch(config, discv, lanDiscv, l, *privateKey, listenAddr, netID) } // newSwitch creates a new Switch with the given config. -func newSwitch(config *cfg.Config, discv discv, lanDiscv lanDiscv, blacklistDB dbm.DB, l Listener, privKey signlib.PrivKey, listenAddr string, netID uint64) (*Switch, error) { +func newSwitch(config *cfg.Config, discv discv, lanDiscv lanDiscv, l Listener, privKey signlib.PrivKey, listenAddr string, netID uint64) (*Switch, error) { sw := &Switch{ Config: config, peerConfig: DefaultPeerConfig(config.P2P), @@ -126,17 +125,12 @@ func newSwitch(config *cfg.Config, discv discv, lanDiscv lanDiscv, blacklistDB d nodePrivKey: privKey, discv: discv, lanDiscv: lanDiscv, - db: blacklistDB, nodeInfo: NewNodeInfo(config, privKey.XPub(), listenAddr, netID), - bannedPeer: make(map[string]time.Time), - } - if err := sw.loadBannedPeers(); err != nil { - return nil, err + security: security.NewSecurity(config), } sw.AddListener(l) sw.BaseService = *cmn.NewBaseService(nil, "P2P Switch", sw) - trust.Init() log.WithFields(log.Fields{"module": logModule, "nodeInfo": sw.nodeInfo}).Info("init p2p network") return sw, nil } @@ -148,6 +142,13 @@ func (sw *Switch) OnStart() error { return err } } + + sw.security.RegisterFilter(sw.nodeInfo) + sw.security.RegisterFilter(sw.peers) + if err := sw.security.Start(); err != nil { + return err + } + for _, listener := range sw.listeners { go sw.listenerRoutine(listener) } @@ -178,21 +179,6 @@ func (sw *Switch) OnStop() { } } -//AddBannedPeer add peer to blacklist -func (sw *Switch) AddBannedPeer(ip string) error { - sw.mtx.Lock() - defer sw.mtx.Unlock() - - sw.bannedPeer[ip] = time.Now().Add(defaultBanDuration) - dataJSON, err := json.Marshal(sw.bannedPeer) - if err != nil { - return err - } - - sw.db.Set([]byte(bannedPeerKey), dataJSON) - return nil -} - // AddPeer performs the P2P handshake with a peer // that already has a SecretConnection. If all goes well, // it starts the peer and adds it to the switch. @@ -213,7 +199,7 @@ func (sw *Switch) AddPeer(pc *peerConn, isLAN bool) error { } peer := newPeer(pc, peerNodeInfo, sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, isLAN) - if err := sw.filterConnByPeer(peer); err != nil { + if err := sw.security.DoFilter(peer.remoteAddrHost(), peer.PubKey()); err != nil { return err } @@ -260,7 +246,7 @@ func (sw *Switch) DialPeerWithAddress(addr *NetAddress) error { log.WithFields(log.Fields{"module": logModule, "address": addr}).Debug("Dialing peer") sw.dialing.Set(addr.IP.String(), addr) defer sw.dialing.Delete(addr.IP.String()) - if err := sw.filterConnByIP(addr.IP.String()); err != nil { + if err := sw.security.DoFilter(addr.IP.String(), ""); err != nil { return err } @@ -279,6 +265,10 @@ func (sw *Switch) DialPeerWithAddress(addr *NetAddress) error { return nil } +func (sw *Switch) IsBanned(ip string, level byte, reason string) bool { + return sw.security.IsBanned(ip, level, reason) +} + //IsDialing prevent duplicate dialing func (sw *Switch) IsDialing(addr *NetAddress) bool { return sw.dialing.Has(addr.IP.String()) @@ -290,17 +280,6 @@ func (sw *Switch) IsListening() bool { return len(sw.listeners) > 0 } -// loadBannedPeers load banned peers from db -func (sw *Switch) loadBannedPeers() error { - if dataJSON := sw.db.Get([]byte(bannedPeerKey)); dataJSON != nil { - if err := json.Unmarshal(dataJSON, &sw.bannedPeer); err != nil { - return err - } - } - - return nil -} - // Listeners returns the list of listeners the switch listens on. // NOTE: Not goroutine safe. func (sw *Switch) Listeners() []Listener { @@ -362,22 +341,6 @@ func (sw *Switch) addPeerWithConnection(conn net.Conn) error { return nil } -func (sw *Switch) checkBannedPeer(peer string) error { - sw.mtx.Lock() - defer sw.mtx.Unlock() - - if banEnd, ok := sw.bannedPeer[peer]; ok { - if time.Now().Before(banEnd) { - return ErrConnectBannedPeer - } - - if err := sw.delBannedPeer(peer); err != nil { - return err - } - } - return nil -} - func (sw *Switch) connectLANPeers(lanPeer mdns.LANPeerEvent) { lanPeers, _, _, numDialing := sw.NumPeers() numToDial := maxNumLANPeers - lanPeers @@ -422,42 +385,6 @@ func (sw *Switch) connectLANPeersRoutine() { } } -func (sw *Switch) delBannedPeer(addr string) error { - sw.mtx.Lock() - defer sw.mtx.Unlock() - - delete(sw.bannedPeer, addr) - datajson, err := json.Marshal(sw.bannedPeer) - if err != nil { - return err - } - - sw.db.Set([]byte(bannedPeerKey), datajson) - return nil -} - -func (sw *Switch) filterConnByIP(ip string) error { - if ip == sw.nodeInfo.listenHost() { - return ErrConnectSelf - } - return sw.checkBannedPeer(ip) -} - -func (sw *Switch) filterConnByPeer(peer *Peer) error { - if err := sw.checkBannedPeer(peer.remoteAddrHost()); err != nil { - return err - } - - if sw.nodeInfo.PubKey == peer.PubKey() { - return ErrConnectSelf - } - - if sw.peers.Has(peer.Key) { - return ErrDuplicatePeer - } - return nil -} - func (sw *Switch) listenerRoutine(l Listener) { for { inConn, ok := <-l.Connections() diff --git a/p2p/switch_test.go b/p2p/switch_test.go index 6b2cdb8f..19c475e8 100644 --- a/p2p/switch_test.go +++ b/p2p/switch_test.go @@ -12,6 +12,7 @@ import ( dbm "github.com/vapor/database/leveldb" "github.com/vapor/errors" conn "github.com/vapor/p2p/connection" + "github.com/vapor/p2p/security" "github.com/vapor/p2p/signlib" ) @@ -125,7 +126,6 @@ func initSwitchFunc(sw *Switch) *Switch { //Test connect self. func TestFiltersOutItself(t *testing.T) { - t.Skip("skipping test") dirPath, err := ioutil.TempDir(".", "") if err != nil { t.Fatal(err) @@ -134,6 +134,7 @@ func TestFiltersOutItself(t *testing.T) { testDB := dbm.NewDB("testdb", "leveldb", dirPath) cfg := *testCfg + cfg.DBPath = dirPath cfg.P2P.ListenAddress = "127.0.1.1:0" swPrivKey, err := signlib.NewPrivKey() if err != nil { @@ -144,8 +145,15 @@ func TestFiltersOutItself(t *testing.T) { s1.Start() defer s1.Stop() + rmdirPath, err := ioutil.TempDir(".", "") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(rmdirPath) + // simulate s1 having a public key and creating a remote peer with the same key rpCfg := *testCfg + rpCfg.DBPath = rmdirPath rp := &remotePeer{PrivKey: s1.nodePrivKey, Config: &rpCfg} rp.Start() defer rp.Stop() @@ -162,7 +170,6 @@ func TestFiltersOutItself(t *testing.T) { } func TestDialBannedPeer(t *testing.T) { - t.Skip("skipping test") dirPath, err := ioutil.TempDir(".", "") if err != nil { t.Fatal(err) @@ -171,6 +178,7 @@ func TestDialBannedPeer(t *testing.T) { testDB := dbm.NewDB("testdb", "leveldb", dirPath) cfg := *testCfg + cfg.DBPath = dirPath cfg.P2P.ListenAddress = "127.0.1.1:0" swPrivKey, err := signlib.NewPrivKey() if err != nil { @@ -180,7 +188,14 @@ func TestDialBannedPeer(t *testing.T) { s1.Start() defer s1.Stop() + rmdirPath, err := ioutil.TempDir(".", "") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(rmdirPath) + rpCfg := *testCfg + rpCfg.DBPath = rmdirPath remotePrivKey, err := signlib.NewPrivKey() if err != nil { t.Fatal(err) @@ -189,19 +204,17 @@ func TestDialBannedPeer(t *testing.T) { rp := &remotePeer{PrivKey: remotePrivKey, Config: &rpCfg} rp.Start() defer rp.Stop() - s1.AddBannedPeer(rp.addr.IP.String()) - if err := s1.DialPeerWithAddress(rp.addr); errors.Root(err) != ErrConnectBannedPeer { - t.Fatal(err) + for { + if ok := s1.security.IsBanned(rp.addr.IP.String(), security.LevelMsgIllegal, "test"); ok { + break + } } - - s1.delBannedPeer(rp.addr.IP.String()) - if err := s1.DialPeerWithAddress(rp.addr); err != nil { + if err := s1.DialPeerWithAddress(rp.addr); errors.Root(err) != security.ErrConnectBannedPeer { t.Fatal(err) } } func TestDuplicateOutBoundPeer(t *testing.T) { - t.Skip("skipping test") dirPath, err := ioutil.TempDir(".", "") if err != nil { t.Fatal(err) @@ -210,6 +223,7 @@ func TestDuplicateOutBoundPeer(t *testing.T) { testDB := dbm.NewDB("testdb", "leveldb", dirPath) cfg := *testCfg + cfg.DBPath = dirPath cfg.P2P.ListenAddress = "127.0.1.1:0" swPrivKey, err := signlib.NewPrivKey() if err != nil { @@ -220,7 +234,14 @@ func TestDuplicateOutBoundPeer(t *testing.T) { s1.Start() defer s1.Stop() + rmdirPath, err := ioutil.TempDir(".", "") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(rmdirPath) + rpCfg := *testCfg + rpCfg.DBPath = rmdirPath remotePrivKey, err := signlib.NewPrivKey() if err != nil { t.Fatal(err) @@ -240,7 +261,6 @@ func TestDuplicateOutBoundPeer(t *testing.T) { } func TestDuplicateInBoundPeer(t *testing.T) { - t.Skip("skipping test") dirPath, err := ioutil.TempDir(".", "") if err != nil { t.Fatal(err) @@ -249,6 +269,7 @@ func TestDuplicateInBoundPeer(t *testing.T) { testDB := dbm.NewDB("testdb", "leveldb", dirPath) cfg := *testCfg + cfg.DBPath = dirPath cfg.P2P.ListenAddress = "127.0.1.1:0" swPrivKey, err := signlib.NewPrivKey() if err != nil { @@ -282,7 +303,6 @@ func TestDuplicateInBoundPeer(t *testing.T) { } func TestAddInboundPeer(t *testing.T) { - t.Skip("skipping test") dirPath, err := ioutil.TempDir(".", "") if err != nil { t.Fatal(err) @@ -291,6 +311,7 @@ func TestAddInboundPeer(t *testing.T) { testDB := dbm.NewDB("testdb", "leveldb", dirPath) cfg := *testCfg + cfg.DBPath = dirPath cfg.P2P.MaxNumPeers = 2 cfg.P2P.ListenAddress = "127.0.1.1:0" swPrivKey, err := signlib.NewPrivKey() @@ -344,7 +365,6 @@ func TestAddInboundPeer(t *testing.T) { } func TestStopPeer(t *testing.T) { - t.Skip("skipping test") dirPath, err := ioutil.TempDir(".", "") if err != nil { t.Fatal(err) @@ -353,6 +373,7 @@ func TestStopPeer(t *testing.T) { testDB := dbm.NewDB("testdb", "leveldb", dirPath) cfg := *testCfg + cfg.DBPath = dirPath cfg.P2P.MaxNumPeers = 2 cfg.P2P.ListenAddress = "127.0.1.1:0" swPrivKey, err := signlib.NewPrivKey() diff --git a/p2p/test_util.go b/p2p/test_util.go index 76c52ad9..fa5d631f 100644 --- a/p2p/test_util.go +++ b/p2p/test_util.go @@ -92,7 +92,7 @@ func MakeSwitch(cfg *cfg.Config, testdb dbm.DB, privKey signlib.PrivKey, initSwi // new switch, add reactors l, listenAddr := GetListener(cfg.P2P) cfg.P2P.LANDiscover = false - sw, err := newSwitch(cfg, new(mockDiscv), nil, testdb, l, privKey, listenAddr, 0) + sw, err := newSwitch(cfg, new(mockDiscv), nil, l, privKey, listenAddr, 0) if err != nil { log.Errorf("create switch error: %s", err) return nil -- 2.11.0