OSDN Git Service

feat: init cross_tx keepers (#146)
[bytom/vapor.git] / vendor / github.com / jinzhu / gorm / dialects / mssql / mssql.go
1 package mssql
2
3 import (
4         "database/sql/driver"
5         "encoding/json"
6         "errors"
7         "fmt"
8         "reflect"
9         "strconv"
10         "strings"
11         "time"
12
13         // Importing mssql driver package only in dialect file, otherwide not needed
14         _ "github.com/denisenkom/go-mssqldb"
15         "github.com/jinzhu/gorm"
16 )
17
18 func setIdentityInsert(scope *gorm.Scope) {
19         if scope.Dialect().GetName() == "mssql" {
20                 for _, field := range scope.PrimaryFields() {
21                         if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsBlank {
22                                 scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v ON", scope.TableName()))
23                                 scope.InstanceSet("mssql:identity_insert_on", true)
24                         }
25                 }
26         }
27 }
28
29 func turnOffIdentityInsert(scope *gorm.Scope) {
30         if scope.Dialect().GetName() == "mssql" {
31                 if _, ok := scope.InstanceGet("mssql:identity_insert_on"); ok {
32                         scope.NewDB().Exec(fmt.Sprintf("SET IDENTITY_INSERT %v OFF", scope.TableName()))
33                 }
34         }
35 }
36
37 func init() {
38         gorm.DefaultCallback.Create().After("gorm:begin_transaction").Register("mssql:set_identity_insert", setIdentityInsert)
39         gorm.DefaultCallback.Create().Before("gorm:commit_or_rollback_transaction").Register("mssql:turn_off_identity_insert", turnOffIdentityInsert)
40         gorm.RegisterDialect("mssql", &mssql{})
41 }
42
43 type mssql struct {
44         db gorm.SQLCommon
45         gorm.DefaultForeignKeyNamer
46 }
47
48 func (mssql) GetName() string {
49         return "mssql"
50 }
51
52 func (s *mssql) SetDB(db gorm.SQLCommon) {
53         s.db = db
54 }
55
56 func (mssql) BindVar(i int) string {
57         return "$$$" // ?
58 }
59
60 func (mssql) Quote(key string) string {
61         return fmt.Sprintf(`[%s]`, key)
62 }
63
64 func (s *mssql) DataTypeOf(field *gorm.StructField) string {
65         var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field, s)
66
67         if sqlType == "" {
68                 switch dataValue.Kind() {
69                 case reflect.Bool:
70                         sqlType = "bit"
71                 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
72                         if s.fieldCanAutoIncrement(field) {
73                                 field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
74                                 sqlType = "int IDENTITY(1,1)"
75                         } else {
76                                 sqlType = "int"
77                         }
78                 case reflect.Int64, reflect.Uint64:
79                         if s.fieldCanAutoIncrement(field) {
80                                 field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
81                                 sqlType = "bigint IDENTITY(1,1)"
82                         } else {
83                                 sqlType = "bigint"
84                         }
85                 case reflect.Float32, reflect.Float64:
86                         sqlType = "float"
87                 case reflect.String:
88                         if size > 0 && size < 8000 {
89                                 sqlType = fmt.Sprintf("nvarchar(%d)", size)
90                         } else {
91                                 sqlType = "nvarchar(max)"
92                         }
93                 case reflect.Struct:
94                         if _, ok := dataValue.Interface().(time.Time); ok {
95                                 sqlType = "datetimeoffset"
96                         }
97                 default:
98                         if gorm.IsByteArrayOrSlice(dataValue) {
99                                 if size > 0 && size < 8000 {
100                                         sqlType = fmt.Sprintf("varbinary(%d)", size)
101                                 } else {
102                                         sqlType = "varbinary(max)"
103                                 }
104                         }
105                 }
106         }
107
108         if sqlType == "" {
109                 panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String()))
110         }
111
112         if strings.TrimSpace(additionalType) == "" {
113                 return sqlType
114         }
115         return fmt.Sprintf("%v %v", sqlType, additionalType)
116 }
117
118 func (s mssql) fieldCanAutoIncrement(field *gorm.StructField) bool {
119         if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok {
120                 return value != "FALSE"
121         }
122         return field.IsPrimaryKey
123 }
124
125 func (s mssql) HasIndex(tableName string, indexName string) bool {
126         var count int
127         s.db.QueryRow("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Scan(&count)
128         return count > 0
129 }
130
131 func (s mssql) RemoveIndex(tableName string, indexName string) error {
132         _, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, s.Quote(tableName)))
133         return err
134 }
135
136 func (s mssql) HasForeignKey(tableName string, foreignKeyName string) bool {
137         var count int
138         currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
139         s.db.QueryRow(`SELECT count(*) 
140         FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id 
141                 inner join information_schema.tables as I on I.TABLE_NAME = T.name 
142         WHERE F.name = ? 
143                 AND T.Name = ? AND I.TABLE_CATALOG = ?;`, foreignKeyName, tableName, currentDatabase).Scan(&count)
144         return count > 0
145 }
146
147 func (s mssql) HasTable(tableName string) bool {
148         var count int
149         currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
150         s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, currentDatabase).Scan(&count)
151         return count > 0
152 }
153
154 func (s mssql) HasColumn(tableName string, columnName string) bool {
155         var count int
156         currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
157         s.db.QueryRow("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
158         return count > 0
159 }
160
161 func (s mssql) ModifyColumn(tableName string, columnName string, typ string) error {
162         _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v %v", tableName, columnName, typ))
163         return err
164 }
165
166 func (s mssql) CurrentDatabase() (name string) {
167         s.db.QueryRow("SELECT DB_NAME() AS [Current Database]").Scan(&name)
168         return
169 }
170
171 func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
172         if offset != nil {
173                 if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
174                         sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset)
175                 }
176         }
177         if limit != nil {
178                 if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
179                         if sql == "" {
180                                 // add default zero offset
181                                 sql += " OFFSET 0 ROWS"
182                         }
183                         sql += fmt.Sprintf(" FETCH NEXT %d ROWS ONLY", parsedLimit)
184                 }
185         }
186         return
187 }
188
189 func (mssql) SelectFromDummyTable() string {
190         return ""
191 }
192
193 func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
194         return ""
195 }
196
197 func (mssql) DefaultValueStr() string {
198         return "DEFAULT VALUES"
199 }
200
201 func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) {
202         if strings.Contains(tableName, ".") {
203                 splitStrings := strings.SplitN(tableName, ".", 2)
204                 return splitStrings[0], splitStrings[1]
205         }
206         return dialect.CurrentDatabase(), tableName
207 }
208
209 // JSON type to support easy handling of JSON data in character table fields
210 // using golang json.RawMessage for deferred decoding/encoding
211 type JSON struct {
212         json.RawMessage
213 }
214
215 // Value get value of JSON
216 func (j JSON) Value() (driver.Value, error) {
217         if len(j.RawMessage) == 0 {
218                 return nil, nil
219         }
220         return j.MarshalJSON()
221 }
222
223 // Scan scan value into JSON
224 func (j *JSON) Scan(value interface{}) error {
225         str, ok := value.(string)
226         if !ok {
227                 return errors.New(fmt.Sprint("Failed to unmarshal JSONB value (strcast):", value))
228         }
229         bytes := []byte(str)
230         return json.Unmarshal(bytes, j)
231 }