Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wxiaoguang committed Dec 28, 2024
1 parent ce7d574 commit 67d2df7
Showing 1 changed file with 22 additions and 21 deletions.
43 changes: 22 additions & 21 deletions models/unittest/fixtures_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,16 @@ func (f *fixturesLoader) prepareFieldValue(v any) any {
return v
}

func (f *fixturesLoader) mssqlTableHasIdentityColumn(db *sql.DB, tableName string) (bool, error) {
row := db.QueryRow(`SELECT COUNT(*) FROM sys.identity_columns WHERE OBJECT_ID = OBJECT_ID(?)`, tableName)
func (f *fixturesLoader) mssqlTableHasIdentityColumn(q *sql.Tx, tableName string) (bool, error) {
row := q.QueryRow(`SELECT COUNT(*) FROM sys.identity_columns WHERE OBJECT_ID = OBJECT_ID(?)`, tableName)
var count int
if err := row.Scan(&count); err != nil {
return false, err
}
return count > 0, nil
}

func (f *fixturesLoader) loadFixtures(file string) error {
func (f *fixturesLoader) loadFixtures(tx *sql.Tx, file string) error {
data, err := os.ReadFile(file)
if err != nil {
return fmt.Errorf("failed to read file %q: %w", file, err)
Expand All @@ -57,25 +57,14 @@ func (f *fixturesLoader) loadFixtures(file string) error {

tableName, _, _ := strings.Cut(filepath.Base(file), ".")
tableNameQuoted := f.quoteObject(tableName)
_, err = f.engine.Table(tableName).Where("1=1").Delete() // sqlite3 doesn't support truncate
_, err = tx.Exec(fmt.Sprintf("DELETE FROM %s", tableNameQuoted)) // sqlite3 doesn't support truncate
if err != nil {
return err
}

goDB := f.engine.DB().DB
tx, err := goDB.Begin()
if err != nil {
return err
}
defer func() {
if tx != nil {
_ = tx.Rollback()
}
}()

switch f.engine.Dialect().URI().DBType {
case schemas.MSSQL:
hasIdentityColumn, err := f.mssqlTableHasIdentityColumn(goDB, tableName)
hasIdentityColumn, err := f.mssqlTableHasIdentityColumn(tx, tableName)
if err != nil {
return err
}
Expand All @@ -84,6 +73,7 @@ func (f *fixturesLoader) loadFixtures(file string) error {
if err != nil {
return err
}
defer func() { _, err = tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s OFF", tableNameQuoted)) }()
}
}

Expand Down Expand Up @@ -112,16 +102,20 @@ func (f *fixturesLoader) loadFixtures(file string) error {
sqlBuf = sqlBuf[:0]
sqlArguments = sqlArguments[:0]
}
err = tx.Commit()
tx = nil
return err
return nil
}

func (f *fixturesLoader) Load() error {
goDB := f.engine.DB().DB

switch f.engine.Dialect().URI().DBType {
case schemas.SQLITE:
f.quoteObject = func(s string) string { return fmt.Sprintf(`"%s"`, s) }
f.paramPlaceholder = func(idx int) string { return "?" }
if _, err := goDB.Exec("PRAGMA defer_foreign_keys = ON"); err != nil {
return err
}
defer func() { _, _ = goDB.Exec("PRAGMA defer_foreign_keys = OFF") }()
case schemas.POSTGRES:
f.quoteObject = func(s string) string { return fmt.Sprintf(`"%s"`, s) }
f.paramPlaceholder = func(idx int) string { return fmt.Sprintf(`$%d`, idx) }
Expand All @@ -141,13 +135,20 @@ func (f *fixturesLoader) Load() error {
f.opts.Files = append(f.opts.Files, e.Name())
}
}

tx, err := goDB.Begin()
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()

for _, file := range f.opts.Files {
if !filepath.IsAbs(file) {
file = filepath.Join(f.opts.Dir, file)
}
if err := f.loadFixtures(file); err != nil {
if err := f.loadFixtures(tx, file); err != nil {
return fmt.Errorf("failed to load fixtures from %s: %w", file, err)
}
}
return nil
return tx.Commit()
}

0 comments on commit 67d2df7

Please sign in to comment.