OSDN Git Service

feat: init cross_tx keepers (#146)
[bytom/vapor.git] / vendor / github.com / jinzhu / gorm / callback_create.go
1 package gorm
2
3 import (
4         "fmt"
5         "strings"
6 )
7
8 // Define callbacks for creating
9 func init() {
10         DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback)
11         DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback)
12         DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
13         DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback)
14         DefaultCallback.Create().Register("gorm:create", createCallback)
15         DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback)
16         DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
17         DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback)
18         DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
19 }
20
21 // beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
22 func beforeCreateCallback(scope *Scope) {
23         if !scope.HasError() {
24                 scope.CallMethod("BeforeSave")
25         }
26         if !scope.HasError() {
27                 scope.CallMethod("BeforeCreate")
28         }
29 }
30
31 // updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
32 func updateTimeStampForCreateCallback(scope *Scope) {
33         if !scope.HasError() {
34                 now := NowFunc()
35
36                 if createdAtField, ok := scope.FieldByName("CreatedAt"); ok {
37                         if createdAtField.IsBlank {
38                                 createdAtField.Set(now)
39                         }
40                 }
41
42                 if updatedAtField, ok := scope.FieldByName("UpdatedAt"); ok {
43                         if updatedAtField.IsBlank {
44                                 updatedAtField.Set(now)
45                         }
46                 }
47         }
48 }
49
50 // createCallback the callback used to insert data into database
51 func createCallback(scope *Scope) {
52         if !scope.HasError() {
53                 defer scope.trace(NowFunc())
54
55                 var (
56                         columns, placeholders        []string
57                         blankColumnsWithDefaultValue []string
58                 )
59
60                 for _, field := range scope.Fields() {
61                         if scope.changeableField(field) {
62                                 if field.IsNormal && !field.IsIgnored {
63                                         if field.IsBlank && field.HasDefaultValue {
64                                                 blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
65                                                 scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
66                                         } else if !field.IsPrimaryKey || !field.IsBlank {
67                                                 columns = append(columns, scope.Quote(field.DBName))
68                                                 placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
69                                         }
70                                 } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
71                                         for _, foreignKey := range field.Relationship.ForeignDBNames {
72                                                 if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
73                                                         columns = append(columns, scope.Quote(foreignField.DBName))
74                                                         placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
75                                                 }
76                                         }
77                                 }
78                         }
79                 }
80
81                 var (
82                         returningColumn = "*"
83                         quotedTableName = scope.QuotedTableName()
84                         primaryField    = scope.PrimaryField()
85                         extraOption     string
86                 )
87
88                 if str, ok := scope.Get("gorm:insert_option"); ok {
89                         extraOption = fmt.Sprint(str)
90                 }
91
92                 if primaryField != nil {
93                         returningColumn = scope.Quote(primaryField.DBName)
94                 }
95
96                 lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
97
98                 if len(columns) == 0 {
99                         scope.Raw(fmt.Sprintf(
100                                 "INSERT INTO %v %v%v%v",
101                                 quotedTableName,
102                                 scope.Dialect().DefaultValueStr(),
103                                 addExtraSpaceIfExist(extraOption),
104                                 addExtraSpaceIfExist(lastInsertIDReturningSuffix),
105                         ))
106                 } else {
107                         scope.Raw(fmt.Sprintf(
108                                 "INSERT INTO %v (%v) VALUES (%v)%v%v",
109                                 scope.QuotedTableName(),
110                                 strings.Join(columns, ","),
111                                 strings.Join(placeholders, ","),
112                                 addExtraSpaceIfExist(extraOption),
113                                 addExtraSpaceIfExist(lastInsertIDReturningSuffix),
114                         ))
115                 }
116
117                 // execute create sql
118                 if lastInsertIDReturningSuffix == "" || primaryField == nil {
119                         if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
120                                 // set rows affected count
121                                 scope.db.RowsAffected, _ = result.RowsAffected()
122
123                                 // set primary value to primary field
124                                 if primaryField != nil && primaryField.IsBlank {
125                                         if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
126                                                 scope.Err(primaryField.Set(primaryValue))
127                                         }
128                                 }
129                         }
130                 } else {
131                         if primaryField.Field.CanAddr() {
132                                 if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
133                                         primaryField.IsBlank = false
134                                         scope.db.RowsAffected = 1
135                                 }
136                         } else {
137                                 scope.Err(ErrUnaddressable)
138                         }
139                 }
140         }
141 }
142
143 // forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
144 func forceReloadAfterCreateCallback(scope *Scope) {
145         if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
146                 db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string))
147                 for _, field := range scope.Fields() {
148                         if field.IsPrimaryKey && !field.IsBlank {
149                                 db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface())
150                         }
151                 }
152                 db.Scan(scope.Value)
153         }
154 }
155
156 // afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating
157 func afterCreateCallback(scope *Scope) {
158         if !scope.HasError() {
159                 scope.CallMethod("AfterCreate")
160         }
161         if !scope.HasError() {
162                 scope.CallMethod("AfterSave")
163         }
164 }