OSDN Git Service

Merge pull request #201 from Bytom/v0.1
[bytom/vapor.git] / vendor / github.com / jinzhu / gorm / callback_query_preload.go
diff --git a/vendor/github.com/jinzhu/gorm/callback_query_preload.go b/vendor/github.com/jinzhu/gorm/callback_query_preload.go
new file mode 100755 (executable)
index 0000000..d7c8a13
--- /dev/null
@@ -0,0 +1,404 @@
+package gorm
+
+import (
+       "errors"
+       "fmt"
+       "reflect"
+       "strconv"
+       "strings"
+)
+
+// preloadCallback used to preload associations
+func preloadCallback(scope *Scope) {
+       if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip {
+               return
+       }
+
+       if ap, ok := scope.Get("gorm:auto_preload"); ok {
+               // If gorm:auto_preload IS NOT a bool then auto preload.
+               // Else if it IS a bool, use the value
+               if apb, ok := ap.(bool); !ok {
+                       autoPreload(scope)
+               } else if apb {
+                       autoPreload(scope)
+               }
+       }
+
+       if scope.Search.preload == nil || scope.HasError() {
+               return
+       }
+
+       var (
+               preloadedMap = map[string]bool{}
+               fields       = scope.Fields()
+       )
+
+       for _, preload := range scope.Search.preload {
+               var (
+                       preloadFields = strings.Split(preload.schema, ".")
+                       currentScope  = scope
+                       currentFields = fields
+               )
+
+               for idx, preloadField := range preloadFields {
+                       var currentPreloadConditions []interface{}
+
+                       if currentScope == nil {
+                               continue
+                       }
+
+                       // if not preloaded
+                       if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
+
+                               // assign search conditions to last preload
+                               if idx == len(preloadFields)-1 {
+                                       currentPreloadConditions = preload.conditions
+                               }
+
+                               for _, field := range currentFields {
+                                       if field.Name != preloadField || field.Relationship == nil {
+                                               continue
+                                       }
+
+                                       switch field.Relationship.Kind {
+                                       case "has_one":
+                                               currentScope.handleHasOnePreload(field, currentPreloadConditions)
+                                       case "has_many":
+                                               currentScope.handleHasManyPreload(field, currentPreloadConditions)
+                                       case "belongs_to":
+                                               currentScope.handleBelongsToPreload(field, currentPreloadConditions)
+                                       case "many_to_many":
+                                               currentScope.handleManyToManyPreload(field, currentPreloadConditions)
+                                       default:
+                                               scope.Err(errors.New("unsupported relation"))
+                                       }
+
+                                       preloadedMap[preloadKey] = true
+                                       break
+                               }
+
+                               if !preloadedMap[preloadKey] {
+                                       scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType))
+                                       return
+                               }
+                       }
+
+                       // preload next level
+                       if idx < len(preloadFields)-1 {
+                               currentScope = currentScope.getColumnAsScope(preloadField)
+                               if currentScope != nil {
+                                       currentFields = currentScope.Fields()
+                               }
+                       }
+               }
+       }
+}
+
+func autoPreload(scope *Scope) {
+       for _, field := range scope.Fields() {
+               if field.Relationship == nil {
+                       continue
+               }
+
+               if val, ok := field.TagSettingsGet("PRELOAD"); ok {
+                       if preload, err := strconv.ParseBool(val); err != nil {
+                               scope.Err(errors.New("invalid preload option"))
+                               return
+                       } else if !preload {
+                               continue
+                       }
+               }
+
+               scope.Search.Preload(field.Name)
+       }
+}
+
+func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) {
+       var (
+               preloadDB         = scope.NewDB()
+               preloadConditions []interface{}
+       )
+
+       for _, condition := range conditions {
+               if scopes, ok := condition.(func(*DB) *DB); ok {
+                       preloadDB = scopes(preloadDB)
+               } else {
+                       preloadConditions = append(preloadConditions, condition)
+               }
+       }
+
+       return preloadDB, preloadConditions
+}
+
+// handleHasOnePreload used to preload has one associations
+func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
+       relation := field.Relationship
+
+       // get relations's primary keys
+       primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
+       if len(primaryKeys) == 0 {
+               return
+       }
+
+       // preload conditions
+       preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
+
+       // find relations
+       query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
+       values := toQueryValues(primaryKeys)
+       if relation.PolymorphicType != "" {
+               query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
+               values = append(values, relation.PolymorphicValue)
+       }
+
+       results := makeSlice(field.Struct.Type)
+       scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
+
+       // assign find results
+       var (
+               resultsValue       = indirect(reflect.ValueOf(results))
+               indirectScopeValue = scope.IndirectValue()
+       )
+
+       if indirectScopeValue.Kind() == reflect.Slice {
+               foreignValuesToResults := make(map[string]reflect.Value)
+               for i := 0; i < resultsValue.Len(); i++ {
+                       result := resultsValue.Index(i)
+                       foreignValues := toString(getValueFromFields(result, relation.ForeignFieldNames))
+                       foreignValuesToResults[foreignValues] = result
+               }
+               for j := 0; j < indirectScopeValue.Len(); j++ {
+                       indirectValue := indirect(indirectScopeValue.Index(j))
+                       valueString := toString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames))
+                       if result, found := foreignValuesToResults[valueString]; found {
+                               indirectValue.FieldByName(field.Name).Set(result)
+                       }
+               }
+       } else {
+               for i := 0; i < resultsValue.Len(); i++ {
+                       result := resultsValue.Index(i)
+                       scope.Err(field.Set(result))
+               }
+       }
+}
+
+// handleHasManyPreload used to preload has many associations
+func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
+       relation := field.Relationship
+
+       // get relations's primary keys
+       primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value)
+       if len(primaryKeys) == 0 {
+               return
+       }
+
+       // preload conditions
+       preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
+
+       // find relations
+       query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys))
+       values := toQueryValues(primaryKeys)
+       if relation.PolymorphicType != "" {
+               query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName))
+               values = append(values, relation.PolymorphicValue)
+       }
+
+       results := makeSlice(field.Struct.Type)
+       scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error)
+
+       // assign find results
+       var (
+               resultsValue       = indirect(reflect.ValueOf(results))
+               indirectScopeValue = scope.IndirectValue()
+       )
+
+       if indirectScopeValue.Kind() == reflect.Slice {
+               preloadMap := make(map[string][]reflect.Value)
+               for i := 0; i < resultsValue.Len(); i++ {
+                       result := resultsValue.Index(i)
+                       foreignValues := getValueFromFields(result, relation.ForeignFieldNames)
+                       preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result)
+               }
+
+               for j := 0; j < indirectScopeValue.Len(); j++ {
+                       object := indirect(indirectScopeValue.Index(j))
+                       objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames)
+                       f := object.FieldByName(field.Name)
+                       if results, ok := preloadMap[toString(objectRealValue)]; ok {
+                               f.Set(reflect.Append(f, results...))
+                       } else {
+                               f.Set(reflect.MakeSlice(f.Type(), 0, 0))
+                       }
+               }
+       } else {
+               scope.Err(field.Set(resultsValue))
+       }
+}
+
+// handleBelongsToPreload used to preload belongs to associations
+func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
+       relation := field.Relationship
+
+       // preload conditions
+       preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
+
+       // get relations's primary keys
+       primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value)
+       if len(primaryKeys) == 0 {
+               return
+       }
+
+       // find relations
+       results := makeSlice(field.Struct.Type)
+       scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error)
+
+       // assign find results
+       var (
+               resultsValue       = indirect(reflect.ValueOf(results))
+               indirectScopeValue = scope.IndirectValue()
+       )
+
+       foreignFieldToObjects := make(map[string][]*reflect.Value)
+       if indirectScopeValue.Kind() == reflect.Slice {
+               for j := 0; j < indirectScopeValue.Len(); j++ {
+                       object := indirect(indirectScopeValue.Index(j))
+                       valueString := toString(getValueFromFields(object, relation.ForeignFieldNames))
+                       foreignFieldToObjects[valueString] = append(foreignFieldToObjects[valueString], &object)
+               }
+       }
+
+       for i := 0; i < resultsValue.Len(); i++ {
+               result := resultsValue.Index(i)
+               if indirectScopeValue.Kind() == reflect.Slice {
+                       valueString := toString(getValueFromFields(result, relation.AssociationForeignFieldNames))
+                       if objects, found := foreignFieldToObjects[valueString]; found {
+                               for _, object := range objects {
+                                       object.FieldByName(field.Name).Set(result)
+                               }
+                       }
+               } else {
+                       scope.Err(field.Set(result))
+               }
+       }
+}
+
+// handleManyToManyPreload used to preload many to many associations
+func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) {
+       var (
+               relation         = field.Relationship
+               joinTableHandler = relation.JoinTableHandler
+               fieldType        = field.Struct.Type.Elem()
+               foreignKeyValue  interface{}
+               foreignKeyType   = reflect.ValueOf(&foreignKeyValue).Type()
+               linkHash         = map[string][]reflect.Value{}
+               isPtr            bool
+       )
+
+       if fieldType.Kind() == reflect.Ptr {
+               isPtr = true
+               fieldType = fieldType.Elem()
+       }
+
+       var sourceKeys = []string{}
+       for _, key := range joinTableHandler.SourceForeignKeys() {
+               sourceKeys = append(sourceKeys, key.DBName)
+       }
+
+       // preload conditions
+       preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
+
+       // generate query with join table
+       newScope := scope.New(reflect.New(fieldType).Interface())
+       preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value)
+
+       if len(preloadDB.search.selects) == 0 {
+               preloadDB = preloadDB.Select("*")
+       }
+
+       preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
+
+       // preload inline conditions
+       if len(preloadConditions) > 0 {
+               preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
+       }
+
+       rows, err := preloadDB.Rows()
+
+       if scope.Err(err) != nil {
+               return
+       }
+       defer rows.Close()
+
+       columns, _ := rows.Columns()
+       for rows.Next() {
+               var (
+                       elem   = reflect.New(fieldType).Elem()
+                       fields = scope.New(elem.Addr().Interface()).Fields()
+               )
+
+               // register foreign keys in join tables
+               var joinTableFields []*Field
+               for _, sourceKey := range sourceKeys {
+                       joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()})
+               }
+
+               scope.scan(rows, columns, append(fields, joinTableFields...))
+
+               scope.New(elem.Addr().Interface()).
+                       InstanceSet("gorm:skip_query_callback", true).
+                       callCallbacks(scope.db.parent.callbacks.queries)
+
+               var foreignKeys = make([]interface{}, len(sourceKeys))
+               // generate hashed forkey keys in join table
+               for idx, joinTableField := range joinTableFields {
+                       if !joinTableField.Field.IsNil() {
+                               foreignKeys[idx] = joinTableField.Field.Elem().Interface()
+                       }
+               }
+               hashedSourceKeys := toString(foreignKeys)
+
+               if isPtr {
+                       linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr())
+               } else {
+                       linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem)
+               }
+       }
+
+       if err := rows.Err(); err != nil {
+               scope.Err(err)
+       }
+
+       // assign find results
+       var (
+               indirectScopeValue = scope.IndirectValue()
+               fieldsSourceMap    = map[string][]reflect.Value{}
+               foreignFieldNames  = []string{}
+       )
+
+       for _, dbName := range relation.ForeignFieldNames {
+               if field, ok := scope.FieldByName(dbName); ok {
+                       foreignFieldNames = append(foreignFieldNames, field.Name)
+               }
+       }
+
+       if indirectScopeValue.Kind() == reflect.Slice {
+               for j := 0; j < indirectScopeValue.Len(); j++ {
+                       object := indirect(indirectScopeValue.Index(j))
+                       key := toString(getValueFromFields(object, foreignFieldNames))
+                       fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name))
+               }
+       } else if indirectScopeValue.IsValid() {
+               key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames))
+               fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name))
+       }
+       for source, link := range linkHash {
+               for i, field := range fieldsSourceMap[source] {
+                       //If not 0 this means Value is a pointer and we already added preloaded models to it
+                       if fieldsSourceMap[source][i].Len() != 0 {
+                               continue
+                       }
+                       field.Set(reflect.Append(fieldsSourceMap[source][i], link...))
+               }
+
+       }
+}