OSDN Git Service

Modify access token db to sqldb dev_access_token_db
authormars <mars@bytom.io>
Wed, 8 May 2019 07:26:32 +0000 (15:26 +0800)
committermars <mars@bytom.io>
Wed, 8 May 2019 08:39:35 +0000 (16:39 +0800)
accesstoken/accesstoken.go
accesstoken/accesstoken_test.go
api/api_test.go
database/orm/accesstoken.go [new file with mode: 0644]
net/http/authn/authn_test.go
node/node.go

index 7b32b0e..7539e33 100644 (file)
@@ -4,12 +4,15 @@ package accesstoken
 
 import (
        "crypto/rand"
-       "encoding/json"
        "fmt"
        "regexp"
        "strings"
        "time"
 
+       "github.com/jinzhu/gorm"
+
+       "github.com/vapor/database/orm"
+
        "github.com/vapor/crypto/sha3pool"
        dbm "github.com/vapor/database/db"
        "github.com/vapor/errors"
@@ -42,13 +45,22 @@ type Token struct {
        Created time.Time `json:"created_at"`
 }
 
+func tokenFromOrmToken(ac orm.AccessToken) *Token {
+       return &Token{
+               ID:      ac.ID,
+               Token:   ac.Token,
+               Type:    ac.Type,
+               Created: ac.Created,
+       }
+}
+
 // CredentialStore store user access credential.
 type CredentialStore struct {
-       DB dbm.DB
+       DB dbm.SQLDB
 }
 
 // NewStore creates and returns a new Store object.
-func NewStore(db dbm.DB) *CredentialStore {
+func NewStore(db dbm.SQLDB) *CredentialStore {
        return &CredentialStore{
                DB: db,
        }
@@ -60,33 +72,30 @@ func (cs *CredentialStore) Create(id, typ string) (*Token, error) {
                return nil, errors.WithDetailf(ErrBadID, "invalid id %q", id)
        }
 
-       key := []byte(id)
-       if cs.DB.Get(key) != nil {
-               return nil, errors.WithDetailf(ErrDuplicateID, "id %q already in use", id)
-       }
-
-       secret := make([]byte, tokenSize)
-       if _, err := rand.Read(secret); err != nil {
-               return nil, err
-       }
-
-       hashedSecret := make([]byte, tokenSize)
-       sha3pool.Sum256(hashedSecret, secret)
-
-       token := &Token{
-               ID:      id,
-               Token:   fmt.Sprintf("%s:%x", id, hashedSecret),
-               Type:    typ,
-               Created: time.Now(),
-       }
+       accessToken := orm.AccessToken{ID: id}
 
-       value, err := json.Marshal(token)
-       if err != nil {
-               return nil, err
+       if err := cs.DB.Db().Where(&accessToken).Find(&accessToken).Error; err != nil {
+               if err != gorm.ErrRecordNotFound {
+                       return nil, err
+               }
+               secret := make([]byte, tokenSize)
+               if _, err := rand.Read(secret); err != nil {
+                       return nil, err
+               }
+               hashedSecret := make([]byte, tokenSize)
+               sha3pool.Sum256(hashedSecret, secret)
+               accessToken = orm.AccessToken{
+                       ID:      id,
+                       Token:   fmt.Sprintf("%s:%x", id, hashedSecret),
+                       Type:    typ,
+                       Created: time.Now(),
+               }
+               if err = cs.DB.Db().Create(&accessToken).Error; err != nil {
+                       return nil, err
+               }
+               return tokenFromOrmToken(accessToken), nil
        }
-       cs.DB.Set(key, value)
-
-       return token, nil
+       return nil, errors.WithDetailf(ErrDuplicateID, "id %q already in use", id)
 }
 
 // Check returns whether or not an id-secret pair is a valid access token.
@@ -94,18 +103,13 @@ func (cs *CredentialStore) Check(id string, secret string) error {
        if !validIDRegexp.MatchString(id) {
                return errors.WithDetailf(ErrBadID, "invalid id %q", id)
        }
+       accessToken := orm.AccessToken{ID: id}
 
-       var value []byte
-       token := &Token{}
-
-       if value = cs.DB.Get([]byte(id)); value == nil {
-               return errors.WithDetailf(ErrNoMatchID, "check id %q nonexisting", id)
-       }
-       if err := json.Unmarshal(value, token); err != nil {
+       if err := cs.DB.Db().Where(&accessToken).Find(&accessToken).Error; err != nil {
                return err
        }
 
-       if strings.Split(token.Token, ":")[1] == secret {
+       if strings.Split(accessToken.Token, ":")[1] == secret {
                return nil
        }
 
@@ -115,15 +119,16 @@ func (cs *CredentialStore) Check(id string, secret string) error {
 // List lists all access tokens.
 func (cs *CredentialStore) List() ([]*Token, error) {
        tokens := make([]*Token, 0)
-       iter := cs.DB.Iterator()
-       defer iter.Release()
-
-       for iter.Next() {
-               token := &Token{}
-               if err := json.Unmarshal(iter.Value(), token); err != nil {
+       rows, err := cs.DB.Db().Model(&orm.AccessToken{}).Rows()
+       if err != nil {
+               return nil, err
+       }
+       for rows.Next() {
+               accessToken := orm.AccessToken{}
+               if err := rows.Scan(&accessToken.ID, &accessToken.Token, &accessToken.Type, &accessToken.Created); err != nil {
                        return nil, err
                }
-               tokens = append(tokens, token)
+               tokens = append(tokens, tokenFromOrmToken(accessToken))
        }
        return tokens, nil
 }
@@ -133,11 +138,11 @@ func (cs *CredentialStore) Delete(id string) error {
        if !validIDRegexp.MatchString(id) {
                return errors.WithDetailf(ErrBadID, "invalid id %q", id)
        }
-
-       if value := cs.DB.Get([]byte(id)); value == nil {
-               return errors.WithDetailf(ErrNoMatchID, "check id %q", id)
+       if err := cs.DB.Db().Delete(&orm.AccessToken{ID: id}).Error; err != nil {
+               if err == gorm.ErrRecordNotFound {
+                       return errors.WithDetailf(ErrNoMatchID, "check id %q", id)
+               }
+               return err
        }
-
-       cs.DB.Delete([]byte(id))
        return nil
 }
index e5639cd..82b0860 100644 (file)
@@ -6,14 +6,25 @@ import (
        "strings"
        "testing"
 
+       "github.com/jinzhu/gorm"
+
+       "github.com/vapor/database/orm"
+
        dbm "github.com/vapor/database/db"
        _ "github.com/vapor/database/leveldb"
+       _ "github.com/vapor/database/sqlite"
        "github.com/vapor/errors"
 )
 
 func TestCreate(t *testing.T) {
-       testDB := dbm.NewDB("testdb", "leveldb", "temp")
-       defer os.RemoveAll("temp")
+       testDB := dbm.NewSqlDB("sql", "sqlitedb", "temp")
+       defer func() {
+               testDB.Db().Close()
+               os.RemoveAll("temp")
+       }()
+
+       testDB.Db().AutoMigrate(&orm.AccessToken{})
+
        cs := NewStore(testDB)
 
        cases := []struct {
@@ -37,8 +48,13 @@ func TestCreate(t *testing.T) {
 
 func TestList(t *testing.T) {
        ctx := context.Background()
-       testDB := dbm.NewDB("testdb", "leveldb", "temp")
-       defer os.RemoveAll("temp")
+       testDB := dbm.NewSqlDB("sql", "sqlitedb", "temp")
+       defer func() {
+               testDB.Db().Close()
+               os.RemoveAll("temp")
+       }()
+
+       testDB.Db().AutoMigrate(&orm.AccessToken{})
        cs := NewStore(testDB)
 
        tokenMap := make(map[string]*Token)
@@ -64,8 +80,13 @@ func TestList(t *testing.T) {
 
 func TestCheck(t *testing.T) {
        ctx := context.Background()
-       testDB := dbm.NewDB("testdb", "leveldb", "temp")
-       defer os.RemoveAll("temp")
+       testDB := dbm.NewSqlDB("sql", "sqlitedb", "temp")
+       defer func() {
+               testDB.Db().Close()
+               os.RemoveAll("temp")
+       }()
+
+       testDB.Db().AutoMigrate(&orm.AccessToken{})
        cs := NewStore(testDB)
 
        token := mustCreateToken(ctx, t, cs, "x", "client")
@@ -82,8 +103,13 @@ func TestCheck(t *testing.T) {
 
 func TestDelete(t *testing.T) {
        ctx := context.Background()
-       testDB := dbm.NewDB("testdb", "leveldb", "temp")
-       defer os.RemoveAll("temp")
+       testDB := dbm.NewSqlDB("sql", "sqlitedb", "temp")
+       defer func() {
+               testDB.Db().Close()
+               os.RemoveAll("temp")
+       }()
+
+       testDB.Db().AutoMigrate(&orm.AccessToken{})
        cs := NewStore(testDB)
 
        const id = "Y"
@@ -94,15 +120,35 @@ func TestDelete(t *testing.T) {
                t.Fatal(err)
        }
 
-       value := cs.DB.Get([]byte(id))
-       if len(value) > 0 {
+       accessToken := orm.AccessToken{ID: id}
+
+       err = cs.DB.Db().Where(&accessToken).Find(&accessToken).Error
+       if err != gorm.ErrRecordNotFound {
+               t.Fatal(err)
+       }
+
+       if err == nil {
                t.Fatal("delete fail")
        }
+
+       /*
+               cs.List
+
+               value := cs.DB.Get([]byte(id))
+               if len(value) > 0 {
+                       t.Fatal("delete fail")
+               }
+       */
 }
 
 func TestDeleteWithInvalidId(t *testing.T) {
-       testDB := dbm.NewDB("testdb", "leveldb", "temp")
-       defer os.RemoveAll("temp")
+       testDB := dbm.NewSqlDB("sql", "sqlitedb", "temp")
+       defer func() {
+               testDB.Db().Close()
+               os.RemoveAll("temp")
+       }()
+
+       testDB.Db().AutoMigrate(&orm.AccessToken{})
        cs := NewStore(testDB)
 
        err := cs.Delete("@")
index d82b26c..787523f 100644 (file)
@@ -14,6 +14,8 @@ import (
        "github.com/vapor/consensus"
        dbm "github.com/vapor/database/db"
        _ "github.com/vapor/database/leveldb"
+       "github.com/vapor/database/orm"
+       _ "github.com/vapor/database/sqlite"
        "github.com/vapor/testutil"
 )
 
@@ -27,8 +29,13 @@ func TestAPIHandler(t *testing.T) {
        defer server.Close()
 
        // create accessTokens
-       testDB := dbm.NewDB("testdb", "leveldb", "temp")
-       defer os.RemoveAll("temp")
+       testDB := dbm.NewSqlDB("sql", "sqlitedb", "temp")
+       defer func() {
+               testDB.Db().Close()
+               os.RemoveAll("temp")
+       }()
+
+       testDB.Db().AutoMigrate(&orm.AccessToken{})
        a.accessTokens = accesstoken.NewStore(testDB)
 
        client := &rpc.Client{
diff --git a/database/orm/accesstoken.go b/database/orm/accesstoken.go
new file mode 100644 (file)
index 0000000..df6e7bc
--- /dev/null
@@ -0,0 +1,10 @@
+package orm
+
+import "time"
+
+type AccessToken struct {
+       ID      string `gorm:"primary_key"`
+       Token   string
+       Type    string
+       Created time.Time
+}
index 71476cc..d63b7da 100644 (file)
@@ -9,12 +9,19 @@ import (
        "github.com/vapor/accesstoken"
        dbm "github.com/vapor/database/db"
        _ "github.com/vapor/database/leveldb"
+       "github.com/vapor/database/orm"
+       _ "github.com/vapor/database/sqlite"
        "github.com/vapor/errors"
 )
 
 func TestAuthenticate(t *testing.T) {
-       tokenDB := dbm.NewDB("testdb", "leveldb", "temp")
-       defer os.RemoveAll("temp")
+       tokenDB := dbm.NewSqlDB("sql", "sqlitedb", "temp")
+       defer func() {
+               tokenDB.Db().Close()
+               os.RemoveAll("temp")
+       }()
+
+       tokenDB.Db().AutoMigrate(&orm.AccessToken{})
        tokenStore := accesstoken.NewStore(tokenDB)
        token, err := tokenStore.Create("alice", "test")
        if err != nil {
index 7354485..94974a2 100644 (file)
@@ -92,8 +92,7 @@ func NewNode(config *cfg.Config) *Node {
        initDatabaseTable(sqlDB)
        sqlStore := database.NewSQLStore(sqlDB)
 
-       tokenDB := dbm.NewDB("accesstoken", config.DBBackend, config.DBDir())
-       accessTokens := accesstoken.NewStore(tokenDB)
+       accessTokens := accesstoken.NewStore(sqlDB)
 
        txPool := protocol.NewTxPool(sqlStore)
        chain, err := protocol.NewChain(sqlStore, txPool)