Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace more db.DefaultContext #27628

Merged
merged 3 commits into from
Oct 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions contrib/fixtures/fixture_generation.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package main

import (
"context"
"fmt"
"os"
"path/filepath"
Expand All @@ -18,7 +19,7 @@ import (

var (
generators = []struct {
gen func() (string, error)
gen func(ctx context.Context) (string, error)
name string
}{
{
Expand All @@ -41,27 +42,28 @@ func main() {
fmt.Printf("PrepareTestDatabase: %+v\n", err)
os.Exit(1)
}
ctx := context.Background()
if len(os.Args) == 0 {
for _, r := range os.Args {
if err := generate(r); err != nil {
if err := generate(ctx, r); err != nil {
fmt.Printf("generate '%s': %+v\n", r, err)
os.Exit(1)
}
}
} else {
for _, g := range generators {
if err := generate(g.name); err != nil {
if err := generate(ctx, g.name); err != nil {
fmt.Printf("generate '%s': %+v\n", g.name, err)
os.Exit(1)
}
}
}
}

func generate(name string) error {
func generate(ctx context.Context, name string) error {
for _, g := range generators {
if g.name == name {
data, err := g.gen()
data, err := g.gen(ctx)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion models/asymkey/ssh_key_authorized_keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func appendAuthorizedKeysToFile(keys ...*PublicKey) error {
}

// RewriteAllPublicKeys removes any authorized key and rewrite all keys from database again.
// Note: db.GetEngine(db.DefaultContext).Iterate does not get latest data after insert/delete, so we have to call this function
// Note: db.GetEngine(ctx).Iterate does not get latest data after insert/delete, so we have to call this function
// outside any session scope independently.
func RewriteAllPublicKeys(ctx context.Context) error {
// Don't rewrite key if internal server
Expand Down
2 changes: 1 addition & 1 deletion models/asymkey/ssh_key_authorized_principals.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import (
const authorizedPrincipalsFile = "authorized_principals"

// RewriteAllPrincipalKeys removes any authorized principal and rewrite all keys from database again.
// Note: db.GetEngine(db.DefaultContext).Iterate does not get latest data after insert/delete, so we have to call this function
// Note: db.GetEngine(ctx).Iterate does not get latest data after insert/delete, so we have to call this function
// outside any session scope independently.
func RewriteAllPrincipalKeys(ctx context.Context) error {
// Don't rewrite key if internal server
Expand Down
11 changes: 6 additions & 5 deletions models/fixture_generation.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package models

import (
"context"
"fmt"
"strings"

Expand All @@ -14,23 +15,23 @@ import (

// GetYamlFixturesAccess returns a string containing the contents
// for the access table, as recalculated using repo.RecalculateAccesses()
func GetYamlFixturesAccess() (string, error) {
func GetYamlFixturesAccess(ctx context.Context) (string, error) {
repos := make([]*repo_model.Repository, 0, 50)
if err := db.GetEngine(db.DefaultContext).Find(&repos); err != nil {
if err := db.GetEngine(ctx).Find(&repos); err != nil {
return "", err
}

for _, repo := range repos {
repo.MustOwner(db.DefaultContext)
if err := access_model.RecalculateAccesses(db.DefaultContext, repo); err != nil {
repo.MustOwner(ctx)
if err := access_model.RecalculateAccesses(ctx, repo); err != nil {
return "", err
}
}

var b strings.Builder

accesses := make([]*access_model.Access, 0, 200)
if err := db.GetEngine(db.DefaultContext).OrderBy("user_id, repo_id").Find(&accesses); err != nil {
if err := db.GetEngine(ctx).OrderBy("user_id, repo_id").Find(&accesses); err != nil {
return "", err
}

Expand Down
8 changes: 5 additions & 3 deletions models/fixture_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
package models

import (
"context"
"os"
"path/filepath"
"testing"

"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"
"code.gitea.io/gitea/modules/util"

Expand All @@ -17,8 +19,8 @@ import (
func TestFixtureGeneration(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())

test := func(gen func() (string, error), name string) {
expected, err := gen()
test := func(ctx context.Context, gen func(ctx context.Context) (string, error), name string) {
expected, err := gen(ctx)
if !assert.NoError(t, err) {
return
}
Expand All @@ -31,5 +33,5 @@ func TestFixtureGeneration(t *testing.T) {
assert.EqualValues(t, expected, data, "Differences detected for %s", p)
}

test(GetYamlFixturesAccess, "access")
test(db.DefaultContext, GetYamlFixturesAccess, "access")
}
4 changes: 2 additions & 2 deletions models/org.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ func removeOrgUser(ctx context.Context, orgID, userID int64) error {
}

// RemoveOrgUser removes user from given organization.
func RemoveOrgUser(orgID, userID int64) error {
ctx, committer, err := db.TxContext(db.DefaultContext)
func RemoveOrgUser(ctx context.Context, orgID, userID int64) error {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
Expand Down
9 changes: 5 additions & 4 deletions models/org_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package models
import (
"testing"

"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/organization"
"code.gitea.io/gitea/models/unittest"
user_model "code.gitea.io/gitea/models/user"
Expand All @@ -20,15 +21,15 @@ func TestUser_RemoveMember(t *testing.T) {
// remove a user that is a member
unittest.AssertExistsAndLoadBean(t, &organization.OrgUser{UID: 4, OrgID: 3})
prevNumMembers := org.NumMembers
assert.NoError(t, RemoveOrgUser(org.ID, 4))
assert.NoError(t, RemoveOrgUser(db.DefaultContext, org.ID, 4))
unittest.AssertNotExistsBean(t, &organization.OrgUser{UID: 4, OrgID: 3})
org = unittest.AssertExistsAndLoadBean(t, &organization.Organization{ID: 3})
assert.Equal(t, prevNumMembers-1, org.NumMembers)

// remove a user that is not a member
unittest.AssertNotExistsBean(t, &organization.OrgUser{UID: 5, OrgID: 3})
prevNumMembers = org.NumMembers
assert.NoError(t, RemoveOrgUser(org.ID, 5))
assert.NoError(t, RemoveOrgUser(db.DefaultContext, org.ID, 5))
unittest.AssertNotExistsBean(t, &organization.OrgUser{UID: 5, OrgID: 3})
org = unittest.AssertExistsAndLoadBean(t, &organization.Organization{ID: 3})
assert.Equal(t, prevNumMembers, org.NumMembers)
Expand All @@ -44,15 +45,15 @@ func TestRemoveOrgUser(t *testing.T) {
if unittest.BeanExists(t, &organization.OrgUser{OrgID: orgID, UID: userID}) {
expectedNumMembers--
}
assert.NoError(t, RemoveOrgUser(orgID, userID))
assert.NoError(t, RemoveOrgUser(db.DefaultContext, orgID, userID))
unittest.AssertNotExistsBean(t, &organization.OrgUser{OrgID: orgID, UID: userID})
org = unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: orgID})
assert.EqualValues(t, expectedNumMembers, org.NumMembers)
}
testSuccess(3, 4)
testSuccess(3, 4)

err := RemoveOrgUser(7, 5)
err := RemoveOrgUser(db.DefaultContext, 7, 5)
assert.Error(t, err)
assert.True(t, organization.IsErrLastOrgOwner(err))
unittest.AssertExistsAndLoadBean(t, &organization.OrgUser{OrgID: 7, UID: 5})
Expand Down
4 changes: 2 additions & 2 deletions models/repo/repo_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ func FindUserCodeAccessibleOwnerRepoIDs(ctx context.Context, ownerID int64, user
}

// GetUserRepositories returns a list of repositories of given user.
func GetUserRepositories(opts *SearchRepoOptions) (RepositoryList, int64, error) {
func GetUserRepositories(ctx context.Context, opts *SearchRepoOptions) (RepositoryList, int64, error) {
if len(opts.OrderBy) == 0 {
opts.OrderBy = "updated_unix DESC"
}
Expand All @@ -734,7 +734,7 @@ func GetUserRepositories(opts *SearchRepoOptions) (RepositoryList, int64, error)
cond = cond.And(builder.In("lower_name", opts.LowerNames))
}

sess := db.GetEngine(db.DefaultContext)
sess := db.GetEngine(ctx)

count, err := sess.Where(cond).Count(new(Repository))
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions models/repo/watch.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ func watchRepoMode(ctx context.Context, watch Watch, mode WatchMode) (err error)
}

// WatchRepoMode watch repository in specific mode.
func WatchRepoMode(userID, repoID int64, mode WatchMode) (err error) {
func WatchRepoMode(ctx context.Context, userID, repoID int64, mode WatchMode) (err error) {
var watch Watch
if watch, err = GetWatch(db.DefaultContext, userID, repoID); err != nil {
if watch, err = GetWatch(ctx, userID, repoID); err != nil {
return err
}
return watchRepoMode(db.DefaultContext, watch, mode)
return watchRepoMode(ctx, watch, mode)
}

// WatchRepo watch or unwatch repository.
Expand Down
8 changes: 4 additions & 4 deletions models/repo/watch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,18 +122,18 @@ func TestWatchRepoMode(t *testing.T) {

unittest.AssertCount(t, &repo_model.Watch{UserID: 12, RepoID: 1}, 0)

assert.NoError(t, repo_model.WatchRepoMode(12, 1, repo_model.WatchModeAuto))
assert.NoError(t, repo_model.WatchRepoMode(db.DefaultContext, 12, 1, repo_model.WatchModeAuto))
unittest.AssertCount(t, &repo_model.Watch{UserID: 12, RepoID: 1}, 1)
unittest.AssertCount(t, &repo_model.Watch{UserID: 12, RepoID: 1, Mode: repo_model.WatchModeAuto}, 1)

assert.NoError(t, repo_model.WatchRepoMode(12, 1, repo_model.WatchModeNormal))
assert.NoError(t, repo_model.WatchRepoMode(db.DefaultContext, 12, 1, repo_model.WatchModeNormal))
unittest.AssertCount(t, &repo_model.Watch{UserID: 12, RepoID: 1}, 1)
unittest.AssertCount(t, &repo_model.Watch{UserID: 12, RepoID: 1, Mode: repo_model.WatchModeNormal}, 1)

assert.NoError(t, repo_model.WatchRepoMode(12, 1, repo_model.WatchModeDont))
assert.NoError(t, repo_model.WatchRepoMode(db.DefaultContext, 12, 1, repo_model.WatchModeDont))
unittest.AssertCount(t, &repo_model.Watch{UserID: 12, RepoID: 1}, 1)
unittest.AssertCount(t, &repo_model.Watch{UserID: 12, RepoID: 1, Mode: repo_model.WatchModeDont}, 1)

assert.NoError(t, repo_model.WatchRepoMode(12, 1, repo_model.WatchModeNone))
assert.NoError(t, repo_model.WatchRepoMode(db.DefaultContext, 12, 1, repo_model.WatchModeNone))
unittest.AssertCount(t, &repo_model.Watch{UserID: 12, RepoID: 1}, 0)
}
8 changes: 4 additions & 4 deletions models/system/appstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ func init() {
}

// SaveAppStateContent saves the app state item to database
func SaveAppStateContent(key, content string) error {
return db.WithTx(db.DefaultContext, func(ctx context.Context) error {
func SaveAppStateContent(ctx context.Context, key, content string) error {
return db.WithTx(ctx, func(ctx context.Context) error {
eng := db.GetEngine(ctx)
// try to update existing row
res, err := eng.Exec("UPDATE app_state SET revision=revision+1, content=? WHERE id=?", content, key)
Expand All @@ -43,8 +43,8 @@ func SaveAppStateContent(key, content string) error {
}

// GetAppStateContent gets an app state from database
func GetAppStateContent(key string) (content string, err error) {
e := db.GetEngine(db.DefaultContext)
func GetAppStateContent(ctx context.Context, key string) (content string, err error) {
e := db.GetEngine(ctx)
appState := &AppState{ID: key}
has, err := e.Get(appState)
if err != nil {
Expand Down
6 changes: 4 additions & 2 deletions modules/system/appstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

package system

import "context"

// StateStore is the interface to get/set app state items
type StateStore interface {
Get(item StateItem) error
Set(item StateItem) error
Get(ctx context.Context, item StateItem) error
Set(ctx context.Context, item StateItem) error
}

// StateItem provides the name for a state item. the name will be used to generate filenames, etc
Expand Down
11 changes: 6 additions & 5 deletions modules/system/appstate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package system
import (
"testing"

"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -40,25 +41,25 @@ func TestAppStateDB(t *testing.T) {
as := &DBStore{}

item1 := new(testItem1)
assert.NoError(t, as.Get(item1))
assert.NoError(t, as.Get(db.DefaultContext, item1))
assert.Equal(t, "", item1.Val1)
assert.EqualValues(t, 0, item1.Val2)

item1 = new(testItem1)
item1.Val1 = "a"
item1.Val2 = 2
assert.NoError(t, as.Set(item1))
assert.NoError(t, as.Set(db.DefaultContext, item1))

item2 := new(testItem2)
item2.K = "V"
assert.NoError(t, as.Set(item2))
assert.NoError(t, as.Set(db.DefaultContext, item2))

item1 = new(testItem1)
assert.NoError(t, as.Get(item1))
assert.NoError(t, as.Get(db.DefaultContext, item1))
assert.Equal(t, "a", item1.Val1)
assert.EqualValues(t, 2, item1.Val2)

item2 = new(testItem2)
assert.NoError(t, as.Get(item2))
assert.NoError(t, as.Get(db.DefaultContext, item2))
assert.Equal(t, "V", item2.K)
}
10 changes: 6 additions & 4 deletions modules/system/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
package system

import (
"context"

"code.gitea.io/gitea/models/system"
"code.gitea.io/gitea/modules/json"

Expand All @@ -14,8 +16,8 @@ import (
type DBStore struct{}

// Get reads the state item
func (f *DBStore) Get(item StateItem) error {
content, err := system.GetAppStateContent(item.Name())
func (f *DBStore) Get(ctx context.Context, item StateItem) error {
content, err := system.GetAppStateContent(ctx, item.Name())
if err != nil {
return err
}
Expand All @@ -26,10 +28,10 @@ func (f *DBStore) Get(item StateItem) error {
}

// Set saves the state item
func (f *DBStore) Set(item StateItem) error {
func (f *DBStore) Set(ctx context.Context, item StateItem) error {
b, err := json.Marshal(item)
if err != nil {
return err
}
return system.SaveAppStateContent(item.Name(), util.BytesToReadOnlyString(b))
return system.SaveAppStateContent(ctx, item.Name(), util.BytesToReadOnlyString(b))
}
Loading