diff --git a/tests/connection_test.go b/tests/connection_test.go index 7e2cd7b..e60d8ed 100644 --- a/tests/connection_test.go +++ b/tests/connection_test.go @@ -40,6 +40,9 @@ package tests import ( "fmt" + "log" + "os" + "sync" "testing" "gorm.io/gorm" @@ -65,3 +68,91 @@ func TestWithSingleConnection(t *testing.T) { t.Errorf("WithSingleConnection() method should get correct value, expect: %v, got %v", expectedString, actualString) } } + +func TestConnectionWithInvalidQuery(t *testing.T) { + err := DB.Connection(func(tx *gorm.DB) error { + return tx.Exec("SELECT * FROM non_existent_table").Error + }) + if err == nil { + t.Fatalf("Expected error for invalid query in Connection, got nil") + } +} + +func TestMultipleSequentialConnections(t *testing.T) { + for i := 0; i < 20; i++ { + var val int + err := DB.Connection(func(tx *gorm.DB) error { + return tx.Raw("SELECT 1 FROM dual").Scan(&val).Error + }) + if err != nil { + t.Fatalf("Sequential Connection #%d failed: %v", i+1, err) + } + if val != 1 { + t.Fatalf("Sequential Connection #%d got wrong result: %v", i+1, val) + } + } +} + +func TestConnectionAfterDBClose(t *testing.T) { + sqlDB, err := DB.DB() + if err != nil { + t.Fatalf("DB.DB() should not fail, got: %v", err) + } + err = sqlDB.Close() + if err != nil { + t.Fatalf("sqlDB.Close() failed: %v", err) + } + cerr := DB.Connection(func(tx *gorm.DB) error { + var v int + return tx.Raw("SELECT 1 FROM dual").Scan(&v).Error + }) + if cerr == nil { + t.Fatalf("Expected error when calling Connection after DB closed, got nil") + } + if DB, err = OpenTestConnection(&gorm.Config{Logger: newLogger}); err != nil { + log.Printf("failed to connect database, got error %v", err) + os.Exit(1) + } +} + +func TestConnectionHandlesPanic(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatalf("Expected panic inside Connection, but none occurred") + } + }() + DB.Connection(func(tx *gorm.DB) error { + panic("panic in connection callback") + }) + t.Fatalf("Should have panicked inside connection callback") +} + +func TestConcurrentConnections(t *testing.T) { + const goroutines = 10 + var wg sync.WaitGroup + wg.Add(goroutines) + errChan := make(chan error, goroutines) + + for i := 0; i < goroutines; i++ { + go func(i int) { + defer wg.Done() + var val int + err := DB.Connection(func(tx *gorm.DB) error { + return tx.Raw("SELECT ? FROM dual", i).Scan(&val).Error + }) + if err != nil { + errChan <- fmt.Errorf("goroutine #%d: connection err: %v", i, err) + return + } + if val != i { + errChan <- fmt.Errorf("goroutine #%d: got wrong result: %v", i, val) + } + }(i) + } + + wg.Wait() + close(errChan) + for err := range errChan { + t.Error(err) + } +} diff --git a/tests/passed-tests.txt b/tests/passed-tests.txt index 1c4eaf2..cce7342 100644 --- a/tests/passed-tests.txt +++ b/tests/passed-tests.txt @@ -38,6 +38,12 @@ TestPluginCallbacks TestCallbacksGet TestCallbacksRemove TestWithSingleConnection +TestConnectionWithInvalidQuery +TestNestedConnection +TestMultipleSequentialConnections +TestConnectionAfterDBClose +TestConnectionHandlesPanic +TestConcurrentConnections TestCountWithGroup TestCount TestCountOnEmptyTable