--- /dev/null
+package gorm
+
+import (
+ "errors"
+ "fmt"
+ "reflect"
+ "strings"
+)
+
+// JoinTableHandlerInterface is an interface for how to handle many2many relations
+type JoinTableHandlerInterface interface {
+ // initialize join table handler
+ Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
+ // Table return join table's table name
+ Table(db *DB) string
+ // Add create relationship in join table for source and destination
+ Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
+ // Delete delete relationship in join table for sources
+ Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
+ // JoinWith query with `Join` conditions
+ JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
+ // SourceForeignKeys return source foreign keys
+ SourceForeignKeys() []JoinTableForeignKey
+ // DestinationForeignKeys return destination foreign keys
+ DestinationForeignKeys() []JoinTableForeignKey
+}
+
+// JoinTableForeignKey join table foreign key struct
+type JoinTableForeignKey struct {
+ DBName string
+ AssociationDBName string
+}
+
+// JoinTableSource is a struct that contains model type and foreign keys
+type JoinTableSource struct {
+ ModelType reflect.Type
+ ForeignKeys []JoinTableForeignKey
+}
+
+// JoinTableHandler default join table handler
+type JoinTableHandler struct {
+ TableName string `sql:"-"`
+ Source JoinTableSource `sql:"-"`
+ Destination JoinTableSource `sql:"-"`
+}
+
+// SourceForeignKeys return source foreign keys
+func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
+ return s.Source.ForeignKeys
+}
+
+// DestinationForeignKeys return destination foreign keys
+func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
+ return s.Destination.ForeignKeys
+}
+
+// Setup initialize a default join table handler
+func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
+ s.TableName = tableName
+
+ s.Source = JoinTableSource{ModelType: source}
+ s.Source.ForeignKeys = []JoinTableForeignKey{}
+ for idx, dbName := range relationship.ForeignFieldNames {
+ s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{
+ DBName: relationship.ForeignDBNames[idx],
+ AssociationDBName: dbName,
+ })
+ }
+
+ s.Destination = JoinTableSource{ModelType: destination}
+ s.Destination.ForeignKeys = []JoinTableForeignKey{}
+ for idx, dbName := range relationship.AssociationForeignFieldNames {
+ s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{
+ DBName: relationship.AssociationForeignDBNames[idx],
+ AssociationDBName: dbName,
+ })
+ }
+}
+
+// Table return join table's table name
+func (s JoinTableHandler) Table(db *DB) string {
+ return DefaultTableNameHandler(db, s.TableName)
+}
+
+func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) {
+ for _, source := range sources {
+ scope := db.NewScope(source)
+ modelType := scope.GetModelStruct().ModelType
+
+ for _, joinTableSource := range joinTableSources {
+ if joinTableSource.ModelType == modelType {
+ for _, foreignKey := range joinTableSource.ForeignKeys {
+ if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
+ conditionMap[foreignKey.DBName] = field.Field.Interface()
+ }
+ }
+ break
+ }
+ }
+ }
+}
+
+// Add create relationship in join table for source and destination
+func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error {
+ var (
+ scope = db.NewScope("")
+ conditionMap = map[string]interface{}{}
+ )
+
+ // Update condition map for source
+ s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source)
+
+ // Update condition map for destination
+ s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination)
+
+ var assignColumns, binVars, conditions []string
+ var values []interface{}
+ for key, value := range conditionMap {
+ assignColumns = append(assignColumns, scope.Quote(key))
+ binVars = append(binVars, `?`)
+ conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
+ values = append(values, value)
+ }
+
+ for _, value := range values {
+ values = append(values, value)
+ }
+
+ quotedTable := scope.Quote(handler.Table(db))
+ sql := fmt.Sprintf(
+ "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)",
+ quotedTable,
+ strings.Join(assignColumns, ","),
+ strings.Join(binVars, ","),
+ scope.Dialect().SelectFromDummyTable(),
+ quotedTable,
+ strings.Join(conditions, " AND "),
+ )
+
+ return db.Exec(sql, values...).Error
+}
+
+// Delete delete relationship in join table for sources
+func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
+ var (
+ scope = db.NewScope(nil)
+ conditions []string
+ values []interface{}
+ conditionMap = map[string]interface{}{}
+ )
+
+ s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...)
+
+ for key, value := range conditionMap {
+ conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
+ values = append(values, value)
+ }
+
+ return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
+}
+
+// JoinWith query with `Join` conditions
+func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
+ var (
+ scope = db.NewScope(source)
+ tableName = handler.Table(db)
+ quotedTableName = scope.Quote(tableName)
+ joinConditions []string
+ values []interface{}
+ )
+
+ if s.Source.ModelType == scope.GetModelStruct().ModelType {
+ destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
+ for _, foreignKey := range s.Destination.ForeignKeys {
+ joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTableName, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
+ }
+
+ var foreignDBNames []string
+ var foreignFieldNames []string
+
+ for _, foreignKey := range s.Source.ForeignKeys {
+ foreignDBNames = append(foreignDBNames, foreignKey.DBName)
+ if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok {
+ foreignFieldNames = append(foreignFieldNames, field.Name)
+ }
+ }
+
+ foreignFieldValues := scope.getColumnAsArray(foreignFieldNames, scope.Value)
+
+ var condString string
+ if len(foreignFieldValues) > 0 {
+ var quotedForeignDBNames []string
+ for _, dbName := range foreignDBNames {
+ quotedForeignDBNames = append(quotedForeignDBNames, tableName+"."+dbName)
+ }
+
+ condString = fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, quotedForeignDBNames), toQueryMarks(foreignFieldValues))
+
+ keys := scope.getColumnAsArray(foreignFieldNames, scope.Value)
+ values = append(values, toQueryValues(keys))
+ } else {
+ condString = fmt.Sprintf("1 <> 1")
+ }
+
+ return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTableName, strings.Join(joinConditions, " AND "))).
+ Where(condString, toQueryValues(foreignFieldValues)...)
+ }
+
+ db.Error = errors.New("wrong source type for join table handler")
+ return db
+}