OSDN Git Service

feat: init cross_tx keepers (#146)
[bytom/vapor.git] / vendor / github.com / jinzhu / gorm / callback_update.go
1 package gorm
2
3 import (
4         "errors"
5         "fmt"
6         "sort"
7         "strings"
8 )
9
10 // Define callbacks for updating
11 func init() {
12         DefaultCallback.Update().Register("gorm:assign_updating_attributes", assignUpdatingAttributesCallback)
13         DefaultCallback.Update().Register("gorm:begin_transaction", beginTransactionCallback)
14         DefaultCallback.Update().Register("gorm:before_update", beforeUpdateCallback)
15         DefaultCallback.Update().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
16         DefaultCallback.Update().Register("gorm:update_time_stamp", updateTimeStampForUpdateCallback)
17         DefaultCallback.Update().Register("gorm:update", updateCallback)
18         DefaultCallback.Update().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
19         DefaultCallback.Update().Register("gorm:after_update", afterUpdateCallback)
20         DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
21 }
22
23 // assignUpdatingAttributesCallback assign updating attributes to model
24 func assignUpdatingAttributesCallback(scope *Scope) {
25         if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
26                 if updateMaps, hasUpdate := scope.updatedAttrsWithValues(attrs); hasUpdate {
27                         scope.InstanceSet("gorm:update_attrs", updateMaps)
28                 } else {
29                         scope.SkipLeft()
30                 }
31         }
32 }
33
34 // beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating
35 func beforeUpdateCallback(scope *Scope) {
36         if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() {
37                 scope.Err(errors.New("Missing WHERE clause while updating"))
38                 return
39         }
40         if _, ok := scope.Get("gorm:update_column"); !ok {
41                 if !scope.HasError() {
42                         scope.CallMethod("BeforeSave")
43                 }
44                 if !scope.HasError() {
45                         scope.CallMethod("BeforeUpdate")
46                 }
47         }
48 }
49
50 // updateTimeStampForUpdateCallback will set `UpdatedAt` when updating
51 func updateTimeStampForUpdateCallback(scope *Scope) {
52         if _, ok := scope.Get("gorm:update_column"); !ok {
53                 scope.SetColumn("UpdatedAt", NowFunc())
54         }
55 }
56
57 // updateCallback the callback used to update data to database
58 func updateCallback(scope *Scope) {
59         if !scope.HasError() {
60                 var sqls []string
61
62                 if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
63                         // Sort the column names so that the generated SQL is the same every time.
64                         updateMap := updateAttrs.(map[string]interface{})
65                         var columns []string
66                         for c := range updateMap {
67                                 columns = append(columns, c)
68                         }
69                         sort.Strings(columns)
70
71                         for _, column := range columns {
72                                 value := updateMap[column]
73                                 sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value)))
74                         }
75                 } else {
76                         for _, field := range scope.Fields() {
77                                 if scope.changeableField(field) {
78                                         if !field.IsPrimaryKey && field.IsNormal {
79                                                 if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue {
80                                                         sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
81                                                 }
82                                         } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
83                                                 for _, foreignKey := range relationship.ForeignDBNames {
84                                                         if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
85                                                                 sqls = append(sqls,
86                                                                         fmt.Sprintf("%v = %v", scope.Quote(foreignField.DBName), scope.AddToVars(foreignField.Field.Interface())))
87                                                         }
88                                                 }
89                                         }
90                                 }
91                         }
92                 }
93
94                 var extraOption string
95                 if str, ok := scope.Get("gorm:update_option"); ok {
96                         extraOption = fmt.Sprint(str)
97                 }
98
99                 if len(sqls) > 0 {
100                         scope.Raw(fmt.Sprintf(
101                                 "UPDATE %v SET %v%v%v",
102                                 scope.QuotedTableName(),
103                                 strings.Join(sqls, ", "),
104                                 addExtraSpaceIfExist(scope.CombinedConditionSql()),
105                                 addExtraSpaceIfExist(extraOption),
106                         )).Exec()
107                 }
108         }
109 }
110
111 // afterUpdateCallback will invoke `AfterUpdate`, `AfterSave` method after updating
112 func afterUpdateCallback(scope *Scope) {
113         if _, ok := scope.Get("gorm:update_column"); !ok {
114                 if !scope.HasError() {
115                         scope.CallMethod("AfterUpdate")
116                 }
117                 if !scope.HasError() {
118                         scope.CallMethod("AfterSave")
119                 }
120         }
121 }