From 9a70a12a341b8e7773efc661693e4b7e3199a4bd Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 31 Oct 2022 23:51:14 +0800 Subject: [PATCH] Merge db.Iterate and IterateObjects (#21641) These two functions are similiar, merge them. --- cmd/migrate_storage.go | 12 +++++----- models/db/context.go | 10 -------- models/db/iterate.go | 21 ++++++++++------- models/db/iterate_test.go | 44 +++++++++++++++++++++++++++++++++++ modules/doctor/breaking.go | 5 ++-- modules/doctor/fix16961.go | 5 +--- modules/doctor/mergebase.go | 5 ++-- modules/doctor/misc.go | 5 ++-- routers/private/mail.go | 3 ++- services/repository/avatar.go | 2 +- services/repository/check.go | 12 +++------- services/repository/hooks.go | 4 +--- 12 files changed, 77 insertions(+), 51 deletions(-) create mode 100644 models/db/iterate_test.go diff --git a/cmd/migrate_storage.go b/cmd/migrate_storage.go index a283f91401839..b6af5b96e87fc 100644 --- a/cmd/migrate_storage.go +++ b/cmd/migrate_storage.go @@ -83,35 +83,35 @@ var CmdMigrateStorage = cli.Command{ } func migrateAttachments(ctx context.Context, dstStorage storage.ObjectStorage) error { - return db.IterateObjects(ctx, func(attach *repo_model.Attachment) error { + return db.Iterate(ctx, nil, func(ctx context.Context, attach *repo_model.Attachment) error { _, err := storage.Copy(dstStorage, attach.RelativePath(), storage.Attachments, attach.RelativePath()) return err }) } func migrateLFS(ctx context.Context, dstStorage storage.ObjectStorage) error { - return db.IterateObjects(ctx, func(mo *git_model.LFSMetaObject) error { + return db.Iterate(ctx, nil, func(ctx context.Context, mo *git_model.LFSMetaObject) error { _, err := storage.Copy(dstStorage, mo.RelativePath(), storage.LFS, mo.RelativePath()) return err }) } func migrateAvatars(ctx context.Context, dstStorage storage.ObjectStorage) error { - return db.IterateObjects(ctx, func(user *user_model.User) error { + return db.Iterate(ctx, nil, func(ctx context.Context, user *user_model.User) error { _, err := storage.Copy(dstStorage, user.CustomAvatarRelativePath(), storage.Avatars, user.CustomAvatarRelativePath()) return err }) } func migrateRepoAvatars(ctx context.Context, dstStorage storage.ObjectStorage) error { - return db.IterateObjects(ctx, func(repo *repo_model.Repository) error { + return db.Iterate(ctx, nil, func(ctx context.Context, repo *repo_model.Repository) error { _, err := storage.Copy(dstStorage, repo.CustomAvatarRelativePath(), storage.RepoAvatars, repo.CustomAvatarRelativePath()) return err }) } func migrateRepoArchivers(ctx context.Context, dstStorage storage.ObjectStorage) error { - return db.IterateObjects(ctx, func(archiver *repo_model.RepoArchiver) error { + return db.Iterate(ctx, nil, func(ctx context.Context, archiver *repo_model.RepoArchiver) error { p := archiver.RelativePath() _, err := storage.Copy(dstStorage, p, storage.RepoArchives, p) return err @@ -119,7 +119,7 @@ func migrateRepoArchivers(ctx context.Context, dstStorage storage.ObjectStorage) } func migratePackages(ctx context.Context, dstStorage storage.ObjectStorage) error { - return db.IterateObjects(ctx, func(pb *packages_model.PackageBlob) error { + return db.Iterate(ctx, nil, func(ctx context.Context, pb *packages_model.PackageBlob) error { p := packages_module.KeyToRelativePath(packages_module.BlobHash256Key(pb.HashSHA256)) _, err := storage.Copy(dstStorage, p, storage.Packages, p) return err diff --git a/models/db/context.go b/models/db/context.go index 4fd35200cf71c..e90780e4e93e9 100644 --- a/models/db/context.go +++ b/models/db/context.go @@ -8,9 +8,6 @@ import ( "context" "database/sql" - "code.gitea.io/gitea/modules/setting" - - "xorm.io/builder" "xorm.io/xorm/schemas" ) @@ -121,13 +118,6 @@ func WithTx(f func(ctx context.Context) error, stdCtx ...context.Context) error return sess.Commit() } -// Iterate iterates the databases and doing something -func Iterate(ctx context.Context, tableBean interface{}, cond builder.Cond, fun func(idx int, bean interface{}) error) error { - return GetEngine(ctx).Where(cond). - BufferSize(setting.Database.IterateBufferSize). - Iterate(tableBean, fun) -} - // Insert inserts records into database func Insert(ctx context.Context, beans ...interface{}) error { _, err := GetEngine(ctx).Insert(beans...) diff --git a/models/db/iterate.go b/models/db/iterate.go index 3d4fa06eeb96e..cbd2feed280e8 100644 --- a/models/db/iterate.go +++ b/models/db/iterate.go @@ -8,25 +8,30 @@ import ( "context" "code.gitea.io/gitea/modules/setting" + + "xorm.io/builder" ) -// IterateObjects iterate all the Bean object -func IterateObjects[Object any](ctx context.Context, f func(repo *Object) error) error { +// Iterate iterate all the Bean object +func Iterate[Bean any](ctx context.Context, cond builder.Cond, f func(ctx context.Context, bean *Bean) error) error { var start int batchSize := setting.Database.IterateBufferSize sess := GetEngine(ctx) for { - repos := make([]*Object, 0, batchSize) - if err := sess.Limit(batchSize, start).Find(&repos); err != nil { + beans := make([]*Bean, 0, batchSize) + if cond != nil { + sess = sess.Where(cond) + } + if err := sess.Limit(batchSize, start).Find(&beans); err != nil { return err } - if len(repos) == 0 { + if len(beans) == 0 { return nil } - start += len(repos) + start += len(beans) - for _, repo := range repos { - if err := f(repo); err != nil { + for _, bean := range beans { + if err := f(ctx, bean); err != nil { return err } } diff --git a/models/db/iterate_test.go b/models/db/iterate_test.go new file mode 100644 index 0000000000000..5d03a6e9ceab8 --- /dev/null +++ b/models/db/iterate_test.go @@ -0,0 +1,44 @@ +// Copyright 2022 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package db_test + +import ( + "context" + "testing" + + "code.gitea.io/gitea/models/db" + repo_model "code.gitea.io/gitea/models/repo" + "code.gitea.io/gitea/models/unittest" + + "github.com/stretchr/testify/assert" +) + +func TestIterate(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + xe := unittest.GetXORMEngine() + assert.NoError(t, xe.Sync(&repo_model.RepoUnit{})) + + var repoCnt int + err := db.Iterate(db.DefaultContext, nil, func(ctx context.Context, repo *repo_model.RepoUnit) error { + repoCnt++ + return nil + }) + assert.NoError(t, err) + assert.EqualValues(t, 79, repoCnt) + + err = db.Iterate(db.DefaultContext, nil, func(ctx context.Context, repoUnit *repo_model.RepoUnit) error { + reopUnit2 := repo_model.RepoUnit{ID: repoUnit.ID} + has, err := db.GetByBean(ctx, &reopUnit2) + if err != nil { + return err + } else if !has { + return db.ErrNotExist{Resource: "repo_unit", ID: repoUnit.ID} + } + assert.EqualValues(t, repoUnit.RepoID, repoUnit.RepoID) + assert.EqualValues(t, repoUnit.CreatedUnix, repoUnit.CreatedUnix) + return nil + }) + assert.NoError(t, err) +} diff --git a/modules/doctor/breaking.go b/modules/doctor/breaking.go index 51122d9a61a5c..474997acd855f 100644 --- a/modules/doctor/breaking.go +++ b/modules/doctor/breaking.go @@ -18,10 +18,9 @@ import ( func iterateUserAccounts(ctx context.Context, each func(*user.User) error) error { err := db.Iterate( ctx, - new(user.User), builder.Gt{"id": 0}, - func(idx int, bean interface{}) error { - return each(bean.(*user.User)) + func(ctx context.Context, bean *user.User) error { + return each(bean) }, ) return err diff --git a/modules/doctor/fix16961.go b/modules/doctor/fix16961.go index 307cfcd9ff877..d9f895739f2e7 100644 --- a/modules/doctor/fix16961.go +++ b/modules/doctor/fix16961.go @@ -269,13 +269,10 @@ func fixBrokenRepoUnits16961(ctx context.Context, logger log.Logger, autofix boo err := db.Iterate( ctx, - new(RepoUnit), builder.Gt{ "id": 0, }, - func(idx int, bean interface{}) error { - unit := bean.(*RepoUnit) - + func(ctx context.Context, unit *RepoUnit) error { bs := unit.Config repoUnit := &repo_model.RepoUnit{ ID: unit.ID, diff --git a/modules/doctor/mergebase.go b/modules/doctor/mergebase.go index b279c453f7995..9f5e336461a07 100644 --- a/modules/doctor/mergebase.go +++ b/modules/doctor/mergebase.go @@ -21,10 +21,9 @@ import ( func iteratePRs(ctx context.Context, repo *repo_model.Repository, each func(*repo_model.Repository, *issues_model.PullRequest) error) error { return db.Iterate( ctx, - new(issues_model.PullRequest), builder.Eq{"base_repo_id": repo.ID}, - func(idx int, bean interface{}) error { - return each(repo, bean.(*issues_model.PullRequest)) + func(ctx context.Context, bean *issues_model.PullRequest) error { + return each(repo, bean) }, ) } diff --git a/modules/doctor/misc.go b/modules/doctor/misc.go index 277d66a177725..6f0e066f54d13 100644 --- a/modules/doctor/misc.go +++ b/modules/doctor/misc.go @@ -30,10 +30,9 @@ import ( func iterateRepositories(ctx context.Context, each func(*repo_model.Repository) error) error { err := db.Iterate( ctx, - new(repo_model.Repository), builder.Gt{"id": 0}, - func(idx int, bean interface{}) error { - return each(bean.(*repo_model.Repository)) + func(ctx context.Context, bean *repo_model.Repository) error { + return each(bean) }, ) return err diff --git a/routers/private/mail.go b/routers/private/mail.go index e858992aee13b..255e1d901dfe7 100644 --- a/routers/private/mail.go +++ b/routers/private/mail.go @@ -5,6 +5,7 @@ package private import ( + stdCtx "context" "fmt" "net/http" "strconv" @@ -60,7 +61,7 @@ func SendEmail(ctx *context.PrivateContext) { } } } else { - err := db.IterateObjects(ctx, func(user *user_model.User) error { + err := db.Iterate(ctx, nil, func(ctx stdCtx.Context, user *user_model.User) error { if len(user.Email) > 0 && user.IsActive { emails = append(emails, user.Email) } diff --git a/services/repository/avatar.go b/services/repository/avatar.go index b80a8fb77588e..1cf9e869c0747 100644 --- a/services/repository/avatar.go +++ b/services/repository/avatar.go @@ -96,7 +96,7 @@ func DeleteAvatar(repo *repo_model.Repository) error { // RemoveRandomAvatars removes the randomly generated avatars that were created for repositories func RemoveRandomAvatars(ctx context.Context) error { - return db.IterateObjects(ctx, func(repository *repo_model.Repository) error { + return db.Iterate(ctx, nil, func(ctx context.Context, repository *repo_model.Repository) error { select { case <-ctx.Done(): return db.ErrCancelledf("before random avatars removed for %s", repository.FullName()) diff --git a/services/repository/check.go b/services/repository/check.go index 5529a61b396f4..5725f540b0cca 100644 --- a/services/repository/check.go +++ b/services/repository/check.go @@ -29,10 +29,8 @@ func GitFsck(ctx context.Context, timeout time.Duration, args []git.CmdArg) erro if err := db.Iterate( ctx, - new(repo_model.Repository), builder.Expr("id>0 AND is_fsck_enabled=?", true), - func(idx int, bean interface{}) error { - repo := bean.(*repo_model.Repository) + func(ctx context.Context, repo *repo_model.Repository) error { select { case <-ctx.Done(): return db.ErrCancelledf("before fsck of %s", repo.FullName()) @@ -64,10 +62,8 @@ func GitGcRepos(ctx context.Context, timeout time.Duration, args ...git.CmdArg) if err := db.Iterate( ctx, - new(repo_model.Repository), builder.Gt{"id": 0}, - func(idx int, bean interface{}) error { - repo := bean.(*repo_model.Repository) + func(ctx context.Context, repo *repo_model.Repository) error { select { case <-ctx.Done(): return db.ErrCancelledf("before GC of %s", repo.FullName()) @@ -113,10 +109,8 @@ func gatherMissingRepoRecords(ctx context.Context) ([]*repo_model.Repository, er repos := make([]*repo_model.Repository, 0, 10) if err := db.Iterate( ctx, - new(repo_model.Repository), builder.Gt{"id": 0}, - func(idx int, bean interface{}) error { - repo := bean.(*repo_model.Repository) + func(ctx context.Context, repo *repo_model.Repository) error { select { case <-ctx.Done(): return db.ErrCancelledf("during gathering missing repo records before checking %s", repo.FullName()) diff --git a/services/repository/hooks.go b/services/repository/hooks.go index d326cd26b168f..d29384e012486 100644 --- a/services/repository/hooks.go +++ b/services/repository/hooks.go @@ -25,10 +25,8 @@ func SyncRepositoryHooks(ctx context.Context) error { if err := db.Iterate( ctx, - new(repo_model.Repository), builder.Gt{"id": 0}, - func(idx int, bean interface{}) error { - repo := bean.(*repo_model.Repository) + func(ctx context.Context, repo *repo_model.Repository) error { select { case <-ctx.Done(): return db.ErrCancelledf("before sync repository hooks for %s", repo.FullName())