Skip to content

Commit b167f35

Browse files
authored
Add context parameter to some database functions (#26055)
To avoid deadlock problem, almost database related functions should be have ctx as the first parameter. This PR do a refactor for some of these functions.
1 parent c42b718 commit b167f35

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+209
-237
lines changed

models/activities/action.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -391,10 +391,10 @@ func (a *Action) GetIssueInfos() []string {
391391
}
392392

393393
// GetIssueTitle returns the title of first issue associated
394-
// with the action.
394+
// with the action. This function will be invoked in template so keep db.DefaultContext here
395395
func (a *Action) GetIssueTitle() string {
396396
index, _ := strconv.ParseInt(a.GetIssueInfos()[0], 10, 64)
397-
issue, err := issues_model.GetIssueByIndex(a.RepoID, index)
397+
issue, err := issues_model.GetIssueByIndex(db.DefaultContext, a.RepoID, index)
398398
if err != nil {
399399
log.Error("GetIssueByIndex: %v", err)
400400
return "500 when get issue"
@@ -404,9 +404,9 @@ func (a *Action) GetIssueTitle() string {
404404

405405
// GetIssueContent returns the content of first issue associated with
406406
// this action.
407-
func (a *Action) GetIssueContent() string {
407+
func (a *Action) GetIssueContent(ctx context.Context) string {
408408
index, _ := strconv.ParseInt(a.GetIssueInfos()[0], 10, 64)
409-
issue, err := issues_model.GetIssueByIndex(a.RepoID, index)
409+
issue, err := issues_model.GetIssueByIndex(ctx, a.RepoID, index)
410410
if err != nil {
411411
log.Error("GetIssueByIndex: %v", err)
412412
return "500 when get issue"

models/activities/repo_activity.go

+27-27
Original file line numberDiff line numberDiff line change
@@ -47,21 +47,21 @@ type ActivityStats struct {
4747
func GetActivityStats(ctx context.Context, repo *repo_model.Repository, timeFrom time.Time, releases, issues, prs, code bool) (*ActivityStats, error) {
4848
stats := &ActivityStats{Code: &git.CodeActivityStats{}}
4949
if releases {
50-
if err := stats.FillReleases(repo.ID, timeFrom); err != nil {
50+
if err := stats.FillReleases(ctx, repo.ID, timeFrom); err != nil {
5151
return nil, fmt.Errorf("FillReleases: %w", err)
5252
}
5353
}
5454
if prs {
55-
if err := stats.FillPullRequests(repo.ID, timeFrom); err != nil {
55+
if err := stats.FillPullRequests(ctx, repo.ID, timeFrom); err != nil {
5656
return nil, fmt.Errorf("FillPullRequests: %w", err)
5757
}
5858
}
5959
if issues {
60-
if err := stats.FillIssues(repo.ID, timeFrom); err != nil {
60+
if err := stats.FillIssues(ctx, repo.ID, timeFrom); err != nil {
6161
return nil, fmt.Errorf("FillIssues: %w", err)
6262
}
6363
}
64-
if err := stats.FillUnresolvedIssues(repo.ID, timeFrom, issues, prs); err != nil {
64+
if err := stats.FillUnresolvedIssues(ctx, repo.ID, timeFrom, issues, prs); err != nil {
6565
return nil, fmt.Errorf("FillUnresolvedIssues: %w", err)
6666
}
6767
if code {
@@ -205,41 +205,41 @@ func (stats *ActivityStats) PublishedReleaseCount() int {
205205
}
206206

207207
// FillPullRequests returns pull request information for activity page
208-
func (stats *ActivityStats) FillPullRequests(repoID int64, fromTime time.Time) error {
208+
func (stats *ActivityStats) FillPullRequests(ctx context.Context, repoID int64, fromTime time.Time) error {
209209
var err error
210210
var count int64
211211

212212
// Merged pull requests
213-
sess := pullRequestsForActivityStatement(repoID, fromTime, true)
213+
sess := pullRequestsForActivityStatement(ctx, repoID, fromTime, true)
214214
sess.OrderBy("pull_request.merged_unix DESC")
215215
stats.MergedPRs = make(issues_model.PullRequestList, 0)
216216
if err = sess.Find(&stats.MergedPRs); err != nil {
217217
return err
218218
}
219-
if err = stats.MergedPRs.LoadAttributes(); err != nil {
219+
if err = stats.MergedPRs.LoadAttributes(ctx); err != nil {
220220
return err
221221
}
222222

223223
// Merged pull request authors
224-
sess = pullRequestsForActivityStatement(repoID, fromTime, true)
224+
sess = pullRequestsForActivityStatement(ctx, repoID, fromTime, true)
225225
if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("pull_request").Get(&count); err != nil {
226226
return err
227227
}
228228
stats.MergedPRAuthorCount = count
229229

230230
// Opened pull requests
231-
sess = pullRequestsForActivityStatement(repoID, fromTime, false)
231+
sess = pullRequestsForActivityStatement(ctx, repoID, fromTime, false)
232232
sess.OrderBy("issue.created_unix ASC")
233233
stats.OpenedPRs = make(issues_model.PullRequestList, 0)
234234
if err = sess.Find(&stats.OpenedPRs); err != nil {
235235
return err
236236
}
237-
if err = stats.OpenedPRs.LoadAttributes(); err != nil {
237+
if err = stats.OpenedPRs.LoadAttributes(ctx); err != nil {
238238
return err
239239
}
240240

241241
// Opened pull request authors
242-
sess = pullRequestsForActivityStatement(repoID, fromTime, false)
242+
sess = pullRequestsForActivityStatement(ctx, repoID, fromTime, false)
243243
if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("pull_request").Get(&count); err != nil {
244244
return err
245245
}
@@ -248,8 +248,8 @@ func (stats *ActivityStats) FillPullRequests(repoID int64, fromTime time.Time) e
248248
return nil
249249
}
250250

251-
func pullRequestsForActivityStatement(repoID int64, fromTime time.Time, merged bool) *xorm.Session {
252-
sess := db.GetEngine(db.DefaultContext).Where("pull_request.base_repo_id=?", repoID).
251+
func pullRequestsForActivityStatement(ctx context.Context, repoID int64, fromTime time.Time, merged bool) *xorm.Session {
252+
sess := db.GetEngine(ctx).Where("pull_request.base_repo_id=?", repoID).
253253
Join("INNER", "issue", "pull_request.issue_id = issue.id")
254254

255255
if merged {
@@ -264,35 +264,35 @@ func pullRequestsForActivityStatement(repoID int64, fromTime time.Time, merged b
264264
}
265265

266266
// FillIssues returns issue information for activity page
267-
func (stats *ActivityStats) FillIssues(repoID int64, fromTime time.Time) error {
267+
func (stats *ActivityStats) FillIssues(ctx context.Context, repoID int64, fromTime time.Time) error {
268268
var err error
269269
var count int64
270270

271271
// Closed issues
272-
sess := issuesForActivityStatement(repoID, fromTime, true, false)
272+
sess := issuesForActivityStatement(ctx, repoID, fromTime, true, false)
273273
sess.OrderBy("issue.closed_unix DESC")
274274
stats.ClosedIssues = make(issues_model.IssueList, 0)
275275
if err = sess.Find(&stats.ClosedIssues); err != nil {
276276
return err
277277
}
278278

279279
// Closed issue authors
280-
sess = issuesForActivityStatement(repoID, fromTime, true, false)
280+
sess = issuesForActivityStatement(ctx, repoID, fromTime, true, false)
281281
if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("issue").Get(&count); err != nil {
282282
return err
283283
}
284284
stats.ClosedIssueAuthorCount = count
285285

286286
// New issues
287-
sess = issuesForActivityStatement(repoID, fromTime, false, false)
287+
sess = issuesForActivityStatement(ctx, repoID, fromTime, false, false)
288288
sess.OrderBy("issue.created_unix ASC")
289289
stats.OpenedIssues = make(issues_model.IssueList, 0)
290290
if err = sess.Find(&stats.OpenedIssues); err != nil {
291291
return err
292292
}
293293

294294
// Opened issue authors
295-
sess = issuesForActivityStatement(repoID, fromTime, false, false)
295+
sess = issuesForActivityStatement(ctx, repoID, fromTime, false, false)
296296
if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("issue").Get(&count); err != nil {
297297
return err
298298
}
@@ -302,12 +302,12 @@ func (stats *ActivityStats) FillIssues(repoID int64, fromTime time.Time) error {
302302
}
303303

304304
// FillUnresolvedIssues returns unresolved issue and pull request information for activity page
305-
func (stats *ActivityStats) FillUnresolvedIssues(repoID int64, fromTime time.Time, issues, prs bool) error {
305+
func (stats *ActivityStats) FillUnresolvedIssues(ctx context.Context, repoID int64, fromTime time.Time, issues, prs bool) error {
306306
// Check if we need to select anything
307307
if !issues && !prs {
308308
return nil
309309
}
310-
sess := issuesForActivityStatement(repoID, fromTime, false, true)
310+
sess := issuesForActivityStatement(ctx, repoID, fromTime, false, true)
311311
if !issues || !prs {
312312
sess.And("issue.is_pull = ?", prs)
313313
}
@@ -316,8 +316,8 @@ func (stats *ActivityStats) FillUnresolvedIssues(repoID int64, fromTime time.Tim
316316
return sess.Find(&stats.UnresolvedIssues)
317317
}
318318

319-
func issuesForActivityStatement(repoID int64, fromTime time.Time, closed, unresolved bool) *xorm.Session {
320-
sess := db.GetEngine(db.DefaultContext).Where("issue.repo_id = ?", repoID).
319+
func issuesForActivityStatement(ctx context.Context, repoID int64, fromTime time.Time, closed, unresolved bool) *xorm.Session {
320+
sess := db.GetEngine(ctx).Where("issue.repo_id = ?", repoID).
321321
And("issue.is_closed = ?", closed)
322322

323323
if !unresolved {
@@ -336,20 +336,20 @@ func issuesForActivityStatement(repoID int64, fromTime time.Time, closed, unreso
336336
}
337337

338338
// FillReleases returns release information for activity page
339-
func (stats *ActivityStats) FillReleases(repoID int64, fromTime time.Time) error {
339+
func (stats *ActivityStats) FillReleases(ctx context.Context, repoID int64, fromTime time.Time) error {
340340
var err error
341341
var count int64
342342

343343
// Published releases list
344-
sess := releasesForActivityStatement(repoID, fromTime)
344+
sess := releasesForActivityStatement(ctx, repoID, fromTime)
345345
sess.OrderBy("release.created_unix DESC")
346346
stats.PublishedReleases = make([]*repo_model.Release, 0)
347347
if err = sess.Find(&stats.PublishedReleases); err != nil {
348348
return err
349349
}
350350

351351
// Published releases authors
352-
sess = releasesForActivityStatement(repoID, fromTime)
352+
sess = releasesForActivityStatement(ctx, repoID, fromTime)
353353
if _, err = sess.Select("count(distinct release.publisher_id) as `count`").Table("release").Get(&count); err != nil {
354354
return err
355355
}
@@ -358,8 +358,8 @@ func (stats *ActivityStats) FillReleases(repoID int64, fromTime time.Time) error
358358
return nil
359359
}
360360

361-
func releasesForActivityStatement(repoID int64, fromTime time.Time) *xorm.Session {
362-
return db.GetEngine(db.DefaultContext).Where("release.repo_id = ?", repoID).
361+
func releasesForActivityStatement(ctx context.Context, repoID int64, fromTime time.Time) *xorm.Session {
362+
return db.GetEngine(ctx).Where("release.repo_id = ?", repoID).
363363
And("release.is_draft = ?", false).
364364
And("release.created_unix >= ?", fromTime.Unix())
365365
}

models/issues/comment_list.go

+3-8
Original file line numberDiff line numberDiff line change
@@ -465,8 +465,9 @@ func (comments CommentList) loadReviews(ctx context.Context) error {
465465
return nil
466466
}
467467

468-
// loadAttributes loads all attributes
469-
func (comments CommentList) loadAttributes(ctx context.Context) (err error) {
468+
// LoadAttributes loads attributes of the comments, except for attachments and
469+
// comments
470+
func (comments CommentList) LoadAttributes(ctx context.Context) (err error) {
470471
if err = comments.LoadPosters(ctx); err != nil {
471472
return err
472473
}
@@ -501,9 +502,3 @@ func (comments CommentList) loadAttributes(ctx context.Context) (err error) {
501502

502503
return comments.loadDependentIssues(ctx)
503504
}
504-
505-
// LoadAttributes loads attributes of the comments, except for attachments and
506-
// comments
507-
func (comments CommentList) LoadAttributes() error {
508-
return comments.loadAttributes(db.DefaultContext)
509-
}

models/issues/issue.go

+7-7
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ func (issue *Issue) LoadAttributes(ctx context.Context) (err error) {
354354
return err
355355
}
356356

357-
if err = issue.Comments.loadAttributes(ctx); err != nil {
357+
if err = issue.Comments.LoadAttributes(ctx); err != nil {
358358
return err
359359
}
360360
if issue.IsTimetrackerEnabled(ctx) {
@@ -502,15 +502,15 @@ func (issue *Issue) GetLastEventLabelFake() string {
502502
}
503503

504504
// GetIssueByIndex returns raw issue without loading attributes by index in a repository.
505-
func GetIssueByIndex(repoID, index int64) (*Issue, error) {
505+
func GetIssueByIndex(ctx context.Context, repoID, index int64) (*Issue, error) {
506506
if index < 1 {
507507
return nil, ErrIssueNotExist{}
508508
}
509509
issue := &Issue{
510510
RepoID: repoID,
511511
Index: index,
512512
}
513-
has, err := db.GetEngine(db.DefaultContext).Get(issue)
513+
has, err := db.GetEngine(ctx).Get(issue)
514514
if err != nil {
515515
return nil, err
516516
} else if !has {
@@ -520,12 +520,12 @@ func GetIssueByIndex(repoID, index int64) (*Issue, error) {
520520
}
521521

522522
// GetIssueWithAttrsByIndex returns issue by index in a repository.
523-
func GetIssueWithAttrsByIndex(repoID, index int64) (*Issue, error) {
524-
issue, err := GetIssueByIndex(repoID, index)
523+
func GetIssueWithAttrsByIndex(ctx context.Context, repoID, index int64) (*Issue, error) {
524+
issue, err := GetIssueByIndex(ctx, repoID, index)
525525
if err != nil {
526526
return nil, err
527527
}
528-
return issue, issue.LoadAttributes(db.DefaultContext)
528+
return issue, issue.LoadAttributes(ctx)
529529
}
530530

531531
// GetIssueByID returns an issue by given ID.
@@ -846,7 +846,7 @@ func GetPinnedIssues(ctx context.Context, repoID int64, isPull bool) ([]*Issue,
846846
return nil, err
847847
}
848848

849-
err = IssueList(issues).LoadAttributes()
849+
err = IssueList(issues).LoadAttributes(ctx)
850850
if err != nil {
851851
return nil, err
852852
}

models/issues/issue_list.go

+1-7
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ func (issues IssueList) loadTotalTrackedTimes(ctx context.Context) (err error) {
526526
}
527527

528528
// loadAttributes loads all attributes, expect for attachments and comments
529-
func (issues IssueList) loadAttributes(ctx context.Context) error {
529+
func (issues IssueList) LoadAttributes(ctx context.Context) error {
530530
if _, err := issues.LoadRepositories(ctx); err != nil {
531531
return fmt.Errorf("issue.loadAttributes: LoadRepositories: %w", err)
532532
}
@@ -562,12 +562,6 @@ func (issues IssueList) loadAttributes(ctx context.Context) error {
562562
return nil
563563
}
564564

565-
// LoadAttributes loads attributes of the issues, except for attachments and
566-
// comments
567-
func (issues IssueList) LoadAttributes() error {
568-
return issues.loadAttributes(db.DefaultContext)
569-
}
570-
571565
// LoadComments loads comments
572566
func (issues IssueList) LoadComments(ctx context.Context) error {
573567
return issues.loadComments(ctx, builder.NewCond())

models/issues/issue_list_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ func TestIssueList_LoadAttributes(t *testing.T) {
3939
unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: 4}),
4040
}
4141

42-
assert.NoError(t, issueList.LoadAttributes())
42+
assert.NoError(t, issueList.LoadAttributes(db.DefaultContext))
4343
for _, issue := range issueList {
4444
assert.EqualValues(t, issue.RepoID, issue.Repo.ID)
4545
for _, label := range issue.Labels {

models/issues/issue_search.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ func Issues(ctx context.Context, opts *IssuesOptions) ([]*Issue, error) {
440440
return nil, fmt.Errorf("unable to query Issues: %w", err)
441441
}
442442

443-
if err := issues.LoadAttributes(); err != nil {
443+
if err := issues.LoadAttributes(ctx); err != nil {
444444
return nil, fmt.Errorf("unable to LoadAttributes for Issues: %w", err)
445445
}
446446

models/issues/pull_list.go

+9-14
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,16 @@ func listPullRequestStatement(baseRepoID int64, opts *PullRequestsOptions) (*xor
5151
}
5252

5353
// GetUnmergedPullRequestsByHeadInfo returns all pull requests that are open and has not been merged
54-
func GetUnmergedPullRequestsByHeadInfo(repoID int64, branch string) ([]*PullRequest, error) {
54+
func GetUnmergedPullRequestsByHeadInfo(ctx context.Context, repoID int64, branch string) ([]*PullRequest, error) {
5555
prs := make([]*PullRequest, 0, 2)
56-
sess := db.GetEngine(db.DefaultContext).
56+
sess := db.GetEngine(ctx).
5757
Join("INNER", "issue", "issue.id = pull_request.issue_id").
5858
Where("head_repo_id = ? AND head_branch = ? AND has_merged = ? AND issue.is_closed = ? AND flow = ?", repoID, branch, false, false, PullRequestFlowGithub)
5959
return prs, sess.Find(&prs)
6060
}
6161

6262
// CanMaintainerWriteToBranch check whether user is a maintainer and could write to the branch
63-
func CanMaintainerWriteToBranch(p access_model.Permission, branch string, user *user_model.User) bool {
63+
func CanMaintainerWriteToBranch(ctx context.Context, p access_model.Permission, branch string, user *user_model.User) bool {
6464
if p.CanWrite(unit.TypeCode) {
6565
return true
6666
}
@@ -69,18 +69,18 @@ func CanMaintainerWriteToBranch(p access_model.Permission, branch string, user *
6969
return false
7070
}
7171

72-
prs, err := GetUnmergedPullRequestsByHeadInfo(p.Units[0].RepoID, branch)
72+
prs, err := GetUnmergedPullRequestsByHeadInfo(ctx, p.Units[0].RepoID, branch)
7373
if err != nil {
7474
return false
7575
}
7676

7777
for _, pr := range prs {
7878
if pr.AllowMaintainerEdit {
79-
err = pr.LoadBaseRepo(db.DefaultContext)
79+
err = pr.LoadBaseRepo(ctx)
8080
if err != nil {
8181
continue
8282
}
83-
prPerm, err := access_model.GetUserRepoPermission(db.DefaultContext, pr.BaseRepo, user)
83+
prPerm, err := access_model.GetUserRepoPermission(ctx, pr.BaseRepo, user)
8484
if err != nil {
8585
continue
8686
}
@@ -104,9 +104,9 @@ func HasUnmergedPullRequestsByHeadInfo(ctx context.Context, repoID int64, branch
104104

105105
// GetUnmergedPullRequestsByBaseInfo returns all pull requests that are open and has not been merged
106106
// by given base information (repo and branch).
107-
func GetUnmergedPullRequestsByBaseInfo(repoID int64, branch string) ([]*PullRequest, error) {
107+
func GetUnmergedPullRequestsByBaseInfo(ctx context.Context, repoID int64, branch string) ([]*PullRequest, error) {
108108
prs := make([]*PullRequest, 0, 2)
109-
return prs, db.GetEngine(db.DefaultContext).
109+
return prs, db.GetEngine(ctx).
110110
Where("base_repo_id=? AND base_branch=? AND has_merged=? AND issue.is_closed=?",
111111
repoID, branch, false, false).
112112
OrderBy("issue.updated_unix DESC").
@@ -154,7 +154,7 @@ func PullRequests(baseRepoID int64, opts *PullRequestsOptions) ([]*PullRequest,
154154
// PullRequestList defines a list of pull requests
155155
type PullRequestList []*PullRequest
156156

157-
func (prs PullRequestList) loadAttributes(ctx context.Context) error {
157+
func (prs PullRequestList) LoadAttributes(ctx context.Context) error {
158158
if len(prs) == 0 {
159159
return nil
160160
}
@@ -199,8 +199,3 @@ func (prs PullRequestList) GetIssueIDs() []int64 {
199199
}
200200
return issueIDs
201201
}
202-
203-
// LoadAttributes load all the prs attributes
204-
func (prs PullRequestList) LoadAttributes() error {
205-
return prs.loadAttributes(db.DefaultContext)
206-
}

0 commit comments

Comments
 (0)