OSDN Git Service

feat: init cross_tx keepers (#146)
[bytom/vapor.git] / vendor / github.com / jinzhu / gorm / callback_query_preload.go
1 package gorm
2
3 import (
4         "errors"
5         "fmt"
6         "reflect"
7         "strconv"
8         "strings"
9 )
10
11 // preloadCallback used to preload associations
12 func preloadCallback(scope *Scope) {
13         if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
14                 return
15         }
16
17         if ap, ok := scope.Get("gorm:auto_preload"); ok {
18                 // If gorm:auto_preload IS NOT a bool then auto preload.
19                 // Else if it IS a bool, use the value
20                 if apb, ok := ap.(bool); !ok {
21                         autoPreload(scope)
22                 } else if apb {
23                         autoPreload(scope)
24                 }
25         }
26
27         if scope.Search.preload == nil || scope.HasError() {
28                 return
29         }
30
31         var (
32                 preloadedMap = map[string]bool{}
33                 fields       = scope.Fields()
34         )
35
36         for _, preload := range scope.Search.preload {
37                 var (
38                         preloadFields = strings.Split(preload.schema, ".")
39                         currentScope  = scope
40                         currentFields = fields
41                 )
42
43                 for idx, preloadField := range preloadFields {
44                         var currentPreloadConditions []interface{}
45
46                         if currentScope == nil {
47                                 continue
48                         }
49
50                         // if not preloaded
51                         if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
52
53                                 // assign search conditions to last preload
54                                 if idx == len(preloadFields)-1 {
55                                         currentPreloadConditions = preload.conditions
56                                 }
57
58                                 for _, field := range currentFields {
59                                         if field.Name != preloadField || field.Relationship == nil {
60                                                 continue
61                                         }
62
63                                         switch field.Relationship.Kind {
64                                         case "has_one":
65                                                 currentScope.handleHasOnePreload(field, currentPreloadConditions)
66                                         case "has_many":
67                                                 currentScope.handleHasManyPreload(field, currentPreloadConditions)
68                                         case "belongs_to":
69                                                 currentScope.handleBelongsToPreload(field, currentPreloadConditions)
70                                         case "many_to_many":
71                                                 currentScope.handleManyToManyPreload(field, currentPreloadConditions)
72                                         default:
73                                                 scope.Err(errors.New("unsupported relation"))
74                                         }
75
76                                         preloadedMap[preloadKey] = true
77                                         break
78                                 }
79
80                                 if !preloadedMap[preloadKey] {
81                                         scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType))
82                                         return
83                                 }
84                         }
85
86                         // preload next level
87                         if idx < len(preloadFields)-1 {
88                                 currentScope = currentScope.getColumnAsScope(preloadField)
89                                 if currentScope != nil {
90                                         currentFields = currentScope.Fields()
91                                 }
92                         }
93                 }
94         }
95 }
96
97 func autoPreload(scope *Scope) {
98         for _, field := range scope.Fields() {
99                 if field.Relationship == nil {
100                         continue
101                 }
102
103                 if val, ok := field.TagSettingsGet("PRELOAD"); ok {
104                         if preload, err := strconv.ParseBool(val); err != nil {
105                                 scope.Err(errors.New("invalid preload option"))
106                                 return
107                         } else if !preload {
108                                 continue
109                         }
110                 }
111
112                 scope.Search.Preload(field.Name)
113         }
114 }
115
116 func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) {
117         var (
118                 preloadDB         = scope.NewDB()
119                 preloadConditions []interface{}
120         )
121
122         for _, condition := range conditions {
123                 if scopes, ok := condition.(func(*DB) *DB); ok {
124                         preloadDB = scopes(preloadDB)
125                 } else {
126                         preloadConditions = append(preloadConditions, condition)
127                 }
128         }
129
130         return preloadDB, preloadConditions
131 }
132
133 // handleHasOnePreload used to preload has one associations
134 func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
135         relation := field.Relationship
136
137         // get relations's primary keys
138         primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
139         if len(primaryKeys) == 0 {
140                 return
141         }
142
143         // preload conditions
144         preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
145
146         // find relations
147         query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
148         values := toQueryValues(primaryKeys)
149         if relation.PolymorphicType != "" {
150                 query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
151                 values = append(values, relation.PolymorphicValue)
152         }
153
154         results := makeSlice(field.Struct.Type)
155         scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
156
157         // assign find results
158         var (
159                 resultsValue       = indirect(reflect.ValueOf(results))
160                 indirectScopeValue = scope.IndirectValue()
161         )
162
163         if indirectScopeValue.Kind() == reflect.Slice {
164                 foreignValuesToResults := make(map[string]reflect.Value)
165                 for i := 0; i < resultsValue.Len(); i++ {
166                         result := resultsValue.Index(i)
167                         foreignValues := toString(getValueFromFields(result, relation.ForeignFieldNames))
168                         foreignValuesToResults[foreignValues] = result
169                 }
170                 for j := 0; j < indirectScopeValue.Len(); j++ {
171                         indirectValue := indirect(indirectScopeValue.Index(j))
172                         valueString := toString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames))
173                         if result, found := foreignValuesToResults[valueString]; found {
174                                 indirectValue.FieldByName(field.Name).Set(result)
175                         }
176                 }
177         } else {
178                 for i := 0; i < resultsValue.Len(); i++ {
179                         result := resultsValue.Index(i)
180                         scope.Err(field.Set(result))
181                 }
182         }
183 }
184
185 // handleHasManyPreload used to preload has many associations
186 func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
187         relation := field.Relationship
188
189         // get relations's primary keys
190         primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
191         if len(primaryKeys) == 0 {
192                 return
193         }
194
195         // preload conditions
196         preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
197
198         // find relations
199         query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
200         values := toQueryValues(primaryKeys)
201         if relation.PolymorphicType != "" {
202                 query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
203                 values = append(values, relation.PolymorphicValue)
204         }
205
206         results := makeSlice(field.Struct.Type)
207         scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
208
209         // assign find results
210         var (
211                 resultsValue       = indirect(reflect.ValueOf(results))
212                 indirectScopeValue = scope.IndirectValue()
213         )
214
215         if indirectScopeValue.Kind() == reflect.Slice {
216                 preloadMap := make(map[string][]reflect.Value)
217                 for i := 0; i < resultsValue.Len(); i++ {
218                         result := resultsValue.Index(i)
219                         foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
220                         preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result)
221                 }
222
223                 for j := 0; j < indirectScopeValue.Len(); j++ {
224                         object := indirect(indirectScopeValue.Index(j))
225                         objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames)
226                         f := object.FieldByName(field.Name)
227                         if results, ok := preloadMap[toString(objectRealValue)]; ok {
228                                 f.Set(reflect.Append(f, results...))
229                         } else {
230                                 f.Set(reflect.MakeSlice(f.Type(), 0, 0))
231                         }
232                 }
233         } else {
234                 scope.Err(field.Set(resultsValue))
235         }
236 }
237
238 // handleBelongsToPreload used to preload belongs to associations
239 func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
240         relation := field.Relationship
241
242         // preload conditions
243         preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
244
245         // get relations's primary keys
246         primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
247         if len(primaryKeys) == 0 {
248                 return
249         }
250
251         // find relations
252         results := makeSlice(field.Struct.Type)
253         scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
254
255         // assign find results
256         var (
257                 resultsValue       = indirect(reflect.ValueOf(results))
258                 indirectScopeValue = scope.IndirectValue()
259         )
260
261         foreignFieldToObjects := make(map[string][]*reflect.Value)
262         if indirectScopeValue.Kind() == reflect.Slice {
263                 for j := 0; j < indirectScopeValue.Len(); j++ {
264                         object := indirect(indirectScopeValue.Index(j))
265                         valueString := toString(getValueFromFields(object, relation.ForeignFieldNames))
266                         foreignFieldToObjects[valueString] = append(foreignFieldToObjects[valueString], &object)
267                 }
268         }
269
270         for i := 0; i < resultsValue.Len(); i++ {
271                 result := resultsValue.Index(i)
272                 if indirectScopeValue.Kind() == reflect.Slice {
273                         valueString := toString(getValueFromFields(result, relation.AssociationForeignFieldNames))
274                         if objects, found := foreignFieldToObjects[valueString]; found {
275                                 for _, object := range objects {
276                                         object.FieldByName(field.Name).Set(result)
277                                 }
278                         }
279                 } else {
280                         scope.Err(field.Set(result))
281                 }
282         }
283 }
284
285 // handleManyToManyPreload used to preload many to many associations
286 func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
287         var (
288                 relation         = field.Relationship
289                 joinTableHandler = relation.JoinTableHandler
290                 fieldType        = field.Struct.Type.Elem()
291                 foreignKeyValue  interface{}
292                 foreignKeyType   = reflect.ValueOf(&foreignKeyValue).Type()
293                 linkHash         = map[string][]reflect.Value{}
294                 isPtr            bool
295         )
296
297         if fieldType.Kind() == reflect.Ptr {
298                 isPtr = true
299                 fieldType = fieldType.Elem()
300         }
301
302         var sourceKeys = []string{}
303         for _, key := range joinTableHandler.SourceForeignKeys() {
304                 sourceKeys = append(sourceKeys, key.DBName)
305         }
306
307         // preload conditions
308         preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
309
310         // generate query with join table
311         newScope := scope.New(reflect.New(fieldType).Interface())
312         preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value)
313
314         if len(preloadDB.search.selects) == 0 {
315                 preloadDB = preloadDB.Select("*")
316         }
317
318         preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
319
320         // preload inline conditions
321         if len(preloadConditions) > 0 {
322                 preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
323         }
324
325         rows, err := preloadDB.Rows()
326
327         if scope.Err(err) != nil {
328                 return
329         }
330         defer rows.Close()
331
332         columns, _ := rows.Columns()
333         for rows.Next() {
334                 var (
335                         elem   = reflect.New(fieldType).Elem()
336                         fields = scope.New(elem.Addr().Interface()).Fields()
337                 )
338
339                 // register foreign keys in join tables
340                 var joinTableFields []*Field
341                 for _, sourceKey := range sourceKeys {
342                         joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()})
343                 }
344
345                 scope.scan(rows, columns, append(fields, joinTableFields...))
346
347                 scope.New(elem.Addr().Interface()).
348                         InstanceSet("gorm:skip_query_callback", true).
349                         callCallbacks(scope.db.parent.callbacks.queries)
350
351                 var foreignKeys = make([]interface{}, len(sourceKeys))
352                 // generate hashed forkey keys in join table
353                 for idx, joinTableField := range joinTableFields {
354                         if !joinTableField.Field.IsNil() {
355                                 foreignKeys[idx] = joinTableField.Field.Elem().Interface()
356                         }
357                 }
358                 hashedSourceKeys := toString(foreignKeys)
359
360                 if isPtr {
361                         linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr())
362                 } else {
363                         linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem)
364                 }
365         }
366
367         if err := rows.Err(); err != nil {
368                 scope.Err(err)
369         }
370
371         // assign find results
372         var (
373                 indirectScopeValue = scope.IndirectValue()
374                 fieldsSourceMap    = map[string][]reflect.Value{}
375                 foreignFieldNames  = []string{}
376         )
377
378         for _, dbName := range relation.ForeignFieldNames {
379                 if field, ok := scope.FieldByName(dbName); ok {
380                         foreignFieldNames = append(foreignFieldNames, field.Name)
381                 }
382         }
383
384         if indirectScopeValue.Kind() == reflect.Slice {
385                 for j := 0; j < indirectScopeValue.Len(); j++ {
386                         object := indirect(indirectScopeValue.Index(j))
387                         key := toString(getValueFromFields(object, foreignFieldNames))
388                         fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name))
389                 }
390         } else if indirectScopeValue.IsValid() {
391                 key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
392                 fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
393         }
394         for source, link := range linkHash {
395                 for i, field := range fieldsSourceMap[source] {
396                         //If not 0 this means Value is a pointer and we already added preloaded models to it
397                         if fieldsSourceMap[source][i].Len() != 0 {
398                                 continue
399                         }
400                         field.Set(reflect.Append(fieldsSourceMap[source][i], link...))
401                 }
402
403         }
404 }