OSDN Git Service

filter out known txs
authorWei Wang <apollo.mobility@gmail.com>
Sat, 13 Jul 2019 13:33:35 +0000 (21:33 +0800)
committerWei Wang <apollo.mobility@gmail.com>
Sat, 13 Jul 2019 13:35:26 +0000 (21:35 +0800)
common/ordered_set.go [new file with mode: 0644]
netsync/chainmgr/handle.go
netsync/chainmgr/tx_keeper.go

diff --git a/common/ordered_set.go b/common/ordered_set.go
new file mode 100644 (file)
index 0000000..ee9ec39
--- /dev/null
@@ -0,0 +1,71 @@
+package common
+
+import (
+       "errors"
+       "sync"
+)
+
+// OrderedSet is a set with limited capacity.
+// Items are evicted according to their insertion order.
+type OrderedSet struct {
+       capacity int
+       set      map[interface{}]struct{}
+       queue    []interface{}
+       start    int
+       end      int
+
+       lock sync.RWMutex
+}
+
+// NewOrderedSet creates an ordered set with given capacity
+func NewOrderedSet(capacity int) (*OrderedSet, error) {
+       if capacity < 1 {
+               return nil, errors.New("capacity must be a positive integer")
+       }
+
+       return &OrderedSet{
+               capacity: capacity,
+               set:      map[interface{}]struct{}{},
+               queue:    make([]interface{}, capacity),
+               end:      -1,
+       }, nil
+}
+
+// Add inserts items into the set.
+// If capacity is reached, oldest items are evicted
+func (os *OrderedSet) Add(items ...interface{}) {
+       os.lock.Lock()
+       defer os.lock.Unlock()
+
+       for _, item := range items {
+               if _, ok := os.set[item]; ok {
+                       continue
+               }
+
+               next := (os.end + 1) % os.capacity
+               if os.end != -1 && next == os.start {
+                       delete(os.set, os.queue[os.start])
+                       os.start = (os.start + 1) % os.capacity
+               }
+               os.end = next
+               os.queue[os.end] = item
+               os.set[item] = struct{}{}
+       }
+}
+
+// Has checks if certain items exists in the set
+func (os *OrderedSet) Has(item interface{}) bool {
+       os.lock.RLock()
+       defer os.lock.RUnlock()
+
+       _, ok := os.set[item]
+       return ok
+}
+
+// Size returns the size of the set
+func (os *OrderedSet) Size() int {
+       os.lock.RLock()
+       defer os.lock.RUnlock()
+
+       return len(os.set)
+}
index 91f3949..8792583 100644 (file)
@@ -6,6 +6,7 @@ import (
 
        log "github.com/sirupsen/logrus"
 
+       "github.com/vapor/common"
        cfg "github.com/vapor/config"
        "github.com/vapor/consensus"
        dbm "github.com/vapor/database/leveldb"
@@ -20,6 +21,8 @@ import (
 )
 
 const (
+       maxKnownTxs = 32768 // Maximum transactions hashes to keep in the known list (prevent DOS)
+
        logModule = "netsync"
 )
 
@@ -66,10 +69,12 @@ type Manager struct {
 
        eventDispatcher *event.Dispatcher
        txMsgSub        *event.Subscription
+       knownTxs        *common.OrderedSet // Set of transaction hashes known so far
 }
 
-//NewChainManager create a chain sync manager.
+//NewManager create a chain sync manager.
 func NewManager(config *cfg.Config, sw Switch, chain Chain, mempool Mempool, dispatcher *event.Dispatcher, peers *peers.PeerSet, fastSyncDB dbm.DB) (*Manager, error) {
+       knownTxs, _ := common.NewOrderedSet(maxKnownTxs)
        manager := &Manager{
                sw:              sw,
                mempool:         mempool,
@@ -80,6 +85,7 @@ func NewManager(config *cfg.Config, sw Switch, chain Chain, mempool Mempool, dis
                quit:            make(chan struct{}),
                config:          config,
                eventDispatcher: dispatcher,
+               knownTxs:        knownTxs,
        }
 
        if !config.VaultMode {
@@ -253,6 +259,11 @@ func (m *Manager) handleTransactionMsg(peer *peers.Peer, msg *msgs.TransactionMe
                return
        }
 
+       if m.knownTxs.Has(tx.ID.String()) {
+               return
+       }
+
+       m.knownTxs.Add(tx.ID.String())
        m.peers.MarkTx(peer.ID(), tx.ID)
        if isOrphan, err := m.chain.ValidateTx(tx); err != nil && err != core.ErrDustTx && !isOrphan {
                m.peers.ProcessIllegal(peer.ID(), security.LevelMsgIllegal, "fail on validate tx transaction")
@@ -272,6 +283,11 @@ func (m *Manager) handleTransactionsMsg(peer *peers.Peer, msg *msgs.Transactions
        }
 
        for _, tx := range txs {
+               if m.knownTxs.Has(tx.ID.String()) {
+                       continue
+               }
+
+               m.knownTxs.Add(tx.ID.String())
                m.peers.MarkTx(peer.ID(), tx.ID)
                if isOrphan, err := m.chain.ValidateTx(tx); err != nil && !isOrphan {
                        m.peers.ProcessIllegal(peer.ID(), security.LevelMsgIllegal, "fail on validate tx transaction")
index 6c6f5f9..af5a183 100644 (file)
@@ -50,7 +50,9 @@ func (m *Manager) broadcastTxsLoop() {
                        }
 
                        if ev.TxMsg.MsgType == core.MsgNewTx {
-                               if err := m.peers.BroadcastTx(ev.TxMsg.Tx); err != nil {
+                               tx := ev.TxMsg.Tx
+                               m.knownTxs.Add(tx.ID.String())
+                               if err := m.peers.BroadcastTx(tx); err != nil {
                                        log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on broadcast new tx.")
                                        continue
                                }