OSDN Git Service

feat: init cross_tx keepers (#146)
[bytom/vapor.git] / vendor / github.com / jinzhu / gorm / main_test.go
1 package gorm_test
2
3 import (
4         "database/sql"
5         "database/sql/driver"
6         "fmt"
7         "os"
8         "path/filepath"
9         "reflect"
10         "strconv"
11         "strings"
12         "testing"
13         "time"
14
15         "github.com/erikstmartin/go-testdb"
16         "github.com/jinzhu/gorm"
17         _ "github.com/jinzhu/gorm/dialects/mssql"
18         _ "github.com/jinzhu/gorm/dialects/mysql"
19         "github.com/jinzhu/gorm/dialects/postgres"
20         _ "github.com/jinzhu/gorm/dialects/sqlite"
21         "github.com/jinzhu/now"
22 )
23
24 var (
25         DB                 *gorm.DB
26         t1, t2, t3, t4, t5 time.Time
27 )
28
29 func init() {
30         var err error
31
32         if DB, err = OpenTestConnection(); err != nil {
33                 panic(fmt.Sprintf("No error should happen when connecting to test database, but got err=%+v", err))
34         }
35
36         runMigration()
37 }
38
39 func OpenTestConnection() (db *gorm.DB, err error) {
40         dbDSN := os.Getenv("GORM_DSN")
41         switch os.Getenv("GORM_DIALECT") {
42         case "mysql":
43                 fmt.Println("testing mysql...")
44                 if dbDSN == "" {
45                         dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True"
46                 }
47                 db, err = gorm.Open("mysql", dbDSN)
48         case "postgres":
49                 fmt.Println("testing postgres...")
50                 if dbDSN == "" {
51                         dbDSN = "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable"
52                 }
53                 db, err = gorm.Open("postgres", dbDSN)
54         case "mssql":
55                 // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';
56                 // CREATE DATABASE gorm;
57                 // USE gorm;
58                 // CREATE USER gorm FROM LOGIN gorm;
59                 // sp_changedbowner 'gorm';
60                 fmt.Println("testing mssql...")
61                 if dbDSN == "" {
62                         dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"
63                 }
64                 db, err = gorm.Open("mssql", dbDSN)
65         default:
66                 fmt.Println("testing sqlite3...")
67                 db, err = gorm.Open("sqlite3", filepath.Join(os.TempDir(), "gorm.db"))
68         }
69
70         // db.SetLogger(Logger{log.New(os.Stdout, "\r\n", 0)})
71         // db.SetLogger(log.New(os.Stdout, "\r\n", 0))
72         if debug := os.Getenv("DEBUG"); debug == "true" {
73                 db.LogMode(true)
74         } else if debug == "false" {
75                 db.LogMode(false)
76         }
77
78         db.DB().SetMaxIdleConns(10)
79
80         return
81 }
82
83 func TestOpen_ReturnsError_WithBadArgs(t *testing.T) {
84         stringRef := "foo"
85         testCases := []interface{}{42, time.Now(), &stringRef}
86         for _, tc := range testCases {
87                 t.Run(fmt.Sprintf("%v", tc), func(t *testing.T) {
88                         _, err := gorm.Open("postgresql", tc)
89                         if err == nil {
90                                 t.Error("Should got error with invalid database source")
91                         }
92                         if !strings.HasPrefix(err.Error(), "invalid database source:") {
93                                 t.Errorf("Should got error starting with \"invalid database source:\", but got %q", err.Error())
94                         }
95                 })
96         }
97 }
98
99 func TestStringPrimaryKey(t *testing.T) {
100         type UUIDStruct struct {
101                 ID   string `gorm:"primary_key"`
102                 Name string
103         }
104         DB.DropTable(&UUIDStruct{})
105         DB.AutoMigrate(&UUIDStruct{})
106
107         data := UUIDStruct{ID: "uuid", Name: "hello"}
108         if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" || data.Name != "hello" {
109                 t.Errorf("string primary key should not be populated")
110         }
111
112         data = UUIDStruct{ID: "uuid", Name: "hello world"}
113         if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" || data.Name != "hello world" {
114                 t.Errorf("string primary key should not be populated")
115         }
116 }
117
118 func TestExceptionsWithInvalidSql(t *testing.T) {
119         var columns []string
120         if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
121                 t.Errorf("Should got error with invalid SQL")
122         }
123
124         if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
125                 t.Errorf("Should got error with invalid SQL")
126         }
127
128         if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).Error == nil {
129                 t.Errorf("Should got error with invalid SQL")
130         }
131
132         var count1, count2 int64
133         DB.Model(&User{}).Count(&count1)
134         if count1 <= 0 {
135                 t.Errorf("Should find some users")
136         }
137
138         if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).Error == nil {
139                 t.Errorf("Should got error with invalid SQL")
140         }
141
142         DB.Model(&User{}).Count(&count2)
143         if count1 != count2 {
144                 t.Errorf("No user should not be deleted by invalid SQL")
145         }
146 }
147
148 func TestSetTable(t *testing.T) {
149         DB.Create(getPreparedUser("pluck_user1", "pluck_user"))
150         DB.Create(getPreparedUser("pluck_user2", "pluck_user"))
151         DB.Create(getPreparedUser("pluck_user3", "pluck_user"))
152
153         if err := DB.Table("users").Where("role = ?", "pluck_user").Pluck("age", &[]int{}).Error; err != nil {
154                 t.Error("No errors should happen if set table for pluck", err)
155         }
156
157         var users []User
158         if DB.Table("users").Find(&[]User{}).Error != nil {
159                 t.Errorf("No errors should happen if set table for find")
160         }
161
162         if DB.Table("invalid_table").Find(&users).Error == nil {
163                 t.Errorf("Should got error when table is set to an invalid table")
164         }
165
166         DB.Exec("drop table deleted_users;")
167         if DB.Table("deleted_users").CreateTable(&User{}).Error != nil {
168                 t.Errorf("Create table with specified table")
169         }
170
171         DB.Table("deleted_users").Save(&User{Name: "DeletedUser"})
172
173         var deletedUsers []User
174         DB.Table("deleted_users").Find(&deletedUsers)
175         if len(deletedUsers) != 1 {
176                 t.Errorf("Query from specified table")
177         }
178
179         DB.Save(getPreparedUser("normal_user", "reset_table"))
180         DB.Table("deleted_users").Save(getPreparedUser("deleted_user", "reset_table"))
181         var user1, user2, user3 User
182         DB.Where("role = ?", "reset_table").First(&user1).Table("deleted_users").First(&user2).Table("").First(&user3)
183         if (user1.Name != "normal_user") || (user2.Name != "deleted_user") || (user3.Name != "normal_user") {
184                 t.Errorf("unset specified table with blank string")
185         }
186 }
187
188 type Order struct {
189 }
190
191 type Cart struct {
192 }
193
194 func (c Cart) TableName() string {
195         return "shopping_cart"
196 }
197
198 func TestHasTable(t *testing.T) {
199         type Foo struct {
200                 Id    int
201                 Stuff string
202         }
203         DB.DropTable(&Foo{})
204
205         // Table should not exist at this point, HasTable should return false
206         if ok := DB.HasTable("foos"); ok {
207                 t.Errorf("Table should not exist, but does")
208         }
209         if ok := DB.HasTable(&Foo{}); ok {
210                 t.Errorf("Table should not exist, but does")
211         }
212
213         // We create the table
214         if err := DB.CreateTable(&Foo{}).Error; err != nil {
215                 t.Errorf("Table should be created")
216         }
217
218         // And now it should exits, and HasTable should return true
219         if ok := DB.HasTable("foos"); !ok {
220                 t.Errorf("Table should exist, but HasTable informs it does not")
221         }
222         if ok := DB.HasTable(&Foo{}); !ok {
223                 t.Errorf("Table should exist, but HasTable informs it does not")
224         }
225 }
226
227 func TestTableName(t *testing.T) {
228         DB := DB.Model("")
229         if DB.NewScope(Order{}).TableName() != "orders" {
230                 t.Errorf("Order's table name should be orders")
231         }
232
233         if DB.NewScope(&Order{}).TableName() != "orders" {
234                 t.Errorf("&Order's table name should be orders")
235         }
236
237         if DB.NewScope([]Order{}).TableName() != "orders" {
238                 t.Errorf("[]Order's table name should be orders")
239         }
240
241         if DB.NewScope(&[]Order{}).TableName() != "orders" {
242                 t.Errorf("&[]Order's table name should be orders")
243         }
244
245         DB.SingularTable(true)
246         if DB.NewScope(Order{}).TableName() != "order" {
247                 t.Errorf("Order's singular table name should be order")
248         }
249
250         if DB.NewScope(&Order{}).TableName() != "order" {
251                 t.Errorf("&Order's singular table name should be order")
252         }
253
254         if DB.NewScope([]Order{}).TableName() != "order" {
255                 t.Errorf("[]Order's singular table name should be order")
256         }
257
258         if DB.NewScope(&[]Order{}).TableName() != "order" {
259                 t.Errorf("&[]Order's singular table name should be order")
260         }
261
262         if DB.NewScope(&Cart{}).TableName() != "shopping_cart" {
263                 t.Errorf("&Cart's singular table name should be shopping_cart")
264         }
265
266         if DB.NewScope(Cart{}).TableName() != "shopping_cart" {
267                 t.Errorf("Cart's singular table name should be shopping_cart")
268         }
269
270         if DB.NewScope(&[]Cart{}).TableName() != "shopping_cart" {
271                 t.Errorf("&[]Cart's singular table name should be shopping_cart")
272         }
273
274         if DB.NewScope([]Cart{}).TableName() != "shopping_cart" {
275                 t.Errorf("[]Cart's singular table name should be shopping_cart")
276         }
277         DB.SingularTable(false)
278 }
279
280 func TestNullValues(t *testing.T) {
281         DB.DropTable(&NullValue{})
282         DB.AutoMigrate(&NullValue{})
283
284         if err := DB.Save(&NullValue{
285                 Name:    sql.NullString{String: "hello", Valid: true},
286                 Gender:  &sql.NullString{String: "M", Valid: true},
287                 Age:     sql.NullInt64{Int64: 18, Valid: true},
288                 Male:    sql.NullBool{Bool: true, Valid: true},
289                 Height:  sql.NullFloat64{Float64: 100.11, Valid: true},
290                 AddedAt: NullTime{Time: time.Now(), Valid: true},
291         }).Error; err != nil {
292                 t.Errorf("Not error should raise when test null value")
293         }
294
295         var nv NullValue
296         DB.First(&nv, "name = ?", "hello")
297
298         if nv.Name.String != "hello" || nv.Gender.String != "M" || nv.Age.Int64 != 18 || nv.Male.Bool != true || nv.Height.Float64 != 100.11 || nv.AddedAt.Valid != true {
299                 t.Errorf("Should be able to fetch null value")
300         }
301
302         if err := DB.Save(&NullValue{
303                 Name:    sql.NullString{String: "hello-2", Valid: true},
304                 Gender:  &sql.NullString{String: "F", Valid: true},
305                 Age:     sql.NullInt64{Int64: 18, Valid: false},
306                 Male:    sql.NullBool{Bool: true, Valid: true},
307                 Height:  sql.NullFloat64{Float64: 100.11, Valid: true},
308                 AddedAt: NullTime{Time: time.Now(), Valid: false},
309         }).Error; err != nil {
310                 t.Errorf("Not error should raise when test null value")
311         }
312
313         var nv2 NullValue
314         DB.First(&nv2, "name = ?", "hello-2")
315         if nv2.Name.String != "hello-2" || nv2.Gender.String != "F" || nv2.Age.Int64 != 0 || nv2.Male.Bool != true || nv2.Height.Float64 != 100.11 || nv2.AddedAt.Valid != false {
316                 t.Errorf("Should be able to fetch null value")
317         }
318
319         if err := DB.Save(&NullValue{
320                 Name:    sql.NullString{String: "hello-3", Valid: false},
321                 Gender:  &sql.NullString{String: "M", Valid: true},
322                 Age:     sql.NullInt64{Int64: 18, Valid: false},
323                 Male:    sql.NullBool{Bool: true, Valid: true},
324                 Height:  sql.NullFloat64{Float64: 100.11, Valid: true},
325                 AddedAt: NullTime{Time: time.Now(), Valid: false},
326         }).Error; err == nil {
327                 t.Errorf("Can't save because of name can't be null")
328         }
329 }
330
331 func TestNullValuesWithFirstOrCreate(t *testing.T) {
332         var nv1 = NullValue{
333                 Name:   sql.NullString{String: "first_or_create", Valid: true},
334                 Gender: &sql.NullString{String: "M", Valid: true},
335         }
336
337         var nv2 NullValue
338         result := DB.Where(nv1).FirstOrCreate(&nv2)
339
340         if result.RowsAffected != 1 {
341                 t.Errorf("RowsAffected should be 1 after create some record")
342         }
343
344         if result.Error != nil {
345                 t.Errorf("Should not raise any error, but got %v", result.Error)
346         }
347
348         if nv2.Name.String != "first_or_create" || nv2.Gender.String != "M" {
349                 t.Errorf("first or create with nullvalues")
350         }
351
352         if err := DB.Where(nv1).Assign(NullValue{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&nv2).Error; err != nil {
353                 t.Errorf("Should not raise any error, but got %v", err)
354         }
355
356         if nv2.Age.Int64 != 18 {
357                 t.Errorf("should update age to 18")
358         }
359 }
360
361 func TestTransaction(t *testing.T) {
362         tx := DB.Begin()
363         u := User{Name: "transcation"}
364         if err := tx.Save(&u).Error; err != nil {
365                 t.Errorf("No error should raise")
366         }
367
368         if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil {
369                 t.Errorf("Should find saved record")
370         }
371
372         if sqlTx, ok := tx.CommonDB().(*sql.Tx); !ok || sqlTx == nil {
373                 t.Errorf("Should return the underlying sql.Tx")
374         }
375
376         tx.Rollback()
377
378         if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil {
379                 t.Errorf("Should not find record after rollback")
380         }
381
382         tx2 := DB.Begin()
383         u2 := User{Name: "transcation-2"}
384         if err := tx2.Save(&u2).Error; err != nil {
385                 t.Errorf("No error should raise")
386         }
387
388         if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
389                 t.Errorf("Should find saved record")
390         }
391
392         tx2.Commit()
393
394         if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
395                 t.Errorf("Should be able to find committed record")
396         }
397 }
398
399 func TestRow(t *testing.T) {
400         user1 := User{Name: "RowUser1", Age: 1, Birthday: parseTime("2000-1-1")}
401         user2 := User{Name: "RowUser2", Age: 10, Birthday: parseTime("2010-1-1")}
402         user3 := User{Name: "RowUser3", Age: 20, Birthday: parseTime("2020-1-1")}
403         DB.Save(&user1).Save(&user2).Save(&user3)
404
405         row := DB.Table("users").Where("name = ?", user2.Name).Select("age").Row()
406         var age int64
407         row.Scan(&age)
408         if age != 10 {
409                 t.Errorf("Scan with Row")
410         }
411 }
412
413 func TestRows(t *testing.T) {
414         user1 := User{Name: "RowsUser1", Age: 1, Birthday: parseTime("2000-1-1")}
415         user2 := User{Name: "RowsUser2", Age: 10, Birthday: parseTime("2010-1-1")}
416         user3 := User{Name: "RowsUser3", Age: 20, Birthday: parseTime("2020-1-1")}
417         DB.Save(&user1).Save(&user2).Save(&user3)
418
419         rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
420         if err != nil {
421                 t.Errorf("Not error should happen, got %v", err)
422         }
423
424         count := 0
425         for rows.Next() {
426                 var name string
427                 var age int64
428                 rows.Scan(&name, &age)
429                 count++
430         }
431
432         if count != 2 {
433                 t.Errorf("Should found two records")
434         }
435 }
436
437 func TestScanRows(t *testing.T) {
438         user1 := User{Name: "ScanRowsUser1", Age: 1, Birthday: parseTime("2000-1-1")}
439         user2 := User{Name: "ScanRowsUser2", Age: 10, Birthday: parseTime("2010-1-1")}
440         user3 := User{Name: "ScanRowsUser3", Age: 20, Birthday: parseTime("2020-1-1")}
441         DB.Save(&user1).Save(&user2).Save(&user3)
442
443         rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows()
444         if err != nil {
445                 t.Errorf("Not error should happen, got %v", err)
446         }
447
448         type Result struct {
449                 Name string
450                 Age  int
451         }
452
453         var results []Result
454         for rows.Next() {
455                 var result Result
456                 if err := DB.ScanRows(rows, &result); err != nil {
457                         t.Errorf("should get no error, but got %v", err)
458                 }
459                 results = append(results, result)
460         }
461
462         if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) {
463                 t.Errorf("Should find expected results")
464         }
465 }
466
467 func TestScan(t *testing.T) {
468         user1 := User{Name: "ScanUser1", Age: 1, Birthday: parseTime("2000-1-1")}
469         user2 := User{Name: "ScanUser2", Age: 10, Birthday: parseTime("2010-1-1")}
470         user3 := User{Name: "ScanUser3", Age: 20, Birthday: parseTime("2020-1-1")}
471         DB.Save(&user1).Save(&user2).Save(&user3)
472
473         type result struct {
474                 Name string
475                 Age  int
476         }
477
478         var res result
479         DB.Table("users").Select("name, age").Where("name = ?", user3.Name).Scan(&res)
480         if res.Name != user3.Name {
481                 t.Errorf("Scan into struct should work")
482         }
483
484         var doubleAgeRes = &result{}
485         if err := DB.Table("users").Select("age + age as age").Where("name = ?", user3.Name).Scan(&doubleAgeRes).Error; err != nil {
486                 t.Errorf("Scan to pointer of pointer")
487         }
488         if doubleAgeRes.Age != res.Age*2 {
489                 t.Errorf("Scan double age as age")
490         }
491
492         var ress []result
493         DB.Table("users").Select("name, age").Where("name in (?)", []string{user2.Name, user3.Name}).Scan(&ress)
494         if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name {
495                 t.Errorf("Scan into struct map")
496         }
497 }
498
499 func TestRaw(t *testing.T) {
500         user1 := User{Name: "ExecRawSqlUser1", Age: 1, Birthday: parseTime("2000-1-1")}
501         user2 := User{Name: "ExecRawSqlUser2", Age: 10, Birthday: parseTime("2010-1-1")}
502         user3 := User{Name: "ExecRawSqlUser3", Age: 20, Birthday: parseTime("2020-1-1")}
503         DB.Save(&user1).Save(&user2).Save(&user3)
504
505         type result struct {
506                 Name  string
507                 Email string
508         }
509
510         var ress []result
511         DB.Raw("SELECT name, age FROM users WHERE name = ? or name = ?", user2.Name, user3.Name).Scan(&ress)
512         if len(ress) != 2 || ress[0].Name != user2.Name || ress[1].Name != user3.Name {
513                 t.Errorf("Raw with scan")
514         }
515
516         rows, _ := DB.Raw("select name, age from users where name = ?", user3.Name).Rows()
517         count := 0
518         for rows.Next() {
519                 count++
520         }
521         if count != 1 {
522                 t.Errorf("Raw with Rows should find one record with name 3")
523         }
524
525         DB.Exec("update users set name=? where name in (?)", "jinzhu", []string{user1.Name, user2.Name, user3.Name})
526         if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound {
527                 t.Error("Raw sql to update records")
528         }
529 }
530
531 func TestGroup(t *testing.T) {
532         rows, err := DB.Select("name").Table("users").Group("name").Rows()
533
534         if err == nil {
535                 defer rows.Close()
536                 for rows.Next() {
537                         var name string
538                         rows.Scan(&name)
539                 }
540         } else {
541                 t.Errorf("Should not raise any error")
542         }
543 }
544
545 func TestJoins(t *testing.T) {
546         var user = User{
547                 Name:       "joins",
548                 CreditCard: CreditCard{Number: "411111111111"},
549                 Emails:     []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}},
550         }
551         DB.Save(&user)
552
553         var users1 []User
554         DB.Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins").Find(&users1)
555         if len(users1) != 2 {
556                 t.Errorf("should find two users using left join")
557         }
558
559         var users2 []User
560         DB.Joins("left join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Where("name = ?", "joins").First(&users2)
561         if len(users2) != 1 {
562                 t.Errorf("should find one users using left join with conditions")
563         }
564
565         var users3 []User
566         DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "411111111111").Where("name = ?", "joins").First(&users3)
567         if len(users3) != 1 {
568                 t.Errorf("should find one users using multiple left join conditions")
569         }
570
571         var users4 []User
572         DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "422222222222").Where("name = ?", "joins").First(&users4)
573         if len(users4) != 0 {
574                 t.Errorf("should find no user when searching with unexisting credit card")
575         }
576
577         var users5 []User
578         db5 := DB.Joins("join emails on emails.user_id = users.id AND emails.email = ?", "join1@example.com").Joins("join credit_cards on credit_cards.user_id = users.id AND credit_cards.number = ?", "411111111111").Where(User{Id: 1}).Where(Email{Id: 1}).Not(Email{Id: 10}).First(&users5)
579         if db5.Error != nil {
580                 t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error())
581         }
582 }
583
584 type JoinedIds struct {
585         UserID           int64 `gorm:"column:id"`
586         BillingAddressID int64 `gorm:"column:id"`
587         EmailID          int64 `gorm:"column:id"`
588 }
589
590 func TestScanIdenticalColumnNames(t *testing.T) {
591         var user = User{
592                 Name:  "joinsIds",
593                 Email: "joinIds@example.com",
594                 BillingAddress: Address{
595                         Address1: "One Park Place",
596                 },
597                 Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}},
598         }
599         DB.Save(&user)
600
601         var users []JoinedIds
602         DB.Select("users.id, addresses.id, emails.id").Table("users").
603                 Joins("left join addresses on users.billing_address_id = addresses.id").
604                 Joins("left join emails on emails.user_id = users.id").
605                 Where("name = ?", "joinsIds").Scan(&users)
606
607         if len(users) != 2 {
608                 t.Fatal("should find two rows using left join")
609         }
610
611         if user.Id != users[0].UserID {
612                 t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[0].UserID)
613         }
614         if user.Id != users[1].UserID {
615                 t.Errorf("Expected result row to contain UserID %d, but got %d", user.Id, users[1].UserID)
616         }
617
618         if user.BillingAddressID.Int64 != users[0].BillingAddressID {
619                 t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID)
620         }
621         if user.BillingAddressID.Int64 != users[1].BillingAddressID {
622                 t.Errorf("Expected result row to contain BillingAddressID %d, but got %d", user.BillingAddressID.Int64, users[0].BillingAddressID)
623         }
624
625         if users[0].EmailID == users[1].EmailID {
626                 t.Errorf("Email ids should be unique. Got %d and %d", users[0].EmailID, users[1].EmailID)
627         }
628
629         if int64(user.Emails[0].Id) != users[0].EmailID && int64(user.Emails[1].Id) != users[0].EmailID {
630                 t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[0].EmailID)
631         }
632
633         if int64(user.Emails[0].Id) != users[1].EmailID && int64(user.Emails[1].Id) != users[1].EmailID {
634                 t.Errorf("Expected result row ID to be either %d or %d, but was %d", user.Emails[0].Id, user.Emails[1].Id, users[1].EmailID)
635         }
636 }
637
638 func TestJoinsWithSelect(t *testing.T) {
639         type result struct {
640                 Name  string
641                 Email string
642         }
643
644         user := User{
645                 Name:   "joins_with_select",
646                 Emails: []Email{{Email: "join1@example.com"}, {Email: "join2@example.com"}},
647         }
648         DB.Save(&user)
649
650         var results []result
651         DB.Table("users").Select("name, emails.email").Joins("left join emails on emails.user_id = users.id").Where("name = ?", "joins_with_select").Scan(&results)
652         if len(results) != 2 || results[0].Email != "join1@example.com" || results[1].Email != "join2@example.com" {
653                 t.Errorf("Should find all two emails with Join select")
654         }
655 }
656
657 func TestHaving(t *testing.T) {
658         rows, err := DB.Select("name, count(*) as total").Table("users").Group("name").Having("name IN (?)", []string{"2", "3"}).Rows()
659
660         if err == nil {
661                 defer rows.Close()
662                 for rows.Next() {
663                         var name string
664                         var total int64
665                         rows.Scan(&name, &total)
666
667                         if name == "2" && total != 1 {
668                                 t.Errorf("Should have one user having name 2")
669                         }
670                         if name == "3" && total != 2 {
671                                 t.Errorf("Should have two users having name 3")
672                         }
673                 }
674         } else {
675                 t.Errorf("Should not raise any error")
676         }
677 }
678
679 func TestQueryBuilderSubselectInWhere(t *testing.T) {
680         user := User{Name: "query_expr_select_ruser1", Email: "root@user1.com", Age: 32}
681         DB.Save(&user)
682         user = User{Name: "query_expr_select_ruser2", Email: "nobody@user2.com", Age: 16}
683         DB.Save(&user)
684         user = User{Name: "query_expr_select_ruser3", Email: "root@user3.com", Age: 64}
685         DB.Save(&user)
686         user = User{Name: "query_expr_select_ruser4", Email: "somebody@user3.com", Age: 128}
687         DB.Save(&user)
688
689         var users []User
690         DB.Select("*").Where("name IN (?)", DB.
691                 Select("name").Table("users").Where("name LIKE ?", "query_expr_select%").QueryExpr()).Find(&users)
692
693         if len(users) != 4 {
694                 t.Errorf("Four users should be found, instead found %d", len(users))
695         }
696
697         DB.Select("*").Where("name LIKE ?", "query_expr_select%").Where("age >= (?)", DB.
698                 Select("AVG(age)").Table("users").Where("name LIKE ?", "query_expr_select%").QueryExpr()).Find(&users)
699
700         if len(users) != 2 {
701                 t.Errorf("Two users should be found, instead found %d", len(users))
702         }
703 }
704
705 func TestQueryBuilderRawQueryWithSubquery(t *testing.T) {
706         user := User{Name: "subquery_test_user1", Age: 10}
707         DB.Save(&user)
708         user = User{Name: "subquery_test_user2", Age: 11}
709         DB.Save(&user)
710         user = User{Name: "subquery_test_user3", Age: 12}
711         DB.Save(&user)
712
713         var count int
714         err := DB.Raw("select count(*) from (?) tmp",
715                 DB.Table("users").
716                         Select("name").
717                         Where("age >= ? and name in (?)", 10, []string{"subquery_test_user1", "subquery_test_user2"}).
718                         Group("name").
719                         QueryExpr(),
720         ).Count(&count).Error
721
722         if err != nil {
723                 t.Errorf("Expected to get no errors, but got %v", err)
724         }
725         if count != 2 {
726                 t.Errorf("Row count must be 2, instead got %d", count)
727         }
728
729         err = DB.Raw("select count(*) from (?) tmp",
730                 DB.Table("users").
731                         Select("name").
732                         Where("name LIKE ?", "subquery_test%").
733                         Not("age <= ?", 10).Not("name in (?)", []string{"subquery_test_user1", "subquery_test_user2"}).
734                         Group("name").
735                         QueryExpr(),
736         ).Count(&count).Error
737
738         if err != nil {
739                 t.Errorf("Expected to get no errors, but got %v", err)
740         }
741         if count != 1 {
742                 t.Errorf("Row count must be 1, instead got %d", count)
743         }
744 }
745
746 func TestQueryBuilderSubselectInHaving(t *testing.T) {
747         user := User{Name: "query_expr_having_ruser1", Email: "root@user1.com", Age: 64}
748         DB.Save(&user)
749         user = User{Name: "query_expr_having_ruser2", Email: "root@user2.com", Age: 128}
750         DB.Save(&user)
751         user = User{Name: "query_expr_having_ruser3", Email: "root@user1.com", Age: 64}
752         DB.Save(&user)
753         user = User{Name: "query_expr_having_ruser4", Email: "root@user2.com", Age: 128}
754         DB.Save(&user)
755
756         var users []User
757         DB.Select("AVG(age) as avgage").Where("name LIKE ?", "query_expr_having_%").Group("email").Having("AVG(age) > (?)", DB.
758                 Select("AVG(age)").Where("name LIKE ?", "query_expr_having_%").Table("users").QueryExpr()).Find(&users)
759
760         if len(users) != 1 {
761                 t.Errorf("Two user group should be found, instead found %d", len(users))
762         }
763 }
764
765 func DialectHasTzSupport() bool {
766         // NB: mssql and FoundationDB do not support time zones.
767         if dialect := os.Getenv("GORM_DIALECT"); dialect == "foundation" {
768                 return false
769         }
770         return true
771 }
772
773 func TestTimeWithZone(t *testing.T) {
774         var format = "2006-01-02 15:04:05 -0700"
775         var times []time.Time
776         GMT8, _ := time.LoadLocation("Asia/Shanghai")
777         times = append(times, time.Date(2013, 02, 19, 1, 51, 49, 123456789, GMT8))
778         times = append(times, time.Date(2013, 02, 18, 17, 51, 49, 123456789, time.UTC))
779
780         for index, vtime := range times {
781                 name := "time_with_zone_" + strconv.Itoa(index)
782                 user := User{Name: name, Birthday: &vtime}
783
784                 if !DialectHasTzSupport() {
785                         // If our driver dialect doesn't support TZ's, just use UTC for everything here.
786                         utcBirthday := user.Birthday.UTC()
787                         user.Birthday = &utcBirthday
788                 }
789
790                 DB.Save(&user)
791                 expectedBirthday := "2013-02-18 17:51:49 +0000"
792                 foundBirthday := user.Birthday.UTC().Format(format)
793                 if foundBirthday != expectedBirthday {
794                         t.Errorf("User's birthday should not be changed after save for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday)
795                 }
796
797                 var findUser, findUser2, findUser3 User
798                 DB.First(&findUser, "name = ?", name)
799                 foundBirthday = findUser.Birthday.UTC().Format(format)
800                 if foundBirthday != expectedBirthday {
801                         t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday)
802                 }
803
804                 if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() {
805                         t.Errorf("User should be found")
806                 }
807
808                 if !DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(time.Minute)).First(&findUser3).RecordNotFound() {
809                         t.Errorf("User should not be found")
810                 }
811         }
812 }
813
814 func TestHstore(t *testing.T) {
815         type Details struct {
816                 Id   int64
817                 Bulk postgres.Hstore
818         }
819
820         if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" {
821                 t.Skip()
822         }
823
824         if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS hstore").Error; err != nil {
825                 fmt.Println("\033[31mHINT: Must be superuser to create hstore extension (ALTER USER gorm WITH SUPERUSER;)\033[0m")
826                 panic(fmt.Sprintf("No error should happen when create hstore extension, but got %+v", err))
827         }
828
829         DB.Exec("drop table details")
830
831         if err := DB.CreateTable(&Details{}).Error; err != nil {
832                 panic(fmt.Sprintf("No error should happen when create table, but got %+v", err))
833         }
834
835         bankAccountId, phoneNumber, opinion := "123456", "14151321232", "sharkbait"
836         bulk := map[string]*string{
837                 "bankAccountId": &bankAccountId,
838                 "phoneNumber":   &phoneNumber,
839                 "opinion":       &opinion,
840         }
841         d := Details{Bulk: bulk}
842         DB.Save(&d)
843
844         var d2 Details
845         if err := DB.First(&d2).Error; err != nil {
846                 t.Errorf("Got error when tried to fetch details: %+v", err)
847         }
848
849         for k := range bulk {
850                 if r, ok := d2.Bulk[k]; ok {
851                         if res, _ := bulk[k]; *res != *r {
852                                 t.Errorf("Details should be equal")
853                         }
854                 } else {
855                         t.Errorf("Details should be existed")
856                 }
857         }
858 }
859
860 func TestSetAndGet(t *testing.T) {
861         if value, ok := DB.Set("hello", "world").Get("hello"); !ok {
862                 t.Errorf("Should be able to get setting after set")
863         } else {
864                 if value.(string) != "world" {
865                         t.Errorf("Setted value should not be changed")
866                 }
867         }
868
869         if _, ok := DB.Get("non_existing"); ok {
870                 t.Errorf("Get non existing key should return error")
871         }
872 }
873
874 func TestCompatibilityMode(t *testing.T) {
875         DB, _ := gorm.Open("testdb", "")
876         testdb.SetQueryFunc(func(query string) (driver.Rows, error) {
877                 columns := []string{"id", "name", "age"}
878                 result := `
879                 1,Tim,20
880                 2,Joe,25
881                 3,Bob,30
882                 `
883                 return testdb.RowsFromCSVString(columns, result), nil
884         })
885
886         var users []User
887         DB.Find(&users)
888         if (users[0].Name != "Tim") || len(users) != 3 {
889                 t.Errorf("Unexcepted result returned")
890         }
891 }
892
893 func TestOpenExistingDB(t *testing.T) {
894         DB.Save(&User{Name: "jnfeinstein"})
895         dialect := os.Getenv("GORM_DIALECT")
896
897         db, err := gorm.Open(dialect, DB.DB())
898         if err != nil {
899                 t.Errorf("Should have wrapped the existing DB connection")
900         }
901
902         var user User
903         if db.Where("name = ?", "jnfeinstein").First(&user).Error == gorm.ErrRecordNotFound {
904                 t.Errorf("Should have found existing record")
905         }
906 }
907
908 func TestDdlErrors(t *testing.T) {
909         var err error
910
911         if err = DB.Close(); err != nil {
912                 t.Errorf("Closing DDL test db connection err=%s", err)
913         }
914         defer func() {
915                 // Reopen DB connection.
916                 if DB, err = OpenTestConnection(); err != nil {
917                         t.Fatalf("Failed re-opening db connection: %s", err)
918                 }
919         }()
920
921         if err := DB.Find(&User{}).Error; err == nil {
922                 t.Errorf("Expected operation on closed db to produce an error, but err was nil")
923         }
924 }
925
926 func TestOpenWithOneParameter(t *testing.T) {
927         db, err := gorm.Open("dialect")
928         if db != nil {
929                 t.Error("Open with one parameter returned non nil for db")
930         }
931         if err == nil {
932                 t.Error("Open with one parameter returned err as nil")
933         }
934 }
935
936 func TestSaveAssociations(t *testing.T) {
937         db := DB.New()
938         deltaAddressCount := 0
939         if err := db.Model(&Address{}).Count(&deltaAddressCount).Error; err != nil {
940                 t.Errorf("failed to fetch address count")
941                 t.FailNow()
942         }
943
944         placeAddress := &Address{
945                 Address1: "somewhere on earth",
946         }
947         ownerAddress1 := &Address{
948                 Address1: "near place address",
949         }
950         ownerAddress2 := &Address{
951                 Address1: "address2",
952         }
953         db.Create(placeAddress)
954
955         addressCountShouldBe := func(t *testing.T, expectedCount int) {
956                 countFromDB := 0
957                 t.Helper()
958                 err := db.Model(&Address{}).Count(&countFromDB).Error
959                 if err != nil {
960                         t.Error("failed to fetch address count")
961                 }
962                 if countFromDB != expectedCount {
963                         t.Errorf("address count mismatch: %d", countFromDB)
964                 }
965         }
966         addressCountShouldBe(t, deltaAddressCount+1)
967
968         // owner address should be created, place address should be reused
969         place1 := &Place{
970                 PlaceAddressID: placeAddress.ID,
971                 PlaceAddress:   placeAddress,
972                 OwnerAddress:   ownerAddress1,
973         }
974         err := db.Create(place1).Error
975         if err != nil {
976                 t.Errorf("failed to store place: %s", err.Error())
977         }
978         addressCountShouldBe(t, deltaAddressCount+2)
979
980         // owner address should be created again, place address should be reused
981         place2 := &Place{
982                 PlaceAddressID: placeAddress.ID,
983                 PlaceAddress: &Address{
984                         ID:       777,
985                         Address1: "address1",
986                 },
987                 OwnerAddress:   ownerAddress2,
988                 OwnerAddressID: 778,
989         }
990         err = db.Create(place2).Error
991         if err != nil {
992                 t.Errorf("failed to store place: %s", err.Error())
993         }
994         addressCountShouldBe(t, deltaAddressCount+3)
995
996         count := 0
997         db.Model(&Place{}).Where(&Place{
998                 PlaceAddressID: placeAddress.ID,
999                 OwnerAddressID: ownerAddress1.ID,
1000         }).Count(&count)
1001         if count != 1 {
1002                 t.Errorf("only one instance of (%d, %d) should be available, found: %d",
1003                         placeAddress.ID, ownerAddress1.ID, count)
1004         }
1005
1006         db.Model(&Place{}).Where(&Place{
1007                 PlaceAddressID: placeAddress.ID,
1008                 OwnerAddressID: ownerAddress2.ID,
1009         }).Count(&count)
1010         if count != 1 {
1011                 t.Errorf("only one instance of (%d, %d) should be available, found: %d",
1012                         placeAddress.ID, ownerAddress2.ID, count)
1013         }
1014
1015         db.Model(&Place{}).Where(&Place{
1016                 PlaceAddressID: placeAddress.ID,
1017         }).Count(&count)
1018         if count != 2 {
1019                 t.Errorf("two instances of (%d) should be available, found: %d",
1020                         placeAddress.ID, count)
1021         }
1022 }
1023
1024 func TestBlockGlobalUpdate(t *testing.T) {
1025         db := DB.New()
1026         db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"})
1027
1028         err := db.Model(&Toy{}).Update("OwnerType", "Human").Error
1029         if err != nil {
1030                 t.Error("Unexpected error on global update")
1031         }
1032
1033         err = db.Delete(&Toy{}).Error
1034         if err != nil {
1035                 t.Error("Unexpected error on global delete")
1036         }
1037
1038         db.BlockGlobalUpdate(true)
1039
1040         db.Create(&Toy{Name: "Stuffed Animal", OwnerType: "Nobody"})
1041
1042         err = db.Model(&Toy{}).Update("OwnerType", "Human").Error
1043         if err == nil {
1044                 t.Error("Expected error on global update")
1045         }
1046
1047         err = db.Model(&Toy{}).Where(&Toy{OwnerType: "Martian"}).Update("OwnerType", "Astronaut").Error
1048         if err != nil {
1049                 t.Error("Unxpected error on conditional update")
1050         }
1051
1052         err = db.Delete(&Toy{}).Error
1053         if err == nil {
1054                 t.Error("Expected error on global delete")
1055         }
1056         err = db.Where(&Toy{OwnerType: "Martian"}).Delete(&Toy{}).Error
1057         if err != nil {
1058                 t.Error("Unexpected error on conditional delete")
1059         }
1060 }
1061
1062 func BenchmarkGorm(b *testing.B) {
1063         b.N = 2000
1064         for x := 0; x < b.N; x++ {
1065                 e := strconv.Itoa(x) + "benchmark@example.org"
1066                 now := time.Now()
1067                 email := EmailWithIdx{Email: e, UserAgent: "pc", RegisteredAt: &now}
1068                 // Insert
1069                 DB.Save(&email)
1070                 // Query
1071                 DB.First(&EmailWithIdx{}, "email = ?", e)
1072                 // Update
1073                 DB.Model(&email).UpdateColumn("email", "new-"+e)
1074                 // Delete
1075                 DB.Delete(&email)
1076         }
1077 }
1078
1079 func BenchmarkRawSql(b *testing.B) {
1080         DB, _ := sql.Open("postgres", "user=gorm DB.ame=gorm sslmode=disable")
1081         DB.SetMaxIdleConns(10)
1082         insertSql := "INSERT INTO emails (user_id,email,user_agent,registered_at,created_at,updated_at) VALUES ($1,$2,$3,$4,$5,$6) RETURNING id"
1083         querySql := "SELECT * FROM emails WHERE email = $1 ORDER BY id LIMIT 1"
1084         updateSql := "UPDATE emails SET email = $1, updated_at = $2 WHERE id = $3"
1085         deleteSql := "DELETE FROM orders WHERE id = $1"
1086
1087         b.N = 2000
1088         for x := 0; x < b.N; x++ {
1089                 var id int64
1090                 e := strconv.Itoa(x) + "benchmark@example.org"
1091                 now := time.Now()
1092                 email := EmailWithIdx{Email: e, UserAgent: "pc", RegisteredAt: &now}
1093                 // Insert
1094                 DB.QueryRow(insertSql, email.UserId, email.Email, email.UserAgent, email.RegisteredAt, time.Now(), time.Now()).Scan(&id)
1095                 // Query
1096                 rows, _ := DB.Query(querySql, email.Email)
1097                 rows.Close()
1098                 // Update
1099                 DB.Exec(updateSql, "new-"+e, time.Now(), id)
1100                 // Delete
1101                 DB.Exec(deleteSql, id)
1102         }
1103 }
1104
1105 func parseTime(str string) *time.Time {
1106         t := now.New(time.Now().UTC()).MustParse(str)
1107         return &t
1108 }