OSDN Git Service

add_new_casper_for_chain (#1931)
authorPoseidon <shenao.78@163.com>
Mon, 10 May 2021 08:14:08 +0000 (16:14 +0800)
committerGitHub <noreply@github.com>
Mon, 10 May 2021 08:14:08 +0000 (16:14 +0800)
* add_new_casper_for_chain

* add checkpoint db implemention

* mov casper to protocol package

* remove logs

* remove logs

* remove fmt import

Co-authored-by: Paladz <yzhu101@uottawa.ca>
21 files changed:
account/accounts_test.go
database/leveldb/db.go
database/leveldb/db_test.go [new file with mode: 0644]
database/leveldb/go_level_db.go
database/leveldb/go_level_db_test.go
database/leveldb/mem_db.go
database/leveldb/mem_db_test.go
database/store.go
database/store_test.go
node/node.go
protocol/apply_block.go [moved from protocol/consensus/apply_block.go with 93% similarity]
protocol/auth_verification.go [moved from protocol/consensus/auth_verification.go with 84% similarity]
protocol/block.go
protocol/casper.go [moved from protocol/consensus/casper.go with 90% similarity]
protocol/consensus.go [deleted file]
protocol/protocol.go
protocol/state/checkpoint.go
protocol/store.go
protocol/tree_node.go [moved from protocol/consensus/tree_node.go with 98% similarity]
protocol/txpool_test.go
test/utxo_view/utxo_view_test.go

index 7b1bf1d..95332b2 100644 (file)
@@ -211,7 +211,7 @@ func mockAccountManager(t *testing.T) *Manager {
        }
        defer os.RemoveAll(dirPath)
 
-       testDB := dbm.NewDB("testdb", "memdb", dirPath)
+       testDB := dbm.NewDB("testdb", "leveldb", dirPath)
        dispatcher := event.NewDispatcher()
 
        store := database.NewStore(testDB)
index 38cab5b..80eb08d 100644 (file)
@@ -1,4 +1,4 @@
-package db
+package leveldb
 
 import . "github.com/tendermint/tmlibs/common"
 
@@ -12,6 +12,7 @@ type DB interface {
        NewBatch() Batch
        Iterator() Iterator
        IteratorPrefix([]byte) Iterator
+       IteratorPrefixWithStart(Prefix, start []byte, isReverse bool) Iterator
 
        // For debugging
        Print()
diff --git a/database/leveldb/db_test.go b/database/leveldb/db_test.go
new file mode 100644 (file)
index 0000000..d955da3
--- /dev/null
@@ -0,0 +1,134 @@
+package leveldb
+
+import (
+       "fmt"
+       "io/ioutil"
+       "os"
+       "testing"
+
+       "github.com/stretchr/testify/require"
+)
+
+func newTempDB(t *testing.T, backend string) (db DB, dbDir string) {
+       dirname, err := ioutil.TempDir("", "db_common_test")
+       require.Nil(t, err)
+       return NewDB("testdb", backend, dirname), dirname
+}
+
+func TestDBIteratorSingleKey(t *testing.T) {
+       for backend := range backends {
+               t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) {
+                       db, dir := newTempDB(t, backend)
+                       defer os.RemoveAll(dir)
+
+                       db.Set([]byte("1"), []byte("value_1"))
+                       itr := db.IteratorPrefixWithStart(nil, nil, false)
+                       require.Equal(t, []byte(""), itr.Key())
+                       require.Equal(t, true, itr.Next())
+                       require.Equal(t, []byte("1"), itr.Key())
+               })
+       }
+}
+
+func TestDBIteratorTwoKeys(t *testing.T) {
+       for backend := range backends {
+               t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) {
+                       db, dir := newTempDB(t, backend)
+                       defer os.RemoveAll(dir)
+
+                       db.SetSync([]byte("1"), []byte("value_1"))
+                       db.SetSync([]byte("2"), []byte("value_1"))
+
+                       itr := db.IteratorPrefixWithStart(nil, []byte("1"), false)
+
+                       require.Equal(t, []byte("1"), itr.Key())
+
+                       require.Equal(t, true, itr.Next())
+                       itr = db.IteratorPrefixWithStart(nil, []byte("2"), false)
+
+                       require.Equal(t, false, itr.Next())
+               })
+       }
+}
+
+func TestDBIterator(t *testing.T) {
+       dirname, err := ioutil.TempDir("", "db_common_test")
+       require.Nil(t, err)
+
+       db, err := NewGoLevelDB("testdb", dirname)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       defer func() {
+               db.Close()
+               os.RemoveAll(dirname)
+       }()
+
+       db.SetSync([]byte("aaa1"), []byte("value_1"))
+       db.SetSync([]byte("aaa22"), []byte("value_2"))
+       db.SetSync([]byte("bbb22"), []byte("value_3"))
+
+       itr := db.IteratorPrefixWithStart([]byte("aaa"), []byte("aaa1"), false)
+       defer itr.Release()
+
+       require.Equal(t, true, itr.Next())
+       require.Equal(t, []byte("aaa22"), itr.Key())
+
+       require.Equal(t, false, itr.Next())
+
+       itr = db.IteratorPrefixWithStart([]byte("aaa"), nil, false)
+
+       require.Equal(t, true, itr.Next())
+       require.Equal(t, []byte("aaa1"), itr.Key())
+
+       require.Equal(t, true, itr.Next())
+       require.Equal(t, []byte("aaa22"), itr.Key())
+
+       require.Equal(t, false, itr.Next())
+
+       itr = db.IteratorPrefixWithStart([]byte("bbb"), []byte("aaa1"), false)
+       require.Equal(t, false, itr.Next())
+}
+
+func TestDBIteratorReverse(t *testing.T) {
+       dirname, err := ioutil.TempDir("", "db_common_test")
+       require.Nil(t, err)
+
+       db, err := NewGoLevelDB("testdb", dirname)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       defer func() {
+               db.Close()
+               os.RemoveAll(dirname)
+       }()
+
+       db.SetSync([]byte("aaa1"), []byte("value_1"))
+       db.SetSync([]byte("aaa22"), []byte("value_2"))
+       db.SetSync([]byte("bbb22"), []byte("value_3"))
+
+       itr := db.IteratorPrefixWithStart([]byte("aaa"), []byte("aaa22"), true)
+       defer itr.Release()
+
+       require.Equal(t, true, itr.Next())
+       require.Equal(t, []byte("aaa1"), itr.Key())
+
+       require.Equal(t, false, itr.Next())
+
+       itr = db.IteratorPrefixWithStart([]byte("aaa"), nil, true)
+
+       require.Equal(t, true, itr.Next())
+       require.Equal(t, []byte("aaa22"), itr.Key())
+
+       require.Equal(t, true, itr.Next())
+       require.Equal(t, []byte("aaa1"), itr.Key())
+
+       require.Equal(t, false, itr.Next())
+
+       require.Equal(t, false, itr.Next())
+
+       itr = db.IteratorPrefixWithStart([]byte("bbb"), []byte("aaa1"), true)
+       require.Equal(t, false, itr.Next())
+}
index e9e8d3d..99c28e0 100644 (file)
@@ -1,4 +1,4 @@
-package db
+package leveldb
 
 import (
        "fmt"
@@ -118,7 +118,28 @@ func (db *GoLevelDB) Stats() map[string]string {
 }
 
 type goLevelDBIterator struct {
-       source iterator.Iterator
+       source    iterator.Iterator
+       start     []byte
+       isReverse bool
+}
+
+func newGoLevelDBIterator(source iterator.Iterator, start []byte, isReverse bool) *goLevelDBIterator {
+       if start != nil {
+               valid := source.Seek(start)
+               if !valid && isReverse {
+                       source.Last()
+                       source.Next()
+               }
+       } else if isReverse {
+               source.Last()
+               source.Next()
+       }
+
+       return &goLevelDBIterator{
+               source:    source,
+               start:     start,
+               isReverse: isReverse,
+       }
 }
 
 // Key returns a copy of the current key.
@@ -148,6 +169,10 @@ func (it *goLevelDBIterator) Error() error {
 }
 
 func (it *goLevelDBIterator) Next() bool {
+       it.assertNoError()
+       if it.isReverse {
+               return it.source.Prev()
+       }
        return it.source.Next()
 }
 
@@ -155,12 +180,23 @@ func (it *goLevelDBIterator) Release() {
        it.source.Release()
 }
 
+func (it *goLevelDBIterator) assertNoError() {
+       if err := it.source.Error(); err != nil {
+               panic(err)
+       }
+}
+
 func (db *GoLevelDB) Iterator() Iterator {
-       return &goLevelDBIterator{db.db.NewIterator(nil, nil)}
+       return &goLevelDBIterator{source: db.db.NewIterator(nil, nil)}
 }
 
 func (db *GoLevelDB) IteratorPrefix(prefix []byte) Iterator {
-       return &goLevelDBIterator{db.db.NewIterator(util.BytesPrefix(prefix), nil)}
+       return &goLevelDBIterator{source: db.db.NewIterator(util.BytesPrefix(prefix), nil)}
+}
+
+func (db *GoLevelDB) IteratorPrefixWithStart(Prefix, start []byte, isReverse bool) Iterator {
+       itr := db.db.NewIterator(util.BytesPrefix(Prefix), nil)
+       return newGoLevelDBIterator(itr, start, isReverse)
 }
 
 func (db *GoLevelDB) NewBatch() Batch {
index 2cd3192..fc40bfb 100644 (file)
@@ -1,4 +1,4 @@
-package db
+package leveldb
 
 import (
        "bytes"
@@ -30,7 +30,7 @@ func BenchmarkRandomReadsWrites(b *testing.B) {
                // Write something
                {
                        idx := (int64(RandInt()) % numItems)
-                       internal[idx] += 1
+                       internal[idx]++
                        val := internal[idx]
                        idxBytes := int642Bytes(int64(idx))
                        valBytes := int642Bytes(int64(val))
index 62f40fc..9ab7052 100644 (file)
@@ -1,6 +1,7 @@
-package db
+package leveldb
 
 import (
+       "bytes"
        "fmt"
        "sort"
        "strings"
@@ -78,13 +79,29 @@ func (db *MemDB) Stats() map[string]string {
 type memDBIterator struct {
        last int
        keys []string
-       db   *MemDB
+       db   DB
+
+       start []byte
 }
 
 func newMemDBIterator() *memDBIterator {
        return &memDBIterator{}
 }
 
+// Keys is expected to be in reverse order for reverse iterators.
+func newMemDBIteratorWithArgs(db DB, keys []string, start []byte) *memDBIterator {
+       itr := &memDBIterator{
+               db:    db,
+               keys:  keys,
+               start: start,
+               last:  -1,
+       }
+       if start != nil {
+               itr.Seek(start)
+       }
+       return itr
+}
+
 func (it *memDBIterator) Next() bool {
        if it.last >= len(it.keys)-1 {
                return false
@@ -94,6 +111,9 @@ func (it *memDBIterator) Next() bool {
 }
 
 func (it *memDBIterator) Key() []byte {
+       if it.last < 0 {
+               return []byte("")
+       }
        return []byte(it.keys[it.last])
 }
 
@@ -143,10 +163,38 @@ func (db *MemDB) IteratorPrefix(prefix []byte) Iterator {
        return it
 }
 
+func (db *MemDB) IteratorPrefixWithStart(Prefix, start []byte, isReverse bool) Iterator {
+       db.mtx.Lock()
+       defer db.mtx.Unlock()
+
+       keys := db.getSortedKeys(start, isReverse)
+       return newMemDBIteratorWithArgs(db, keys, start)
+}
+
 func (db *MemDB) NewBatch() Batch {
        return &memDBBatch{db, nil}
 }
 
+func (db *MemDB) getSortedKeys(start []byte, reverse bool) []string {
+       keys := []string{}
+       for key := range db.db {
+               if bytes.Compare([]byte(key), start) < 0 {
+                       continue
+               }
+               keys = append(keys, key)
+       }
+       sort.Strings(keys)
+       if reverse {
+               nkeys := len(keys)
+               for i := 0; i < nkeys/2; i++ {
+                       temp := keys[i]
+                       keys[i] = keys[nkeys-i-1]
+                       keys[nkeys-i-1] = temp
+               }
+       }
+       return keys
+}
+
 //--------------------------------------------------------------------------------
 
 type memDBBatch struct {
index 503e361..459f872 100644 (file)
@@ -1,4 +1,4 @@
-package db
+package leveldb
 
 import (
        "testing"
@@ -23,7 +23,7 @@ func TestMemDbIterator(t *testing.T) {
        i := 0
        for iter.Next() {
                assert.Equal(t, db.Get(iter.Key()), iter.Value(), "values dont match for key")
-               i += 1
+               i++
        }
        assert.Equal(t, i, len(db.db), "iterator didnt cover whole db")
 }
index 6d0e42c..a968373 100644 (file)
@@ -1,6 +1,7 @@
 package database
 
 import (
+       "encoding/binary"
        "encoding/json"
        "time"
 
@@ -19,8 +20,10 @@ import (
 const logModule = "leveldb"
 
 var (
+       // CheckpointPrefix represent the namespace of checkpoints in db
+       CheckpointPrefix = []byte("CP:")
        // BlockStoreKey block store key
-       BlockStoreKey          = []byte("blockStore")
+       BlockStoreKey = []byte("blockStore")
        // BlockHeaderIndexPrefix  block header index with height
        BlockHeaderIndexPrefix = []byte("BH:")
 )
@@ -226,7 +229,7 @@ func (s *Store) LoadBlockIndex(stateBestHeight uint64) (*state.BlockIndex, error
 }
 
 // SaveChainStatus save the core's newest status && delete old status
-func (s *Store) SaveChainStatus(node *state.BlockNode, view *state.UtxoViewpoint, contractView *state.ContractViewpoint) error {
+func (s *Store) SaveChainStatus(node *state.BlockNode, view *state.UtxoViewpoint, contractView *state.ContractViewpoint, finalizedHeight uint64, finalizedHash *bc.Hash) error {
        batch := s.db.NewBatch()
        if err := saveUtxoView(batch, view); err != nil {
                return err
@@ -240,7 +243,7 @@ func (s *Store) SaveChainStatus(node *state.BlockNode, view *state.UtxoViewpoint
                return err
        }
 
-       bytes, err := json.Marshal(protocol.BlockStoreState{Height: node.Height, Hash: &node.Hash})
+       bytes, err := json.Marshal(protocol.BlockStoreState{Height: node.Height, Hash: &node.Hash, FinalizedHeight: finalizedHeight, FinalizedHash: finalizedHash})
        if err != nil {
                return err
        }
@@ -250,16 +253,83 @@ func (s *Store) SaveChainStatus(node *state.BlockNode, view *state.UtxoViewpoint
        return nil
 }
 
-func (s *Store) GetCheckpoint(*bc.Hash) (*state.Checkpoint, error) {
-       return nil, nil
+func calcCheckpointKey(height uint64, hash *bc.Hash) []byte {
+       buf := make([]byte, 8)
+       binary.BigEndian.PutUint64(buf, height)
+       key := append(CheckpointPrefix, buf...)
+       if hash != nil {
+               key = append(key, hash.Bytes()...)
+       }
+       return key
+}
+
+func (s *Store) GetCheckpoint(hash *bc.Hash) (*state.Checkpoint, error) {
+       header, err := s.GetBlockHeader(hash)
+       if err != nil {
+               return nil, err
+       }
+
+       data := s.db.Get(calcCheckpointKey(header.Height, hash))
+       checkpoint := &state.Checkpoint{}
+       if err := json.Unmarshal(data, checkpoint); err != nil {
+               return nil, err
+       }
+
+       return checkpoint, nil
 }
 
 // GetCheckpointsByHeight return all checkpoints of specified block height
-func (s *Store) GetCheckpointsByHeight(uint64) ([]*state.Checkpoint, error) {
-       return nil, nil
+func (s *Store) GetCheckpointsByHeight(height uint64) ([]*state.Checkpoint, error) {
+       iter := s.db.IteratorPrefix(calcCheckpointKey(height, nil))
+       defer iter.Release()
+       return loadCheckpointsFromIter(iter)
+}
+
+// CheckpointsFromNode return all checkpoints from specified block height and hash
+func (s *Store) CheckpointsFromNode(height uint64, hash *bc.Hash) ([]*state.Checkpoint, error) {
+       startKey := calcCheckpointKey(height, hash)
+       iter := s.db.IteratorPrefixWithStart(CheckpointPrefix, startKey, false)
+
+       finalizedCheckpoint := &state.Checkpoint{}
+       if err := json.Unmarshal(iter.Value(), finalizedCheckpoint); err != nil {
+               return nil, err
+       }
+
+       checkpoints := []*state.Checkpoint{finalizedCheckpoint}
+       subs, err := loadCheckpointsFromIter(iter)
+       if err != nil {
+               return nil, err
+       }
+
+       checkpoints = append(checkpoints, subs...)
+       return checkpoints, nil
+}
+
+func loadCheckpointsFromIter(iter dbm.Iterator) ([]*state.Checkpoint, error) {
+       var checkpoints []*state.Checkpoint
+       defer iter.Release()
+       for iter.Next() {
+               checkpoint := &state.Checkpoint{}
+               if err := json.Unmarshal(iter.Value(), checkpoint); err != nil {
+                       return nil, err
+               }
+
+               checkpoints = append(checkpoints, checkpoint)
+       }
+       return checkpoints, nil
 }
 
 // SaveCheckpoints bulk save multiple checkpoint
-func (s *Store) SaveCheckpoints(...*state.Checkpoint) error {
+func (s *Store) SaveCheckpoints(checkpoints ...*state.Checkpoint) error {
+       batch := s.db.NewBatch()
+       for _, checkpoint := range checkpoints {
+               data, err := json.Marshal(checkpoint)
+               if err != nil {
+                       return err
+               }
+
+               batch.Set(calcCheckpointKey(checkpoint.Height, &checkpoint.Hash), data)
+       }
+       batch.Write()
        return nil
 }
index 5d912f8..1202646 100644 (file)
@@ -151,11 +151,11 @@ func TestSaveChainStatus(t *testing.T) {
        }
 
        contractView := state.NewContractViewpoint()
-       if err := store.SaveChainStatus(node, view, contractView); err != nil {
+       if err := store.SaveChainStatus(node, view, contractView, 0, &bc.Hash{}); err != nil {
                t.Fatal(err)
        }
 
-       expectStatus := &protocol.BlockStoreState{Height: node.Height, Hash: &node.Hash}
+       expectStatus := &protocol.BlockStoreState{Height: node.Height, Hash: &node.Hash, FinalizedHeight: 0, FinalizedHash: &bc.Hash{}}
        if !testutil.DeepEqual(store.GetStoreStatus(), expectStatus) {
                t.Errorf("got block status:%v, expect block status:%v", store.GetStoreStatus(), expectStatus)
        }
index b94d69f..7adb99a 100644 (file)
@@ -8,10 +8,10 @@ import (
        _ "net/http/pprof"
        "path/filepath"
 
-       "github.com/prometheus/prometheus/util/flock"
        log "github.com/sirupsen/logrus"
        cmn "github.com/tendermint/tmlibs/common"
        browser "github.com/toqueteos/webbrowser"
+       "github.com/prometheus/prometheus/util/flock"
 
        "github.com/bytom/bytom/accesstoken"
        "github.com/bytom/bytom/account"
@@ -27,7 +27,6 @@ import (
        "github.com/bytom/bytom/env"
        "github.com/bytom/bytom/event"
        bytomLog "github.com/bytom/bytom/log"
-
        "github.com/bytom/bytom/net/websocket"
        "github.com/bytom/bytom/netsync"
        "github.com/bytom/bytom/protocol"
@@ -81,6 +80,7 @@ func NewNode(config *cfg.Config) *Node {
 
        dispatcher := event.NewDispatcher()
        txPool := protocol.NewTxPool(store, dispatcher)
+
        chain, err := protocol.NewChain(store, txPool)
        if err != nil {
                cmn.Exit(cmn.Fmt("Failed to create chain structure: %v", err))
similarity index 93%
rename from protocol/consensus/apply_block.go
rename to protocol/apply_block.go
index 8c4871d..4803d5a 100644 (file)
@@ -1,11 +1,11 @@
-package consensus
+package protocol
 
 import (
        "encoding/hex"
 
+       "github.com/bytom/bytom/config"
        "github.com/bytom/bytom/errors"
        "github.com/bytom/bytom/math/checked"
-       "github.com/bytom/bytom/protocol"
        "github.com/bytom/bytom/protocol/bc"
        "github.com/bytom/bytom/protocol/bc/types"
        "github.com/bytom/bytom/protocol/state"
@@ -16,7 +16,7 @@ import (
 // the tree of checkpoint will grow with the arrival of new blocks
 // it will return verification when an epoch is reached and the current node is the validator, otherwise return nil
 // the chain module must broadcast the verification
-func (c *Casper) ApplyBlock(block *types.Block) (*protocol.Verification, error) {
+func (c *Casper) ApplyBlock(block *types.Block) (*Verification, error) {
        c.mu.Lock()
        defer c.mu.Unlock()
 
@@ -128,15 +128,15 @@ func (c *Casper) applySupLinks(target *state.Checkpoint, supLinks []*types.SupLi
        return nil
 }
 
-func (c *Casper) myVerification(target *state.Checkpoint, validators []*state.Validator) (*protocol.Verification, error) {
-       pubKey := c.prvKey.XPub().String()
+func (c *Casper) myVerification(target *state.Checkpoint, validators []*state.Validator) (*Verification, error) {
+       pubKey := config.CommonConfig.PrivateKey().XPub().String()
        if !isValidator(pubKey, validators) {
                return nil, nil
        }
 
        source := c.lastJustifiedCheckpointOfBranch(target)
        if source != nil {
-               v := &protocol.Verification{
+               v := &Verification{
                        SourceHash:   source.Hash,
                        TargetHash:   target.Hash,
                        SourceHeight: source.Height,
@@ -144,7 +144,8 @@ func (c *Casper) myVerification(target *state.Checkpoint, validators []*state.Va
                        PubKey:       pubKey,
                }
 
-               if err := v.Sign(c.prvKey); err != nil {
+               prvKey := config.CommonConfig.PrivateKey()
+               if err := v.Sign(*prvKey); err != nil {
                        return nil, err
                }
 
@@ -241,10 +242,10 @@ func (c *Casper) lastJustifiedCheckpointOfBranch(branch *state.Checkpoint) *stat
        return nil
 }
 
-func supLinkToVerifications(supLink *types.SupLink, validators []*state.Validator, targetHash bc.Hash, targetHeight uint64) []*protocol.Verification {
-       var result []*protocol.Verification
+func supLinkToVerifications(supLink *types.SupLink, validators []*state.Validator, targetHash bc.Hash, targetHeight uint64) []*Verification {
+       var result []*Verification
        for i, signature := range supLink.Signatures {
-               result = append(result, &protocol.Verification{
+               result = append(result, &Verification{
                        SourceHash:   supLink.SourceHash,
                        TargetHash:   targetHash,
                        SourceHeight: supLink.SourceHeight,
similarity index 84%
rename from protocol/consensus/auth_verification.go
rename to protocol/auth_verification.go
index 5a4d2bd..8696f0b 100644 (file)
@@ -1,11 +1,10 @@
-package consensus
+package protocol
 
 import (
        "fmt"
 
        log "github.com/sirupsen/logrus"
 
-       "github.com/bytom/bytom/protocol"
        "github.com/bytom/bytom/protocol/bc"
        "github.com/bytom/bytom/protocol/state"
 )
@@ -14,7 +13,7 @@ import (
 // the status of source checkpoint must justified, and an individual validator Î½ must not publish two distinct Verification
 // âŸ¨Î½,s1,t1,h(s1),h(t1)⟩ and âŸ¨Î½,s2,t2,h(s2),h(t2)⟩, such that either:
 // h(t1) = h(t2) OR h(s1) < h(s2) < h(t2) < h(t1)
-func (c *Casper) AuthVerification(v *protocol.Verification) error {
+func (c *Casper) AuthVerification(v *Verification) error {
        if err := validate(v); err != nil {
                return err
        }
@@ -41,7 +40,7 @@ func (c *Casper) AuthVerification(v *protocol.Verification) error {
        return c.authVerification(v)
 }
 
-func (c *Casper) authVerification(v *protocol.Verification) error {
+func (c *Casper) authVerification(v *Verification) error {
        target, err := c.store.GetCheckpoint(&v.TargetHash)
        if err != nil {
                c.verificationCache.Add(verificationCacheKey(v.TargetHash, v.PubKey), v)
@@ -55,7 +54,7 @@ func (c *Casper) authVerification(v *protocol.Verification) error {
        return c.addVerificationToCheckpoint(target, v)
 }
 
-func (c *Casper) addVerificationToCheckpoint(target *state.Checkpoint, v *protocol.Verification) error {
+func (c *Casper) addVerificationToCheckpoint(target *state.Checkpoint, v *Verification) error {
        source, err := c.store.GetCheckpoint(&v.SourceHash)
        if err != nil {
                return err
@@ -75,7 +74,7 @@ func (c *Casper) addVerificationToCheckpoint(target *state.Checkpoint, v *protoc
        affectedCheckpoints := c.setJustified(source, target)
        _, newBestHash := c.BestChain()
        if oldBestHash != newBestHash {
-               c.rollbackNotifyCh <- newBestHash
+               c.rollbackNotifyCh <- nil
        }
 
        return c.store.SaveCheckpoints(affectedCheckpoints...)
@@ -131,7 +130,7 @@ func (c *Casper) authVerificationLoop() {
                        }
 
                        c.mu.Lock()
-                       if err := c.authVerification(verification.(*protocol.Verification)); err != nil {
+                       if err := c.authVerification(verification.(*Verification)); err != nil {
                                log.WithField("err", err).Error("auth verification in cache")
                        }
                        c.mu.Unlock()
@@ -141,7 +140,7 @@ func (c *Casper) authVerificationLoop() {
        }
 }
 
-func (c *Casper) verifyVerification(v *protocol.Verification, trackEvilValidator bool) error {
+func (c *Casper) verifyVerification(v *Verification, trackEvilValidator bool) error {
        if err := c.verifySameHeight(v, trackEvilValidator); err != nil {
                return err
        }
@@ -150,7 +149,7 @@ func (c *Casper) verifyVerification(v *protocol.Verification, trackEvilValidator
 }
 
 // a validator must not publish two distinct votes for the same target height
-func (c *Casper) verifySameHeight(v *protocol.Verification, trackEvilValidator bool) error {
+func (c *Casper) verifySameHeight(v *Verification, trackEvilValidator bool) error {
        checkpoints, err := c.store.GetCheckpointsByHeight(v.TargetHeight)
        if err != nil {
                return err
@@ -160,7 +159,7 @@ func (c *Casper) verifySameHeight(v *protocol.Verification, trackEvilValidator b
                for _, supLink := range checkpoint.SupLinks {
                        if _, ok := supLink.Signatures[v.PubKey]; ok && checkpoint.Hash != v.TargetHash {
                                if trackEvilValidator {
-                                       c.evilValidators[v.PubKey] = []*protocol.Verification{v, makeVerification(supLink, checkpoint, v.PubKey)}
+                                       c.evilValidators[v.PubKey] = []*Verification{v, makeVerification(supLink, checkpoint, v.PubKey)}
                                }
                                return errSameHeightInVerification
                        }
@@ -170,7 +169,7 @@ func (c *Casper) verifySameHeight(v *protocol.Verification, trackEvilValidator b
 }
 
 // a validator must not vote within the span of its other votes.
-func (c *Casper) verifySpanHeight(v *protocol.Verification, trackEvilValidator bool) error {
+func (c *Casper) verifySpanHeight(v *Verification, trackEvilValidator bool) error {
        if c.tree.findOnlyOne(func(checkpoint *state.Checkpoint) bool {
                if checkpoint.Height == v.TargetHeight {
                        return false
@@ -181,7 +180,7 @@ func (c *Casper) verifySpanHeight(v *protocol.Verification, trackEvilValidator b
                                if (checkpoint.Height < v.TargetHeight && supLink.SourceHeight > v.SourceHeight) ||
                                        (checkpoint.Height > v.TargetHeight && supLink.SourceHeight < v.SourceHeight) {
                                        if trackEvilValidator {
-                                               c.evilValidators[v.PubKey] = []*protocol.Verification{v, makeVerification(supLink, checkpoint, v.PubKey)}
+                                               c.evilValidators[v.PubKey] = []*Verification{v, makeVerification(supLink, checkpoint, v.PubKey)}
                                        }
                                        return true
                                }
@@ -194,8 +193,8 @@ func (c *Casper) verifySpanHeight(v *protocol.Verification, trackEvilValidator b
        return nil
 }
 
-func makeVerification(supLink *state.SupLink, checkpoint *state.Checkpoint, pubKey string) *protocol.Verification {
-       return &protocol.Verification{
+func makeVerification(supLink *state.SupLink, checkpoint *state.Checkpoint, pubKey string) *Verification {
+       return &Verification{
                SourceHash:   supLink.SourceHash,
                TargetHash:   checkpoint.Hash,
                SourceHeight: supLink.SourceHeight,
@@ -205,7 +204,7 @@ func makeVerification(supLink *state.SupLink, checkpoint *state.Checkpoint, pubK
        }
 }
 
-func validate(v *protocol.Verification) error {
+func validate(v *Verification) error {
        if v.SourceHeight%state.BlocksOfEpoch != 0 || v.TargetHeight%state.BlocksOfEpoch != 0 {
                return errVoteToGrowingCheckpoint
        }
index 8f0891f..8375d3b 100644 (file)
@@ -294,16 +294,15 @@ func (c *Chain) ProcessBlock(block *types.Block) (bool, error) {
        return response.isOrphan, response.err
 }
 
-func (c *Chain) blockProcesser() {
+func (c *Chain) blockProcessor() {
        for {
                select {
                case msg := <-c.processBlockCh:
                        isOrphan, err := c.processBlock(msg.block)
                        msg.reply <- processBlockResponse{isOrphan: isOrphan, err: err}
-               case blockHash := <-c.rollbackBlockCh:
-                       if err := c.rollback(&blockHash); err != nil {
+               case <-c.rollbackNotifyCh:
+                       if err := c.rollback(); err != nil {
                                log.WithFields(log.Fields{"module": logModule, "err": err}).Warning("fail on rollback block")
-                               c.rollbackBlockCh <- blockHash
                        }
                }
        }
@@ -337,8 +336,13 @@ func (c *Chain) processBlock(block *types.Block) (bool, error) {
        return false, nil
 }
 
-func (c *Chain) rollback(bestBlockHash *bc.Hash) error {
-       node := c.index.GetNode(bestBlockHash)
+func (c *Chain) rollback() error {
+       latestBestBlockHash := c.latestBestBlockHash()
+       if c.bestNode.Hash == *latestBestBlockHash {
+               return nil
+       }
+
+       node := c.index.GetNode(latestBestBlockHash)
        log.WithFields(log.Fields{"module": logModule}).Debug("start to reorganize chain")
        return c.reorganizeChain(node)
 }
similarity index 90%
rename from protocol/consensus/casper.go
rename to protocol/casper.go
index 2574a57..492e8c1 100644 (file)
@@ -1,4 +1,4 @@
-package consensus
+package protocol
 
 import (
        "sync"
@@ -6,9 +6,7 @@ import (
        log "github.com/sirupsen/logrus"
 
        "github.com/bytom/bytom/common"
-       "github.com/bytom/bytom/crypto/ed25519/chainkd"
        "github.com/bytom/bytom/errors"
-       "github.com/bytom/bytom/protocol"
        "github.com/bytom/bytom/protocol/bc"
        "github.com/bytom/bytom/protocol/state"
 )
@@ -31,12 +29,11 @@ const minGuaranty = 1E14
 type Casper struct {
        mu               sync.RWMutex
        tree             *treeNode
-       rollbackNotifyCh chan bc.Hash
+       rollbackNotifyCh chan interface{}
        newEpochCh       chan bc.Hash
-       store            protocol.Store
-       prvKey           chainkd.XPrv
+       store            Store
        // pubKey -> conflicting verifications
-       evilValidators map[string][]*protocol.Verification
+       evilValidators map[string][]*Verification
        // block hash -> previous checkpoint hash
        prevCheckpointCache *common.Cache
        // block hash + pubKey -> verification
@@ -49,18 +46,17 @@ type Casper struct {
 // argument checkpoints load the checkpoints from leveldb
 // the first element of checkpoints must genesis checkpoint or the last finalized checkpoint in order to reduce memory space
 // the others must be successors of first one
-func NewCasper(store protocol.Store, prvKey chainkd.XPrv, checkpoints []*state.Checkpoint) *Casper {
+func NewCasper(store Store, checkpoints []*state.Checkpoint, rollbackNotifyCh chan interface{}) *Casper {
        if checkpoints[0].Height != 0 && checkpoints[0].Status != state.Finalized {
                log.Panic("first element of checkpoints must genesis or in finalized status")
        }
 
        casper := &Casper{
                tree:                  makeTree(checkpoints[0], checkpoints[1:]),
-               rollbackNotifyCh:      make(chan bc.Hash),
+               rollbackNotifyCh:      rollbackNotifyCh,
                newEpochCh:            make(chan bc.Hash),
                store:                 store,
-               prvKey:                prvKey,
-               evilValidators:        make(map[string][]*protocol.Verification),
+               evilValidators:        make(map[string][]*Verification),
                prevCheckpointCache:   common.NewCache(1024),
                verificationCache:     common.NewCache(1024),
                justifyingCheckpoints: make(map[bc.Hash][]*state.Checkpoint),
@@ -108,8 +104,8 @@ func (c *Casper) Validators(blockHash *bc.Hash) ([]*state.Validator, error) {
 // EvilValidator represent a validator who broadcast two distinct verification that violate the commandment
 type EvilValidator struct {
        PubKey string
-       V1     *protocol.Verification
-       V2     *protocol.Verification
+       V1     *Verification
+       V2     *Verification
 }
 
 // EvilValidators return all evil validators
diff --git a/protocol/consensus.go b/protocol/consensus.go
deleted file mode 100644 (file)
index 1513979..0000000
+++ /dev/null
@@ -1,26 +0,0 @@
-package protocol
-
-import (
-       "github.com/bytom/bytom/protocol/bc"
-       "github.com/bytom/bytom/protocol/bc/types"
-       "github.com/bytom/bytom/protocol/state"
-)
-
-// Casper is BFT based proof of stack consensus algorithm, it provides safety and liveness in theory
-type CasperConsensus interface {
-
-       // Best chain return the chain containing the justified checkpoint of the largest height
-       BestChain() (uint64, bc.Hash)
-
-       // LastFinalized return the block height and block hash which is finalized ast last
-       LastFinalized() (uint64, bc.Hash)
-
-       // AuthVerification verify whether the Verification is legal.
-       AuthVerification(v *Verification) error
-
-       // ApplyBlock apply block to the consensus module
-       ApplyBlock(block *types.Block) (*Verification, error)
-
-       // Validators return the validators by specified block hash
-       Validators(blockHash *bc.Hash) ([]*state.Validator, error)
-}
index 845577c..b01b73a 100644 (file)
@@ -16,14 +16,14 @@ const maxProcessBlockChSize = 1024
 
 // Chain provides functions for working with the Bytom block chain.
 type Chain struct {
-       index           *state.BlockIndex
-       orphanManage    *OrphanManage
-       txPool          *TxPool
-       store           Store
-       processBlockCh  chan *processBlockMsg
-       rollbackBlockCh chan bc.Hash
-       casper          CasperConsensus
-       eventDispatcher *event.Dispatcher
+       index            *state.BlockIndex
+       orphanManage     *OrphanManage
+       txPool           *TxPool
+       store            Store
+       casper           *Casper
+       processBlockCh   chan *processBlockMsg
+       rollbackNotifyCh chan interface{}
+       eventDispatcher  *event.Dispatcher
 
        cond     sync.Cond
        bestNode *state.BlockNode
@@ -36,10 +36,11 @@ func NewChain(store Store, txPool *TxPool) (*Chain, error) {
 
 func NewChainWithOrphanManage(store Store, txPool *TxPool, manage *OrphanManage) (*Chain, error) {
        c := &Chain{
-               orphanManage:   manage,
-               txPool:         txPool,
-               store:          store,
-               processBlockCh: make(chan *processBlockMsg, maxProcessBlockChSize),
+               orphanManage:     manage,
+               txPool:           txPool,
+               store:            store,
+               rollbackNotifyCh: make(chan interface{}),
+               processBlockCh:   make(chan *processBlockMsg, maxProcessBlockChSize),
        }
        c.cond.L = new(sync.Mutex)
 
@@ -58,7 +59,14 @@ func NewChainWithOrphanManage(store Store, txPool *TxPool, manage *OrphanManage)
 
        c.bestNode = c.index.GetNode(storeStatus.Hash)
        c.index.SetMainChain(c.bestNode)
-       go c.blockProcesser()
+
+       casper, err := newCasper(store, storeStatus, c.rollbackNotifyCh)
+       if err != nil {
+               return nil, err
+       }
+
+       c.casper = casper
+       go c.blockProcessor()
        return c, nil
 }
 
@@ -68,6 +76,16 @@ func (c *Chain) initChainStatus() error {
                return err
        }
 
+       checkpoint := &state.Checkpoint{
+               Height:         0,
+               Hash:           genesisBlock.Hash(),
+               StartTimestamp: genesisBlock.Timestamp,
+               Status:         state.Justified,
+       }
+       if err := c.store.SaveCheckpoints(checkpoint); err != nil {
+               return err
+       }
+
        utxoView := state.NewUtxoViewpoint()
        bcBlock := types.MapBlock(genesisBlock)
        if err := utxoView.ApplyBlock(bcBlock); err != nil {
@@ -80,7 +98,16 @@ func (c *Chain) initChainStatus() error {
        }
 
        contractView := state.NewContractViewpoint()
-       return c.store.SaveChainStatus(node, utxoView, contractView)
+       return c.store.SaveChainStatus(node, utxoView, contractView, 0, &checkpoint.Hash)
+}
+
+func newCasper(store Store, storeStatus *BlockStoreState, rollbackNotifyCh chan interface{}) (*Casper, error) {
+       checkpoints, err := store.CheckpointsFromNode(storeStatus.FinalizedHeight, storeStatus.FinalizedHash)
+       if err != nil {
+               return nil, err
+       }
+
+       return NewCasper(store, checkpoints, rollbackNotifyCh), nil
 }
 
 // BestBlockHeight returns the last irreversible block header of the blockchain
@@ -109,6 +136,11 @@ func (c *Chain) BestBlockHash() *bc.Hash {
        return &c.bestNode.Hash
 }
 
+func (c *Chain) latestBestBlockHash() *bc.Hash {
+       _, hash := c.casper.BestChain()
+       return &hash
+}
+
 // BestBlockHeader returns the chain tail block
 func (c *Chain) BestBlockHeader() *types.BlockHeader {
        node := c.index.BestNode()
@@ -134,7 +166,8 @@ func (c *Chain) SignBlockHeader(blockHeader *types.BlockHeader) {
 
 // This function must be called with mu lock in above level
 func (c *Chain) setState(node *state.BlockNode, view *state.UtxoViewpoint, contractView *state.ContractViewpoint) error {
-       if err := c.store.SaveChainStatus(node, view, contractView); err != nil {
+       finalizedHeight, finalizedHash := c.casper.LastFinalized()
+       if err := c.store.SaveChainStatus(node, view, contractView, finalizedHeight, &finalizedHash); err != nil {
                return err
        }
 
index 53fafd8..f174006 100644 (file)
@@ -48,11 +48,11 @@ func (s *SupLink) IsMajority() bool {
 // Casper only considers checkpoints for finalization. When a checkpoint is explicitly finalized,
 // all ancestor blocks of the checkpoint are implicitly finalized.
 type Checkpoint struct {
-       Height         uint64
-       Hash           bc.Hash
-       ParentHash     bc.Hash
+       Height     uint64
+       Hash       bc.Hash
+       ParentHash bc.Hash
        // only save in the memory, not be persisted
-       Parent         *Checkpoint
+       Parent         *Checkpoint `json:"-"`
        StartTimestamp uint64
        SupLinks       []*SupLink
        Status         CheckpointStatus
index 31ebe36..e87180f 100644 (file)
@@ -19,16 +19,19 @@ type Store interface {
        GetContract(hash [32]byte) ([]byte, error)
 
        GetCheckpoint(*bc.Hash) (*state.Checkpoint, error)
+       CheckpointsFromNode(height uint64, hash *bc.Hash) ([]*state.Checkpoint, error)
        GetCheckpointsByHeight(uint64) ([]*state.Checkpoint, error)
        SaveCheckpoints(...*state.Checkpoint) error
 
        LoadBlockIndex(uint64) (*state.BlockIndex, error)
        SaveBlock(*types.Block) error
-       SaveChainStatus(*state.BlockNode, *state.UtxoViewpoint, *state.ContractViewpoint) error
+       SaveChainStatus(*state.BlockNode, *state.UtxoViewpoint, *state.ContractViewpoint, uint64, *bc.Hash) error
 }
 
 // BlockStoreState represents the core's db status
 type BlockStoreState struct {
-       Height uint64
-       Hash   *bc.Hash
+       Height          uint64
+       Hash            *bc.Hash
+       FinalizedHeight uint64
+       FinalizedHash   *bc.Hash
 }
similarity index 98%
rename from protocol/consensus/tree_node.go
rename to protocol/tree_node.go
index d3b451f..ac3da93 100644 (file)
@@ -1,4 +1,4 @@
-package consensus
+package protocol
 
 import (
        "errors"
index 98be029..9b4773b 100644 (file)
@@ -101,6 +101,7 @@ func (s *mockStore) GetBlockHeader(hash *bc.Hash) (*types.BlockHeader, error)
 func (s *mockStore) GetCheckpoint(hash *bc.Hash) (*state.Checkpoint, error)       { return nil, nil }
 func (s *mockStore) GetCheckpointsByHeight(u uint64) ([]*state.Checkpoint, error) { return nil, nil }
 func (s *mockStore) SaveCheckpoints(...*state.Checkpoint) error                   { return nil }
+func (s *mockStore) CheckpointsFromNode(height uint64, hash *bc.Hash) ([]*state.Checkpoint, error)      { return nil, nil }
 func (s *mockStore) BlockExist(hash *bc.Hash) bool                                { return false }
 func (s *mockStore) GetBlock(*bc.Hash) (*types.Block, error)                      { return nil, nil }
 func (s *mockStore) GetStoreStatus() *BlockStoreState                             { return nil }
@@ -109,7 +110,7 @@ func (s *mockStore) GetUtxo(*bc.Hash) (*storage.UtxoEntry, error)
 func (s *mockStore) GetContract(hash [32]byte) ([]byte, error)                    { return nil, nil }
 func (s *mockStore) LoadBlockIndex(uint64) (*state.BlockIndex, error)             { return nil, nil }
 func (s *mockStore) SaveBlock(*types.Block) error                                 { return nil }
-func (s *mockStore) SaveChainStatus(*state.BlockNode, *state.UtxoViewpoint, *state.ContractViewpoint) error {
+func (s *mockStore) SaveChainStatus(*state.BlockNode, *state.UtxoViewpoint, *state.ContractViewpoint, uint64, *bc.Hash) error {
        return nil
 }
 
@@ -596,6 +597,7 @@ func (s *mockStore1) GetBlockHeader(hash *bc.Hash) (*types.BlockHeader, error)
 func (s *mockStore1) GetCheckpoint(hash *bc.Hash) (*state.Checkpoint, error)       { return nil, nil }
 func (s *mockStore1) GetCheckpointsByHeight(u uint64) ([]*state.Checkpoint, error) { return nil, nil }
 func (s *mockStore1) SaveCheckpoints(...*state.Checkpoint) error                   { return nil }
+func (s *mockStore1) CheckpointsFromNode(height uint64, hash *bc.Hash) ([]*state.Checkpoint, error)      { return nil, nil }
 func (s *mockStore1) BlockExist(hash *bc.Hash) bool                                { return false }
 func (s *mockStore1) GetBlock(*bc.Hash) (*types.Block, error)                      { return nil, nil }
 func (s *mockStore1) GetStoreStatus() *BlockStoreState                             { return nil }
@@ -610,7 +612,7 @@ func (s *mockStore1) GetUtxo(*bc.Hash) (*storage.UtxoEntry, error)        { retu
 func (s *mockStore1) GetContract(hash [32]byte) ([]byte, error)           { return nil, nil }
 func (s *mockStore1) LoadBlockIndex(uint64) (*state.BlockIndex, error)    { return nil, nil }
 func (s *mockStore1) SaveBlock(*types.Block) error { return nil }
-func (s *mockStore1) SaveChainStatus(*state.BlockNode, *state.UtxoViewpoint, *state.ContractViewpoint) error { return nil}
+func (s *mockStore1) SaveChainStatus(*state.BlockNode, *state.UtxoViewpoint, *state.ContractViewpoint, uint64, *bc.Hash) error { return nil}
 
 func TestProcessTransaction(t *testing.T) {
        txPool := &TxPool{
index c7b4907..b70dd15 100644 (file)
@@ -293,7 +293,7 @@ func TestAttachOrDetachBlocks(t *testing.T) {
                        utxoViewpoint0.Entries[k] = v
                }
                contractView := state.NewContractViewpoint()
-               if err := store.SaveChainStatus(node, utxoViewpoint0, contractView); err != nil {
+               if err := store.SaveChainStatus(node, utxoViewpoint0, contractView, 0, &bc.Hash{}); err != nil {
                        t.Error(err)
                }
 
@@ -315,7 +315,7 @@ func TestAttachOrDetachBlocks(t *testing.T) {
                                t.Error(err)
                        }
                }
-               if err := store.SaveChainStatus(node, utxoViewpoint, contractView); err != nil {
+               if err := store.SaveChainStatus(node, utxoViewpoint, contractView, 0, &bc.Hash{}); err != nil {
                        t.Error(err)
                }