OSDN Git Service

feat: init cross_tx keepers (#146)
[bytom/vapor.git] / vendor / github.com / jinzhu / gorm / join_table_handler.go
1 package gorm
2
3 import (
4         "errors"
5         "fmt"
6         "reflect"
7         "strings"
8 )
9
10 // JoinTableHandlerInterface is an interface for how to handle many2many relations
11 type JoinTableHandlerInterface interface {
12         // initialize join table handler
13         Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
14         // Table return join table's table name
15         Table(db *DB) string
16         // Add create relationship in join table for source and destination
17         Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
18         // Delete delete relationship in join table for sources
19         Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
20         // JoinWith query with `Join` conditions
21         JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
22         // SourceForeignKeys return source foreign keys
23         SourceForeignKeys() []JoinTableForeignKey
24         // DestinationForeignKeys return destination foreign keys
25         DestinationForeignKeys() []JoinTableForeignKey
26 }
27
28 // JoinTableForeignKey join table foreign key struct
29 type JoinTableForeignKey struct {
30         DBName            string
31         AssociationDBName string
32 }
33
34 // JoinTableSource is a struct that contains model type and foreign keys
35 type JoinTableSource struct {
36         ModelType   reflect.Type
37         ForeignKeys []JoinTableForeignKey
38 }
39
40 // JoinTableHandler default join table handler
41 type JoinTableHandler struct {
42         TableName   string          `sql:"-"`
43         Source      JoinTableSource `sql:"-"`
44         Destination JoinTableSource `sql:"-"`
45 }
46
47 // SourceForeignKeys return source foreign keys
48 func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
49         return s.Source.ForeignKeys
50 }
51
52 // DestinationForeignKeys return destination foreign keys
53 func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
54         return s.Destination.ForeignKeys
55 }
56
57 // Setup initialize a default join table handler
58 func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
59         s.TableName = tableName
60
61         s.Source = JoinTableSource{ModelType: source}
62         s.Source.ForeignKeys = []JoinTableForeignKey{}
63         for idx, dbName := range relationship.ForeignFieldNames {
64                 s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{
65                         DBName:            relationship.ForeignDBNames[idx],
66                         AssociationDBName: dbName,
67                 })
68         }
69
70         s.Destination = JoinTableSource{ModelType: destination}
71         s.Destination.ForeignKeys = []JoinTableForeignKey{}
72         for idx, dbName := range relationship.AssociationForeignFieldNames {
73                 s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{
74                         DBName:            relationship.AssociationForeignDBNames[idx],
75                         AssociationDBName: dbName,
76                 })
77         }
78 }
79
80 // Table return join table's table name
81 func (s JoinTableHandler) Table(db *DB) string {
82         return DefaultTableNameHandler(db, s.TableName)
83 }
84
85 func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) {
86         for _, source := range sources {
87                 scope := db.NewScope(source)
88                 modelType := scope.GetModelStruct().ModelType
89
90                 for _, joinTableSource := range joinTableSources {
91                         if joinTableSource.ModelType == modelType {
92                                 for _, foreignKey := range joinTableSource.ForeignKeys {
93                                         if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
94                                                 conditionMap[foreignKey.DBName] = field.Field.Interface()
95                                         }
96                                 }
97                                 break
98                         }
99                 }
100         }
101 }
102
103 // Add create relationship in join table for source and destination
104 func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
105         var (
106                 scope        = db.NewScope("")
107                 conditionMap = map[string]interface{}{}
108         )
109
110         // Update condition map for source
111         s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source)
112
113         // Update condition map for destination
114         s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination)
115
116         var assignColumns, binVars, conditions []string
117         var values []interface{}
118         for key, value := range conditionMap {
119                 assignColumns = append(assignColumns, scope.Quote(key))
120                 binVars = append(binVars, `?`)
121                 conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
122                 values = append(values, value)
123         }
124
125         for _, value := range values {
126                 values = append(values, value)
127         }
128
129         quotedTable := scope.Quote(handler.Table(db))
130         sql := fmt.Sprintf(
131                 "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)",
132                 quotedTable,
133                 strings.Join(assignColumns, ","),
134                 strings.Join(binVars, ","),
135                 scope.Dialect().SelectFromDummyTable(),
136                 quotedTable,
137                 strings.Join(conditions, " AND "),
138         )
139
140         return db.Exec(sql, values...).Error
141 }
142
143 // Delete delete relationship in join table for sources
144 func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
145         var (
146                 scope        = db.NewScope(nil)
147                 conditions   []string
148                 values       []interface{}
149                 conditionMap = map[string]interface{}{}
150         )
151
152         s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...)
153
154         for key, value := range conditionMap {
155                 conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
156                 values = append(values, value)
157         }
158
159         return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
160 }
161
162 // JoinWith query with `Join` conditions
163 func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
164         var (
165                 scope           = db.NewScope(source)
166                 tableName       = handler.Table(db)
167                 quotedTableName = scope.Quote(tableName)
168                 joinConditions  []string
169                 values          []interface{}
170         )
171
172         if s.Source.ModelType == scope.GetModelStruct().ModelType {
173                 destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
174                 for _, foreignKey := range s.Destination.ForeignKeys {
175                         joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
176                 }
177
178                 var foreignDBNames []string
179                 var foreignFieldNames []string
180
181                 for _, foreignKey := range s.Source.ForeignKeys {
182                         foreignDBNames = append(foreignDBNames, foreignKey.DBName)
183                         if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
184                                 foreignFieldNames = append(foreignFieldNames, field.Name)
185                         }
186                 }
187
188                 foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value)
189
190                 var condString string
191                 if len(foreignFieldValues) > 0 {
192                         var quotedForeignDBNames []string
193                         for _, dbName := range foreignDBNames {
194                                 quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName)
195                         }
196
197                         condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues))
198
199                         keys := scope.getColumnAsArray(foreignFieldNames, scope.Value)
200                         values = append(values, toQueryValues(keys))
201                 } else {
202                         condString = fmt.Sprintf("1 <> 1")
203                 }
204
205                 return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))).
206                         Where(condString, toQueryValues(foreignFieldValues)...)
207         }
208
209         db.Error = errors.New("wrong source type for join table handler")
210         return db
211 }