diff --git a/database/common/common.go b/database/common/common.go index 8ca57ac2..023a2057 100644 --- a/database/common/common.go +++ b/database/common/common.go @@ -27,16 +27,6 @@ type RepoStore interface { ListRepositories(ctx context.Context) ([]params.Repository, error) DeleteRepository(ctx context.Context, repoID string) error UpdateRepository(ctx context.Context, repoID string, param params.UpdateEntityParams) (params.Repository, error) - - CreateRepositoryPool(ctx context.Context, repoID string, param params.CreatePoolParams) (params.Pool, error) - - GetRepositoryPool(ctx context.Context, repoID, poolID string) (params.Pool, error) - DeleteRepositoryPool(ctx context.Context, repoID, poolID string) error - UpdateRepositoryPool(ctx context.Context, repoID, poolID string, param params.UpdatePoolParams) (params.Pool, error) - FindRepositoryPoolByTags(ctx context.Context, repoID string, tags []string) (params.Pool, error) - - ListRepoPools(ctx context.Context, repoID string) ([]params.Pool, error) - ListRepoInstances(ctx context.Context, repoID string) ([]params.Instance, error) } type OrgStore interface { @@ -46,15 +36,6 @@ type OrgStore interface { ListOrganizations(ctx context.Context) ([]params.Organization, error) DeleteOrganization(ctx context.Context, orgID string) error UpdateOrganization(ctx context.Context, orgID string, param params.UpdateEntityParams) (params.Organization, error) - - CreateOrganizationPool(ctx context.Context, orgID string, param params.CreatePoolParams) (params.Pool, error) - GetOrganizationPool(ctx context.Context, orgID, poolID string) (params.Pool, error) - DeleteOrganizationPool(ctx context.Context, orgID, poolID string) error - UpdateOrganizationPool(ctx context.Context, orgID, poolID string, param params.UpdatePoolParams) (params.Pool, error) - - FindOrganizationPoolByTags(ctx context.Context, orgID string, tags []string) (params.Pool, error) - ListOrgPools(ctx context.Context, orgID string) ([]params.Pool, error) - ListOrgInstances(ctx context.Context, orgID string) ([]params.Instance, error) } type EnterpriseStore interface { @@ -64,15 +45,6 @@ type EnterpriseStore interface { ListEnterprises(ctx context.Context) ([]params.Enterprise, error) DeleteEnterprise(ctx context.Context, enterpriseID string) error UpdateEnterprise(ctx context.Context, enterpriseID string, param params.UpdateEntityParams) (params.Enterprise, error) - - CreateEnterprisePool(ctx context.Context, enterpriseID string, param params.CreatePoolParams) (params.Pool, error) - GetEnterprisePool(ctx context.Context, enterpriseID, poolID string) (params.Pool, error) - DeleteEnterprisePool(ctx context.Context, enterpriseID, poolID string) error - UpdateEnterprisePool(ctx context.Context, enterpriseID, poolID string, param params.UpdatePoolParams) (params.Pool, error) - - FindEnterprisePoolByTags(ctx context.Context, enterpriseID string, tags []string) (params.Pool, error) - ListEnterprisePools(ctx context.Context, enterpriseID string) ([]params.Pool, error) - ListEnterpriseInstances(ctx context.Context, enterpriseID string) ([]params.Instance, error) } type PoolStore interface { @@ -130,6 +102,16 @@ type JobsStore interface { DeleteCompletedJobs(ctx context.Context) error } +type EntityPools interface { + CreateEntityPool(ctx context.Context, entity params.GithubEntity, param params.CreatePoolParams) (params.Pool, error) + GetEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) (params.Pool, error) + DeleteEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) error + UpdateEntityPool(ctx context.Context, entity params.GithubEntity, poolID string, param params.UpdatePoolParams) (params.Pool, error) + + ListEntityPools(ctx context.Context, entity params.GithubEntity) ([]params.Pool, error) + ListEntityInstances(ctx context.Context, entity params.GithubEntity) ([]params.Instance, error) +} + //go:generate mockery --name=Store type Store interface { RepoStore @@ -139,6 +121,7 @@ type Store interface { UserStore InstanceStore JobsStore + EntityPools ControllerInfo() (params.ControllerInfo, error) InitController() (params.ControllerInfo, error) diff --git a/database/common/mocks/Store.go b/database/common/mocks/Store.go index 81e47799..219057e4 100644 --- a/database/common/mocks/Store.go +++ b/database/common/mocks/Store.go @@ -106,27 +106,27 @@ func (_m *Store) CreateEnterprise(ctx context.Context, name string, credentialsN return r0, r1 } -// CreateEnterprisePool provides a mock function with given fields: ctx, enterpriseID, param -func (_m *Store) CreateEnterprisePool(ctx context.Context, enterpriseID string, param params.CreatePoolParams) (params.Pool, error) { - ret := _m.Called(ctx, enterpriseID, param) +// CreateEntityPool provides a mock function with given fields: ctx, entity, param +func (_m *Store) CreateEntityPool(ctx context.Context, entity params.GithubEntity, param params.CreatePoolParams) (params.Pool, error) { + ret := _m.Called(ctx, entity, param) if len(ret) == 0 { - panic("no return value specified for CreateEnterprisePool") + panic("no return value specified for CreateEntityPool") } var r0 params.Pool var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, params.CreatePoolParams) (params.Pool, error)); ok { - return rf(ctx, enterpriseID, param) + if rf, ok := ret.Get(0).(func(context.Context, params.GithubEntity, params.CreatePoolParams) (params.Pool, error)); ok { + return rf(ctx, entity, param) } - if rf, ok := ret.Get(0).(func(context.Context, string, params.CreatePoolParams) params.Pool); ok { - r0 = rf(ctx, enterpriseID, param) + if rf, ok := ret.Get(0).(func(context.Context, params.GithubEntity, params.CreatePoolParams) params.Pool); ok { + r0 = rf(ctx, entity, param) } else { r0 = ret.Get(0).(params.Pool) } - if rf, ok := ret.Get(1).(func(context.Context, string, params.CreatePoolParams) error); ok { - r1 = rf(ctx, enterpriseID, param) + if rf, ok := ret.Get(1).(func(context.Context, params.GithubEntity, params.CreatePoolParams) error); ok { + r1 = rf(ctx, entity, param) } else { r1 = ret.Error(1) } @@ -218,34 +218,6 @@ func (_m *Store) CreateOrganization(ctx context.Context, name string, credential return r0, r1 } -// CreateOrganizationPool provides a mock function with given fields: ctx, orgID, param -func (_m *Store) CreateOrganizationPool(ctx context.Context, orgID string, param params.CreatePoolParams) (params.Pool, error) { - ret := _m.Called(ctx, orgID, param) - - if len(ret) == 0 { - panic("no return value specified for CreateOrganizationPool") - } - - var r0 params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, params.CreatePoolParams) (params.Pool, error)); ok { - return rf(ctx, orgID, param) - } - if rf, ok := ret.Get(0).(func(context.Context, string, params.CreatePoolParams) params.Pool); ok { - r0 = rf(ctx, orgID, param) - } else { - r0 = ret.Get(0).(params.Pool) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, params.CreatePoolParams) error); ok { - r1 = rf(ctx, orgID, param) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // CreateRepository provides a mock function with given fields: ctx, owner, name, credentialsName, webhookSecret, poolBalancerType func (_m *Store) CreateRepository(ctx context.Context, owner string, name string, credentialsName string, webhookSecret string, poolBalancerType params.PoolBalancerType) (params.Repository, error) { ret := _m.Called(ctx, owner, name, credentialsName, webhookSecret, poolBalancerType) @@ -274,34 +246,6 @@ func (_m *Store) CreateRepository(ctx context.Context, owner string, name string return r0, r1 } -// CreateRepositoryPool provides a mock function with given fields: ctx, repoID, param -func (_m *Store) CreateRepositoryPool(ctx context.Context, repoID string, param params.CreatePoolParams) (params.Pool, error) { - ret := _m.Called(ctx, repoID, param) - - if len(ret) == 0 { - panic("no return value specified for CreateRepositoryPool") - } - - var r0 params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, params.CreatePoolParams) (params.Pool, error)); ok { - return rf(ctx, repoID, param) - } - if rf, ok := ret.Get(0).(func(context.Context, string, params.CreatePoolParams) params.Pool); ok { - r0 = rf(ctx, repoID, param) - } else { - r0 = ret.Get(0).(params.Pool) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, params.CreatePoolParams) error); ok { - r1 = rf(ctx, repoID, param) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // CreateUser provides a mock function with given fields: ctx, user func (_m *Store) CreateUser(ctx context.Context, user params.NewUserParams) (params.User, error) { ret := _m.Called(ctx, user) @@ -366,17 +310,17 @@ func (_m *Store) DeleteEnterprise(ctx context.Context, enterpriseID string) erro return r0 } -// DeleteEnterprisePool provides a mock function with given fields: ctx, enterpriseID, poolID -func (_m *Store) DeleteEnterprisePool(ctx context.Context, enterpriseID string, poolID string) error { - ret := _m.Called(ctx, enterpriseID, poolID) +// DeleteEntityPool provides a mock function with given fields: ctx, entity, poolID +func (_m *Store) DeleteEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) error { + ret := _m.Called(ctx, entity, poolID) if len(ret) == 0 { - panic("no return value specified for DeleteEnterprisePool") + panic("no return value specified for DeleteEntityPool") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { - r0 = rf(ctx, enterpriseID, poolID) + if rf, ok := ret.Get(0).(func(context.Context, params.GithubEntity, string) error); ok { + r0 = rf(ctx, entity, poolID) } else { r0 = ret.Error(0) } @@ -438,24 +382,6 @@ func (_m *Store) DeleteOrganization(ctx context.Context, orgID string) error { return r0 } -// DeleteOrganizationPool provides a mock function with given fields: ctx, orgID, poolID -func (_m *Store) DeleteOrganizationPool(ctx context.Context, orgID string, poolID string) error { - ret := _m.Called(ctx, orgID, poolID) - - if len(ret) == 0 { - panic("no return value specified for DeleteOrganizationPool") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { - r0 = rf(ctx, orgID, poolID) - } else { - r0 = ret.Error(0) - } - - return r0 -} - // DeletePoolByID provides a mock function with given fields: ctx, poolID func (_m *Store) DeletePoolByID(ctx context.Context, poolID string) error { ret := _m.Called(ctx, poolID) @@ -492,80 +418,6 @@ func (_m *Store) DeleteRepository(ctx context.Context, repoID string) error { return r0 } -// DeleteRepositoryPool provides a mock function with given fields: ctx, repoID, poolID -func (_m *Store) DeleteRepositoryPool(ctx context.Context, repoID string, poolID string) error { - ret := _m.Called(ctx, repoID, poolID) - - if len(ret) == 0 { - panic("no return value specified for DeleteRepositoryPool") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { - r0 = rf(ctx, repoID, poolID) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// FindEnterprisePoolByTags provides a mock function with given fields: ctx, enterpriseID, tags -func (_m *Store) FindEnterprisePoolByTags(ctx context.Context, enterpriseID string, tags []string) (params.Pool, error) { - ret := _m.Called(ctx, enterpriseID, tags) - - if len(ret) == 0 { - panic("no return value specified for FindEnterprisePoolByTags") - } - - var r0 params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, []string) (params.Pool, error)); ok { - return rf(ctx, enterpriseID, tags) - } - if rf, ok := ret.Get(0).(func(context.Context, string, []string) params.Pool); ok { - r0 = rf(ctx, enterpriseID, tags) - } else { - r0 = ret.Get(0).(params.Pool) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, []string) error); ok { - r1 = rf(ctx, enterpriseID, tags) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// FindOrganizationPoolByTags provides a mock function with given fields: ctx, orgID, tags -func (_m *Store) FindOrganizationPoolByTags(ctx context.Context, orgID string, tags []string) (params.Pool, error) { - ret := _m.Called(ctx, orgID, tags) - - if len(ret) == 0 { - panic("no return value specified for FindOrganizationPoolByTags") - } - - var r0 params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, []string) (params.Pool, error)); ok { - return rf(ctx, orgID, tags) - } - if rf, ok := ret.Get(0).(func(context.Context, string, []string) params.Pool); ok { - r0 = rf(ctx, orgID, tags) - } else { - r0 = ret.Get(0).(params.Pool) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, []string) error); ok { - r1 = rf(ctx, orgID, tags) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // FindPoolsMatchingAllTags provides a mock function with given fields: ctx, entityType, entityID, tags func (_m *Store) FindPoolsMatchingAllTags(ctx context.Context, entityType params.GithubEntityType, entityID string, tags []string) ([]params.Pool, error) { ret := _m.Called(ctx, entityType, entityID, tags) @@ -596,34 +448,6 @@ func (_m *Store) FindPoolsMatchingAllTags(ctx context.Context, entityType params return r0, r1 } -// FindRepositoryPoolByTags provides a mock function with given fields: ctx, repoID, tags -func (_m *Store) FindRepositoryPoolByTags(ctx context.Context, repoID string, tags []string) (params.Pool, error) { - ret := _m.Called(ctx, repoID, tags) - - if len(ret) == 0 { - panic("no return value specified for FindRepositoryPoolByTags") - } - - var r0 params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, []string) (params.Pool, error)); ok { - return rf(ctx, repoID, tags) - } - if rf, ok := ret.Get(0).(func(context.Context, string, []string) params.Pool); ok { - r0 = rf(ctx, repoID, tags) - } else { - r0 = ret.Get(0).(params.Pool) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, []string) error); ok { - r1 = rf(ctx, repoID, tags) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // GetEnterprise provides a mock function with given fields: ctx, name func (_m *Store) GetEnterprise(ctx context.Context, name string) (params.Enterprise, error) { ret := _m.Called(ctx, name) @@ -680,27 +504,27 @@ func (_m *Store) GetEnterpriseByID(ctx context.Context, enterpriseID string) (pa return r0, r1 } -// GetEnterprisePool provides a mock function with given fields: ctx, enterpriseID, poolID -func (_m *Store) GetEnterprisePool(ctx context.Context, enterpriseID string, poolID string) (params.Pool, error) { - ret := _m.Called(ctx, enterpriseID, poolID) +// GetEntityPool provides a mock function with given fields: ctx, entity, poolID +func (_m *Store) GetEntityPool(ctx context.Context, entity params.GithubEntity, poolID string) (params.Pool, error) { + ret := _m.Called(ctx, entity, poolID) if len(ret) == 0 { - panic("no return value specified for GetEnterprisePool") + panic("no return value specified for GetEntityPool") } var r0 params.Pool var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) (params.Pool, error)); ok { - return rf(ctx, enterpriseID, poolID) + if rf, ok := ret.Get(0).(func(context.Context, params.GithubEntity, string) (params.Pool, error)); ok { + return rf(ctx, entity, poolID) } - if rf, ok := ret.Get(0).(func(context.Context, string, string) params.Pool); ok { - r0 = rf(ctx, enterpriseID, poolID) + if rf, ok := ret.Get(0).(func(context.Context, params.GithubEntity, string) params.Pool); ok { + r0 = rf(ctx, entity, poolID) } else { r0 = ret.Get(0).(params.Pool) } - if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { - r1 = rf(ctx, enterpriseID, poolID) + if rf, ok := ret.Get(1).(func(context.Context, params.GithubEntity, string) error); ok { + r1 = rf(ctx, entity, poolID) } else { r1 = ret.Error(1) } @@ -820,34 +644,6 @@ func (_m *Store) GetOrganizationByID(ctx context.Context, orgID string) (params. return r0, r1 } -// GetOrganizationPool provides a mock function with given fields: ctx, orgID, poolID -func (_m *Store) GetOrganizationPool(ctx context.Context, orgID string, poolID string) (params.Pool, error) { - ret := _m.Called(ctx, orgID, poolID) - - if len(ret) == 0 { - panic("no return value specified for GetOrganizationPool") - } - - var r0 params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) (params.Pool, error)); ok { - return rf(ctx, orgID, poolID) - } - if rf, ok := ret.Get(0).(func(context.Context, string, string) params.Pool); ok { - r0 = rf(ctx, orgID, poolID) - } else { - r0 = ret.Get(0).(params.Pool) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { - r1 = rf(ctx, orgID, poolID) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // GetPoolByID provides a mock function with given fields: ctx, poolID func (_m *Store) GetPoolByID(ctx context.Context, poolID string) (params.Pool, error) { ret := _m.Called(ctx, poolID) @@ -960,34 +756,6 @@ func (_m *Store) GetRepositoryByID(ctx context.Context, repoID string) (params.R return r0, r1 } -// GetRepositoryPool provides a mock function with given fields: ctx, repoID, poolID -func (_m *Store) GetRepositoryPool(ctx context.Context, repoID string, poolID string) (params.Pool, error) { - ret := _m.Called(ctx, repoID, poolID) - - if len(ret) == 0 { - panic("no return value specified for GetRepositoryPool") - } - - var r0 params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) (params.Pool, error)); ok { - return rf(ctx, repoID, poolID) - } - if rf, ok := ret.Get(0).(func(context.Context, string, string) params.Pool); ok { - r0 = rf(ctx, repoID, poolID) - } else { - r0 = ret.Get(0).(params.Pool) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { - r1 = rf(ctx, repoID, poolID) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // GetUser provides a mock function with given fields: ctx, user func (_m *Store) GetUser(ctx context.Context, user string) (params.User, error) { ret := _m.Called(ctx, user) @@ -1180,29 +948,29 @@ func (_m *Store) ListAllPools(ctx context.Context) ([]params.Pool, error) { return r0, r1 } -// ListEnterpriseInstances provides a mock function with given fields: ctx, enterpriseID -func (_m *Store) ListEnterpriseInstances(ctx context.Context, enterpriseID string) ([]params.Instance, error) { - ret := _m.Called(ctx, enterpriseID) +// ListEnterprises provides a mock function with given fields: ctx +func (_m *Store) ListEnterprises(ctx context.Context) ([]params.Enterprise, error) { + ret := _m.Called(ctx) if len(ret) == 0 { - panic("no return value specified for ListEnterpriseInstances") + panic("no return value specified for ListEnterprises") } - var r0 []params.Instance + var r0 []params.Enterprise var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) ([]params.Instance, error)); ok { - return rf(ctx, enterpriseID) + if rf, ok := ret.Get(0).(func(context.Context) ([]params.Enterprise, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func(context.Context, string) []params.Instance); ok { - r0 = rf(ctx, enterpriseID) + if rf, ok := ret.Get(0).(func(context.Context) []params.Enterprise); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]params.Instance) + r0 = ret.Get(0).([]params.Enterprise) } } - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, enterpriseID) + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -1210,29 +978,29 @@ func (_m *Store) ListEnterpriseInstances(ctx context.Context, enterpriseID strin return r0, r1 } -// ListEnterprisePools provides a mock function with given fields: ctx, enterpriseID -func (_m *Store) ListEnterprisePools(ctx context.Context, enterpriseID string) ([]params.Pool, error) { - ret := _m.Called(ctx, enterpriseID) +// ListEntityInstances provides a mock function with given fields: ctx, entity +func (_m *Store) ListEntityInstances(ctx context.Context, entity params.GithubEntity) ([]params.Instance, error) { + ret := _m.Called(ctx, entity) if len(ret) == 0 { - panic("no return value specified for ListEnterprisePools") + panic("no return value specified for ListEntityInstances") } - var r0 []params.Pool + var r0 []params.Instance var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) ([]params.Pool, error)); ok { - return rf(ctx, enterpriseID) + if rf, ok := ret.Get(0).(func(context.Context, params.GithubEntity) ([]params.Instance, error)); ok { + return rf(ctx, entity) } - if rf, ok := ret.Get(0).(func(context.Context, string) []params.Pool); ok { - r0 = rf(ctx, enterpriseID) + if rf, ok := ret.Get(0).(func(context.Context, params.GithubEntity) []params.Instance); ok { + r0 = rf(ctx, entity) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]params.Pool) + r0 = ret.Get(0).([]params.Instance) } } - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, enterpriseID) + if rf, ok := ret.Get(1).(func(context.Context, params.GithubEntity) error); ok { + r1 = rf(ctx, entity) } else { r1 = ret.Error(1) } @@ -1240,29 +1008,29 @@ func (_m *Store) ListEnterprisePools(ctx context.Context, enterpriseID string) ( return r0, r1 } -// ListEnterprises provides a mock function with given fields: ctx -func (_m *Store) ListEnterprises(ctx context.Context) ([]params.Enterprise, error) { - ret := _m.Called(ctx) +// ListEntityJobsByStatus provides a mock function with given fields: ctx, entityType, entityID, status +func (_m *Store) ListEntityJobsByStatus(ctx context.Context, entityType params.GithubEntityType, entityID string, status params.JobStatus) ([]params.Job, error) { + ret := _m.Called(ctx, entityType, entityID, status) if len(ret) == 0 { - panic("no return value specified for ListEnterprises") + panic("no return value specified for ListEntityJobsByStatus") } - var r0 []params.Enterprise + var r0 []params.Job var r1 error - if rf, ok := ret.Get(0).(func(context.Context) ([]params.Enterprise, error)); ok { - return rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, params.GithubEntityType, string, params.JobStatus) ([]params.Job, error)); ok { + return rf(ctx, entityType, entityID, status) } - if rf, ok := ret.Get(0).(func(context.Context) []params.Enterprise); ok { - r0 = rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, params.GithubEntityType, string, params.JobStatus) []params.Job); ok { + r0 = rf(ctx, entityType, entityID, status) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]params.Enterprise) + r0 = ret.Get(0).([]params.Job) } } - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) + if rf, ok := ret.Get(1).(func(context.Context, params.GithubEntityType, string, params.JobStatus) error); ok { + r1 = rf(ctx, entityType, entityID, status) } else { r1 = ret.Error(1) } @@ -1270,29 +1038,29 @@ func (_m *Store) ListEnterprises(ctx context.Context) ([]params.Enterprise, erro return r0, r1 } -// ListEntityJobsByStatus provides a mock function with given fields: ctx, entityType, entityID, status -func (_m *Store) ListEntityJobsByStatus(ctx context.Context, entityType params.GithubEntityType, entityID string, status params.JobStatus) ([]params.Job, error) { - ret := _m.Called(ctx, entityType, entityID, status) +// ListEntityPools provides a mock function with given fields: ctx, entity +func (_m *Store) ListEntityPools(ctx context.Context, entity params.GithubEntity) ([]params.Pool, error) { + ret := _m.Called(ctx, entity) if len(ret) == 0 { - panic("no return value specified for ListEntityJobsByStatus") + panic("no return value specified for ListEntityPools") } - var r0 []params.Job + var r0 []params.Pool var r1 error - if rf, ok := ret.Get(0).(func(context.Context, params.GithubEntityType, string, params.JobStatus) ([]params.Job, error)); ok { - return rf(ctx, entityType, entityID, status) + if rf, ok := ret.Get(0).(func(context.Context, params.GithubEntity) ([]params.Pool, error)); ok { + return rf(ctx, entity) } - if rf, ok := ret.Get(0).(func(context.Context, params.GithubEntityType, string, params.JobStatus) []params.Job); ok { - r0 = rf(ctx, entityType, entityID, status) + if rf, ok := ret.Get(0).(func(context.Context, params.GithubEntity) []params.Pool); ok { + r0 = rf(ctx, entity) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]params.Job) + r0 = ret.Get(0).([]params.Pool) } } - if rf, ok := ret.Get(1).(func(context.Context, params.GithubEntityType, string, params.JobStatus) error); ok { - r1 = rf(ctx, entityType, entityID, status) + if rf, ok := ret.Get(1).(func(context.Context, params.GithubEntity) error); ok { + r1 = rf(ctx, entity) } else { r1 = ret.Error(1) } @@ -1360,66 +1128,6 @@ func (_m *Store) ListJobsByStatus(ctx context.Context, status params.JobStatus) return r0, r1 } -// ListOrgInstances provides a mock function with given fields: ctx, orgID -func (_m *Store) ListOrgInstances(ctx context.Context, orgID string) ([]params.Instance, error) { - ret := _m.Called(ctx, orgID) - - if len(ret) == 0 { - panic("no return value specified for ListOrgInstances") - } - - var r0 []params.Instance - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) ([]params.Instance, error)); ok { - return rf(ctx, orgID) - } - if rf, ok := ret.Get(0).(func(context.Context, string) []params.Instance); ok { - r0 = rf(ctx, orgID) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]params.Instance) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, orgID) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// ListOrgPools provides a mock function with given fields: ctx, orgID -func (_m *Store) ListOrgPools(ctx context.Context, orgID string) ([]params.Pool, error) { - ret := _m.Called(ctx, orgID) - - if len(ret) == 0 { - panic("no return value specified for ListOrgPools") - } - - var r0 []params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) ([]params.Pool, error)); ok { - return rf(ctx, orgID) - } - if rf, ok := ret.Get(0).(func(context.Context, string) []params.Pool); ok { - r0 = rf(ctx, orgID) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]params.Pool) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, orgID) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // ListOrganizations provides a mock function with given fields: ctx func (_m *Store) ListOrganizations(ctx context.Context) ([]params.Organization, error) { ret := _m.Called(ctx) @@ -1480,66 +1188,6 @@ func (_m *Store) ListPoolInstances(ctx context.Context, poolID string) ([]params return r0, r1 } -// ListRepoInstances provides a mock function with given fields: ctx, repoID -func (_m *Store) ListRepoInstances(ctx context.Context, repoID string) ([]params.Instance, error) { - ret := _m.Called(ctx, repoID) - - if len(ret) == 0 { - panic("no return value specified for ListRepoInstances") - } - - var r0 []params.Instance - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) ([]params.Instance, error)); ok { - return rf(ctx, repoID) - } - if rf, ok := ret.Get(0).(func(context.Context, string) []params.Instance); ok { - r0 = rf(ctx, repoID) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]params.Instance) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, repoID) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// ListRepoPools provides a mock function with given fields: ctx, repoID -func (_m *Store) ListRepoPools(ctx context.Context, repoID string) ([]params.Pool, error) { - ret := _m.Called(ctx, repoID) - - if len(ret) == 0 { - panic("no return value specified for ListRepoPools") - } - - var r0 []params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) ([]params.Pool, error)); ok { - return rf(ctx, repoID) - } - if rf, ok := ret.Get(0).(func(context.Context, string) []params.Pool); ok { - r0 = rf(ctx, repoID) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]params.Pool) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, repoID) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // ListRepositories provides a mock function with given fields: ctx func (_m *Store) ListRepositories(ctx context.Context) ([]params.Repository, error) { ret := _m.Called(ctx) @@ -1662,27 +1310,27 @@ func (_m *Store) UpdateEnterprise(ctx context.Context, enterpriseID string, para return r0, r1 } -// UpdateEnterprisePool provides a mock function with given fields: ctx, enterpriseID, poolID, param -func (_m *Store) UpdateEnterprisePool(ctx context.Context, enterpriseID string, poolID string, param params.UpdatePoolParams) (params.Pool, error) { - ret := _m.Called(ctx, enterpriseID, poolID, param) +// UpdateEntityPool provides a mock function with given fields: ctx, entity, poolID, param +func (_m *Store) UpdateEntityPool(ctx context.Context, entity params.GithubEntity, poolID string, param params.UpdatePoolParams) (params.Pool, error) { + ret := _m.Called(ctx, entity, poolID, param) if len(ret) == 0 { - panic("no return value specified for UpdateEnterprisePool") + panic("no return value specified for UpdateEntityPool") } var r0 params.Pool var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, params.UpdatePoolParams) (params.Pool, error)); ok { - return rf(ctx, enterpriseID, poolID, param) + if rf, ok := ret.Get(0).(func(context.Context, params.GithubEntity, string, params.UpdatePoolParams) (params.Pool, error)); ok { + return rf(ctx, entity, poolID, param) } - if rf, ok := ret.Get(0).(func(context.Context, string, string, params.UpdatePoolParams) params.Pool); ok { - r0 = rf(ctx, enterpriseID, poolID, param) + if rf, ok := ret.Get(0).(func(context.Context, params.GithubEntity, string, params.UpdatePoolParams) params.Pool); ok { + r0 = rf(ctx, entity, poolID, param) } else { r0 = ret.Get(0).(params.Pool) } - if rf, ok := ret.Get(1).(func(context.Context, string, string, params.UpdatePoolParams) error); ok { - r1 = rf(ctx, enterpriseID, poolID, param) + if rf, ok := ret.Get(1).(func(context.Context, params.GithubEntity, string, params.UpdatePoolParams) error); ok { + r1 = rf(ctx, entity, poolID, param) } else { r1 = ret.Error(1) } @@ -1746,34 +1394,6 @@ func (_m *Store) UpdateOrganization(ctx context.Context, orgID string, param par return r0, r1 } -// UpdateOrganizationPool provides a mock function with given fields: ctx, orgID, poolID, param -func (_m *Store) UpdateOrganizationPool(ctx context.Context, orgID string, poolID string, param params.UpdatePoolParams) (params.Pool, error) { - ret := _m.Called(ctx, orgID, poolID, param) - - if len(ret) == 0 { - panic("no return value specified for UpdateOrganizationPool") - } - - var r0 params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, params.UpdatePoolParams) (params.Pool, error)); ok { - return rf(ctx, orgID, poolID, param) - } - if rf, ok := ret.Get(0).(func(context.Context, string, string, params.UpdatePoolParams) params.Pool); ok { - r0 = rf(ctx, orgID, poolID, param) - } else { - r0 = ret.Get(0).(params.Pool) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, string, params.UpdatePoolParams) error); ok { - r1 = rf(ctx, orgID, poolID, param) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // UpdateRepository provides a mock function with given fields: ctx, repoID, param func (_m *Store) UpdateRepository(ctx context.Context, repoID string, param params.UpdateEntityParams) (params.Repository, error) { ret := _m.Called(ctx, repoID, param) @@ -1802,34 +1422,6 @@ func (_m *Store) UpdateRepository(ctx context.Context, repoID string, param para return r0, r1 } -// UpdateRepositoryPool provides a mock function with given fields: ctx, repoID, poolID, param -func (_m *Store) UpdateRepositoryPool(ctx context.Context, repoID string, poolID string, param params.UpdatePoolParams) (params.Pool, error) { - ret := _m.Called(ctx, repoID, poolID, param) - - if len(ret) == 0 { - panic("no return value specified for UpdateRepositoryPool") - } - - var r0 params.Pool - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, params.UpdatePoolParams) (params.Pool, error)); ok { - return rf(ctx, repoID, poolID, param) - } - if rf, ok := ret.Get(0).(func(context.Context, string, string, params.UpdatePoolParams) params.Pool); ok { - r0 = rf(ctx, repoID, poolID, param) - } else { - r0 = ret.Get(0).(params.Pool) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, string, params.UpdatePoolParams) error); ok { - r1 = rf(ctx, repoID, poolID, param) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // UpdateUser provides a mock function with given fields: ctx, user, param func (_m *Store) UpdateUser(ctx context.Context, user string, param params.UpdateUserParams) (params.User, error) { ret := _m.Called(ctx, user, param) diff --git a/database/sql/enterprise.go b/database/sql/enterprise.go index f83dab8c..3eb53b9e 100644 --- a/database/sql/enterprise.go +++ b/database/sql/enterprise.go @@ -5,7 +5,6 @@ import ( "github.com/google/uuid" "github.com/pkg/errors" - "gorm.io/datatypes" "gorm.io/gorm" runnerErrors "github.com/cloudbase/garm-provider-common/errors" @@ -134,145 +133,6 @@ func (s *sqlDatabase) UpdateEnterprise(ctx context.Context, enterpriseID string, return newParams, nil } -func (s *sqlDatabase) CreateEnterprisePool(ctx context.Context, enterpriseID string, param params.CreatePoolParams) (params.Pool, error) { - if len(param.Tags) == 0 { - return params.Pool{}, runnerErrors.NewBadRequestError("no tags specified") - } - - enterprise, err := s.getEnterpriseByID(ctx, enterpriseID) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching enterprise") - } - - newPool := Pool{ - ProviderName: param.ProviderName, - MaxRunners: param.MaxRunners, - MinIdleRunners: param.MinIdleRunners, - RunnerPrefix: param.GetRunnerPrefix(), - Image: param.Image, - Flavor: param.Flavor, - OSType: param.OSType, - OSArch: param.OSArch, - EnterpriseID: &enterprise.ID, - Enabled: param.Enabled, - RunnerBootstrapTimeout: param.RunnerBootstrapTimeout, - GitHubRunnerGroup: param.GitHubRunnerGroup, - Priority: param.Priority, - } - - if len(param.ExtraSpecs) > 0 { - newPool.ExtraSpecs = datatypes.JSON(param.ExtraSpecs) - } - - _, err = s.getEnterprisePoolByUniqueFields(ctx, enterpriseID, newPool.ProviderName, newPool.Image, newPool.Flavor) - if err != nil { - if !errors.Is(err, runnerErrors.ErrNotFound) { - return params.Pool{}, errors.Wrap(err, "creating pool") - } - } else { - return params.Pool{}, runnerErrors.NewConflictError("pool with the same image and flavor already exists on this provider") - } - - tags := []Tag{} - for _, val := range param.Tags { - t, err := s.getOrCreateTag(val) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching tag") - } - tags = append(tags, t) - } - - q := s.conn.Create(&newPool) - if q.Error != nil { - return params.Pool{}, errors.Wrap(q.Error, "adding pool") - } - - for i := range tags { - if err := s.conn.Model(&newPool).Association("Tags").Append(&tags[i]); err != nil { - return params.Pool{}, errors.Wrap(err, "saving tag") - } - } - - pool, err := s.getPoolByID(ctx, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository") - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - - return s.sqlToCommonPool(pool) -} - -func (s *sqlDatabase) GetEnterprisePool(ctx context.Context, enterpriseID, poolID string) (params.Pool, error) { - pool, err := s.getEntityPool(ctx, params.GithubEntityTypeEnterprise, enterpriseID, poolID, "Tags", "Instances") - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - return s.sqlToCommonPool(pool) -} - -func (s *sqlDatabase) DeleteEnterprisePool(ctx context.Context, enterpriseID, poolID string) error { - pool, err := s.getEntityPool(ctx, params.GithubEntityTypeEnterprise, enterpriseID, poolID) - if err != nil { - return errors.Wrap(err, "looking up enterprise pool") - } - q := s.conn.Unscoped().Delete(&pool) - if q.Error != nil && !errors.Is(q.Error, gorm.ErrRecordNotFound) { - return errors.Wrap(q.Error, "deleting pool") - } - return nil -} - -func (s *sqlDatabase) UpdateEnterprisePool(ctx context.Context, enterpriseID, poolID string, param params.UpdatePoolParams) (params.Pool, error) { - pool, err := s.getEntityPool(ctx, params.GithubEntityTypeEnterprise, enterpriseID, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository") - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - - return s.updatePool(pool, param) -} - -func (s *sqlDatabase) FindEnterprisePoolByTags(_ context.Context, enterpriseID string, tags []string) (params.Pool, error) { - pool, err := s.findPoolByTags(enterpriseID, params.GithubEntityTypeEnterprise, tags) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - return pool[0], nil -} - -func (s *sqlDatabase) ListEnterprisePools(ctx context.Context, enterpriseID string) ([]params.Pool, error) { - pools, err := s.listEntityPools(ctx, params.GithubEntityTypeEnterprise, enterpriseID, "Tags", "Instances", "Enterprise") - if err != nil { - return nil, errors.Wrap(err, "fetching pools") - } - - ret := make([]params.Pool, len(pools)) - for idx, pool := range pools { - ret[idx], err = s.sqlToCommonPool(pool) - if err != nil { - return nil, errors.Wrap(err, "fetching pools") - } - } - - return ret, nil -} - -func (s *sqlDatabase) ListEnterpriseInstances(ctx context.Context, enterpriseID string) ([]params.Instance, error) { - pools, err := s.listEntityPools(ctx, params.GithubEntityTypeEnterprise, enterpriseID, "Instances", "Tags", "Instances.Job") - if err != nil { - return nil, errors.Wrap(err, "fetching enterprise") - } - ret := []params.Instance{} - for _, pool := range pools { - for _, instance := range pool.Instances { - paramsInstance, err := s.sqlToParamsInstance(instance) - if err != nil { - return nil, errors.Wrap(err, "fetching instance") - } - ret = append(ret, paramsInstance) - } - } - return ret, nil -} - func (s *sqlDatabase) getEnterprise(_ context.Context, name string) (Enterprise, error) { var enterprise Enterprise @@ -310,22 +170,3 @@ func (s *sqlDatabase) getEnterpriseByID(_ context.Context, id string, preload .. } return enterprise, nil } - -func (s *sqlDatabase) getEnterprisePoolByUniqueFields(ctx context.Context, enterpriseID string, provider, image, flavor string) (Pool, error) { - enterprise, err := s.getEnterpriseByID(ctx, enterpriseID) - if err != nil { - return Pool{}, errors.Wrap(err, "fetching enterprise") - } - - q := s.conn - var pool []Pool - err = q.Model(&enterprise).Association("Pools").Find(&pool, "provider_name = ? and image = ? and flavor = ?", provider, image, flavor) - if err != nil { - return Pool{}, errors.Wrap(err, "fetching pool") - } - if len(pool) == 0 { - return Pool{}, runnerErrors.ErrNotFound - } - - return pool[0], nil -} diff --git a/database/sql/enterprise_test.go b/database/sql/enterprise_test.go index fa709b89..f77ae3d5 100644 --- a/database/sql/enterprise_test.go +++ b/database/sql/enterprise_test.go @@ -405,7 +405,9 @@ func (s *EnterpriseTestSuite) TestGetEnterpriseByIDDBDecryptingErr() { } func (s *EnterpriseTestSuite) TestCreateEnterprisePool() { - pool, err := s.Store.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) s.Require().Nil(err) @@ -422,56 +424,57 @@ func (s *EnterpriseTestSuite) TestCreateEnterprisePool() { func (s *EnterpriseTestSuite) TestCreateEnterprisePoolMissingTags() { s.Fixtures.CreatePoolParams.Tags = []string{} - - _, err := s.Store.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) + _, err = s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("no tags specified", err.Error()) } func (s *EnterpriseTestSuite) TestCreateEnterprisePoolInvalidEnterpriseID() { - _, err := s.Store.CreateEnterprisePool(context.Background(), "dummy-enterprise-id", s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: "dummy-enterprise-id", + EntityType: params.GithubEntityTypeEnterprise, + } + _, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) - s.Require().Equal("fetching enterprise: parsing id: invalid request", err.Error()) + s.Require().Equal("parsing id: invalid request", err.Error()) } func (s *EnterpriseTestSuite) TestCreateEnterprisePoolDBCreateErr() { + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")). WithArgs(s.Fixtures.Enterprises[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Enterprises[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`enterprise_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and enterprise_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WillReturnError(fmt.Errorf("mocked creating pool error")) - _, err := s.StoreSQLMocked.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("creating pool: fetching pool: mocked creating pool error", err.Error()) + s.Require().Equal("checking pool existence: mocked creating pool error", err.Error()) + s.assertSQLMockExpectations() } func (s *EnterpriseTestSuite) TestCreateEnterpriseDBPoolAlreadyExistErr() { + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")). WithArgs(s.Fixtures.Enterprises[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Enterprises[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`enterprise_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and enterprise_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Enterprises[0].ID). WillReturnRows(sqlmock.NewRows([]string{"enterprise_id", "provider_name", "image", "flavor"}). AddRow( s.Fixtures.Enterprises[0].ID, @@ -479,159 +482,141 @@ func (s *EnterpriseTestSuite) TestCreateEnterpriseDBPoolAlreadyExistErr() { s.Fixtures.CreatePoolParams.Image, s.Fixtures.CreatePoolParams.Flavor)) - _, err := s.StoreSQLMocked.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) s.Require().Equal(runnerErrors.NewConflictError("pool with the same image and flavor already exists on this provider"), err) + s.assertSQLMockExpectations() } func (s *EnterpriseTestSuite) TestCreateEnterprisePoolDBFetchTagErr() { + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")). WithArgs(s.Fixtures.Enterprises[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Enterprises[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`enterprise_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and enterprise_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Enterprises[0].ID). WillReturnRows(sqlmock.NewRows([]string{"enterprise_id"})) s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `tags` WHERE name = ? AND `tags`.`deleted_at` IS NULL ORDER BY `tags`.`id` LIMIT 1")). WillReturnError(fmt.Errorf("mocked fetching tag error")) - _, err := s.StoreSQLMocked.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("fetching tag: fetching tag from database: mocked fetching tag error", err.Error()) + s.Require().Equal("creating tag: fetching tag from database: mocked fetching tag error", err.Error()) + s.assertSQLMockExpectations() } func (s *EnterpriseTestSuite) TestCreateEnterprisePoolDBAddingPoolErr() { s.Fixtures.CreatePoolParams.Tags = []string{"linux"} - - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Enterprises[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")). WithArgs(s.Fixtures.Enterprises[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`enterprise_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and enterprise_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Enterprises[0].ID). WillReturnRows(sqlmock.NewRows([]string{"enterprise_id"})) s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `tags` WHERE name = ? AND `tags`.`deleted_at` IS NULL ORDER BY `tags`.`id` LIMIT 1")). WillReturnRows(sqlmock.NewRows([]string{"linux"})) - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `tags`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `pools`")). WillReturnError(fmt.Errorf("mocked adding pool error")) s.Fixtures.SQLMock.ExpectRollback() - _, err := s.StoreSQLMocked.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("adding pool: mocked adding pool error", err.Error()) + s.Require().Equal("creating pool: mocked adding pool error", err.Error()) + s.assertSQLMockExpectations() } func (s *EnterpriseTestSuite) TestCreateEnterprisePoolDBSaveTagErr() { s.Fixtures.CreatePoolParams.Tags = []string{"linux"} + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")). WithArgs(s.Fixtures.Enterprises[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Enterprises[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`enterprise_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and enterprise_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Enterprises[0].ID). WillReturnRows(sqlmock.NewRows([]string{"enterprise_id"})) s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `tags` WHERE name = ? AND `tags`.`deleted_at` IS NULL ORDER BY `tags`.`id` LIMIT 1")). WillReturnRows(sqlmock.NewRows([]string{"linux"})) - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `tags`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `pools`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("UPDATE `pools` SET")). WillReturnError(fmt.Errorf("mocked saving tag error")) s.Fixtures.SQLMock.ExpectRollback() - _, err := s.StoreSQLMocked.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("saving tag: mocked saving tag error", err.Error()) + s.Require().Equal("associating tags: mocked saving tag error", err.Error()) + s.assertSQLMockExpectations() } func (s *EnterpriseTestSuite) TestCreateEnterprisePoolDBFetchPoolErr() { s.Fixtures.CreatePoolParams.Tags = []string{"linux"} + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")). WithArgs(s.Fixtures.Enterprises[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `enterprises` WHERE id = ? AND `enterprises`.`deleted_at` IS NULL ORDER BY `enterprises`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Enterprises[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Enterprises[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`enterprise_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and enterprise_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Enterprises[0].ID). WillReturnRows(sqlmock.NewRows([]string{"enterprise_id"})) s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `tags` WHERE name = ? AND `tags`.`deleted_at` IS NULL ORDER BY `tags`.`id` LIMIT 1")). WillReturnRows(sqlmock.NewRows([]string{"linux"})) - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `tags`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `pools`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("UPDATE `pools` SET")). WillReturnResult(sqlmock.NewResult(1, 1)) @@ -646,125 +631,121 @@ func (s *EnterpriseTestSuite) TestCreateEnterprisePoolDBFetchPoolErr() { ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE id = ? AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WillReturnRows(sqlmock.NewRows([]string{"id"})) - _, err := s.StoreSQLMocked.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) s.Require().Equal("fetching pool: not found", err.Error()) + s.assertSQLMockExpectations() } func (s *EnterpriseTestSuite) TestListEnterprisePools() { enterprisePools := []params.Pool{} + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) for i := 1; i <= 2; i++ { s.Fixtures.CreatePoolParams.Flavor = fmt.Sprintf("test-flavor-%v", i) - pool, err := s.Store.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) } enterprisePools = append(enterprisePools, pool) } - pools, err := s.Store.ListEnterprisePools(context.Background(), s.Fixtures.Enterprises[0].ID) + pools, err := s.Store.ListEntityPools(context.Background(), entity) s.Require().Nil(err) garmTesting.EqualDBEntityID(s.T(), enterprisePools, pools) } func (s *EnterpriseTestSuite) TestListEnterprisePoolsInvalidEnterpriseID() { - _, err := s.Store.ListEnterprisePools(context.Background(), "dummy-enterprise-id") + entity := params.GithubEntity{ + ID: "dummy-enterprise-id", + EntityType: params.GithubEntityTypeEnterprise, + } + _, err := s.Store.ListEntityPools(context.Background(), entity) s.Require().NotNil(err) s.Require().Equal("fetching pools: parsing id: invalid request", err.Error()) } func (s *EnterpriseTestSuite) TestGetEnterprisePool() { - pool, err := s.Store.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) } - enterprisePool, err := s.Store.GetEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, pool.ID) + enterprisePool, err := s.Store.GetEntityPool(context.Background(), entity, pool.ID) s.Require().Nil(err) s.Require().Equal(enterprisePool.ID, pool.ID) } func (s *EnterpriseTestSuite) TestGetEnterprisePoolInvalidEnterpriseID() { - _, err := s.Store.GetEnterprisePool(context.Background(), "dummy-enterprise-id", "dummy-pool-id") + entity := params.GithubEntity{ + ID: "dummy-enterprise-id", + EntityType: params.GithubEntityTypeEnterprise, + } + _, err := s.Store.GetEntityPool(context.Background(), entity, "dummy-pool-id") s.Require().NotNil(err) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) } func (s *EnterpriseTestSuite) TestDeleteEnterprisePool() { - pool, err := s.Store.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) } - err = s.Store.DeleteEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, pool.ID) + err = s.Store.DeleteEntityPool(context.Background(), entity, pool.ID) s.Require().Nil(err) - _, err = s.Store.GetEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, pool.ID) + _, err = s.Store.GetEntityPool(context.Background(), entity, pool.ID) s.Require().Equal("fetching pool: finding pool: not found", err.Error()) } func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolInvalidEnterpriseID() { - err := s.Store.DeleteEnterprisePool(context.Background(), "dummy-enterprise-id", "dummy-pool-id") + entity := params.GithubEntity{ + ID: "dummy-enterprise-id", + EntityType: params.GithubEntityTypeEnterprise, + } + err := s.Store.DeleteEntityPool(context.Background(), entity, "dummy-pool-id") s.Require().NotNil(err) - s.Require().Equal("looking up enterprise pool: parsing id: invalid request", err.Error()) + s.Require().Equal("parsing id: invalid request", err.Error()) } func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolDBDeleteErr() { - pool, err := s.Store.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) } - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (id = ? and enterprise_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). - WithArgs(pool.ID, s.Fixtures.Enterprises[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"enterprise_id", "id"}).AddRow(s.Fixtures.Enterprises[0].ID, pool.ID)) s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. - ExpectExec(regexp.QuoteMeta("DELETE FROM `pools` WHERE `pools`.`id` = ?")). - WithArgs(pool.ID). + ExpectExec(regexp.QuoteMeta("DELETE FROM `pools` WHERE id = ? and enterprise_id = ?")). + WithArgs(pool.ID, s.Fixtures.Enterprises[0].ID). WillReturnError(fmt.Errorf("mocked deleting pool error")) s.Fixtures.SQLMock.ExpectRollback() - err = s.StoreSQLMocked.DeleteEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, pool.ID) - - s.assertSQLMockExpectations() - s.Require().NotNil(err) - s.Require().Equal("deleting pool: mocked deleting pool error", err.Error()) -} - -func (s *EnterpriseTestSuite) TestFindEnterprisePoolByTags() { - enterprisePool, err := s.Store.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) - if err != nil { - s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) - } - - pool, err := s.Store.FindEnterprisePoolByTags(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams.Tags) - - s.Require().Nil(err) - s.Require().Equal(enterprisePool.ID, pool.ID) - s.Require().Equal(enterprisePool.Image, pool.Image) - s.Require().Equal(enterprisePool.Flavor, pool.Flavor) -} - -func (s *EnterpriseTestSuite) TestFindEnterprisePoolByTagsMissingTags() { - tags := []string{} - - _, err := s.Store.FindEnterprisePoolByTags(context.Background(), s.Fixtures.Enterprises[0].ID, tags) - + err = s.StoreSQLMocked.DeleteEntityPool(context.Background(), entity, pool.ID) s.Require().NotNil(err) - s.Require().Equal("fetching pool: missing tags", err.Error()) + s.Require().Equal("removing pool: mocked deleting pool error", err.Error()) + s.assertSQLMockExpectations() } func (s *EnterpriseTestSuite) TestListEnterpriseInstances() { - pool, err := s.Store.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) } @@ -778,26 +759,32 @@ func (s *EnterpriseTestSuite) TestListEnterpriseInstances() { poolInstances = append(poolInstances, instance) } - instances, err := s.Store.ListEnterpriseInstances(context.Background(), s.Fixtures.Enterprises[0].ID) + instances, err := s.Store.ListEntityInstances(context.Background(), entity) s.Require().Nil(err) s.equalInstancesByName(poolInstances, instances) } func (s *EnterpriseTestSuite) TestListEnterpriseInstancesInvalidEnterpriseID() { - _, err := s.Store.ListEnterpriseInstances(context.Background(), "dummy-enterprise-id") + entity := params.GithubEntity{ + ID: "dummy-enterprise-id", + EntityType: params.GithubEntityTypeEnterprise, + } + _, err := s.Store.ListEntityInstances(context.Background(), entity) s.Require().NotNil(err) - s.Require().Equal("fetching enterprise: parsing id: invalid request", err.Error()) + s.Require().Equal("fetching entity: parsing id: invalid request", err.Error()) } func (s *EnterpriseTestSuite) TestUpdateEnterprisePool() { - pool, err := s.Store.CreateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Enterprises[0].GetEntity() + s.Require().Nil(err) + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) } - pool, err = s.Store.UpdateEnterprisePool(context.Background(), s.Fixtures.Enterprises[0].ID, pool.ID, s.Fixtures.UpdatePoolParams) + pool, err = s.Store.UpdateEntityPool(context.Background(), entity, pool.ID, s.Fixtures.UpdatePoolParams) s.Require().Nil(err) s.Require().Equal(*s.Fixtures.UpdatePoolParams.MaxRunners, pool.MaxRunners) @@ -807,7 +794,11 @@ func (s *EnterpriseTestSuite) TestUpdateEnterprisePool() { } func (s *EnterpriseTestSuite) TestUpdateEnterprisePoolInvalidEnterpriseID() { - _, err := s.Store.UpdateEnterprisePool(context.Background(), "dummy-enterprise-id", "dummy-pool-id", s.Fixtures.UpdatePoolParams) + entity := params.GithubEntity{ + ID: "dummy-enterprise-id", + EntityType: params.GithubEntityTypeEnterprise, + } + _, err := s.Store.UpdateEntityPool(context.Background(), entity, "dummy-pool-id", s.Fixtures.UpdatePoolParams) s.Require().NotNil(err) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) diff --git a/database/sql/instances.go b/database/sql/instances.go index 4c475bf2..d24961da 100644 --- a/database/sql/instances.go +++ b/database/sql/instances.go @@ -48,8 +48,8 @@ func (s *sqlDatabase) unsealAndUnmarshal(data []byte, target interface{}) error return nil } -func (s *sqlDatabase) CreateInstance(ctx context.Context, poolID string, param params.CreateInstanceParams) (params.Instance, error) { - pool, err := s.getPoolByID(ctx, poolID) +func (s *sqlDatabase) CreateInstance(_ context.Context, poolID string, param params.CreateInstanceParams) (params.Instance, error) { + pool, err := s.getPoolByID(s.conn, poolID) if err != nil { return params.Instance{}, errors.Wrap(err, "fetching pool") } @@ -108,8 +108,8 @@ func (s *sqlDatabase) getInstanceByID(_ context.Context, instanceID string) (Ins return instance, nil } -func (s *sqlDatabase) getPoolInstanceByName(ctx context.Context, poolID string, instanceName string) (Instance, error) { - pool, err := s.getPoolByID(ctx, poolID) +func (s *sqlDatabase) getPoolInstanceByName(poolID string, instanceName string) (Instance, error) { + pool, err := s.getPoolByID(s.conn, poolID) if err != nil { return Instance{}, errors.Wrap(err, "fetching pool") } @@ -152,8 +152,8 @@ func (s *sqlDatabase) getInstanceByName(_ context.Context, instanceName string, return instance, nil } -func (s *sqlDatabase) GetPoolInstanceByName(ctx context.Context, poolID string, instanceName string) (params.Instance, error) { - instance, err := s.getPoolInstanceByName(ctx, poolID, instanceName) +func (s *sqlDatabase) GetPoolInstanceByName(_ context.Context, poolID string, instanceName string) (params.Instance, error) { + instance, err := s.getPoolInstanceByName(poolID, instanceName) if err != nil { return params.Instance{}, errors.Wrap(err, "fetching instance") } @@ -170,8 +170,8 @@ func (s *sqlDatabase) GetInstanceByName(ctx context.Context, instanceName string return s.sqlToParamsInstance(instance) } -func (s *sqlDatabase) DeleteInstance(ctx context.Context, poolID string, instanceName string) error { - instance, err := s.getPoolInstanceByName(ctx, poolID, instanceName) +func (s *sqlDatabase) DeleteInstance(_ context.Context, poolID string, instanceName string) error { + instance, err := s.getPoolInstanceByName(poolID, instanceName) if err != nil { return errors.Wrap(err, "deleting instance") } @@ -337,8 +337,8 @@ func (s *sqlDatabase) ListAllInstances(_ context.Context) ([]params.Instance, er return ret, nil } -func (s *sqlDatabase) PoolInstanceCount(ctx context.Context, poolID string) (int64, error) { - pool, err := s.getPoolByID(ctx, poolID) +func (s *sqlDatabase) PoolInstanceCount(_ context.Context, poolID string) (int64, error) { + pool, err := s.getPoolByID(s.conn, poolID) if err != nil { return 0, errors.Wrap(err, "fetching pool") } diff --git a/database/sql/instances_test.go b/database/sql/instances_test.go index 0c0eadcf..b136c8ae 100644 --- a/database/sql/instances_test.go +++ b/database/sql/instances_test.go @@ -92,7 +92,9 @@ func (s *InstancesTestSuite) SetupTest() { OSType: "linux", Tags: []string{"self-hosted", "amd64", "linux"}, } - pool, err := s.Store.CreateOrganizationPool(context.Background(), org.ID, createPoolParams) + entity, err := org.GetEntity() + s.Require().Nil(err) + pool, err := s.Store.CreateEntityPool(context.Background(), entity, createPoolParams) if err != nil { s.FailNow(fmt.Sprintf("failed to create org pool: %s", err)) } diff --git a/database/sql/organizations.go b/database/sql/organizations.go index 4d246065..24704fd9 100644 --- a/database/sql/organizations.go +++ b/database/sql/organizations.go @@ -20,7 +20,6 @@ import ( "github.com/google/uuid" "github.com/pkg/errors" - "gorm.io/datatypes" "gorm.io/gorm" runnerErrors "github.com/cloudbase/garm-provider-common/errors" @@ -151,169 +150,6 @@ func (s *sqlDatabase) GetOrganizationByID(ctx context.Context, orgID string) (pa return param, nil } -func (s *sqlDatabase) CreateOrganizationPool(ctx context.Context, orgID string, param params.CreatePoolParams) (params.Pool, error) { - if len(param.Tags) == 0 { - return params.Pool{}, runnerErrors.NewBadRequestError("no tags specified") - } - - org, err := s.getOrgByID(ctx, orgID) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching org") - } - - newPool := Pool{ - ProviderName: param.ProviderName, - MaxRunners: param.MaxRunners, - MinIdleRunners: param.MinIdleRunners, - RunnerPrefix: param.GetRunnerPrefix(), - Image: param.Image, - Flavor: param.Flavor, - OSType: param.OSType, - OSArch: param.OSArch, - OrgID: &org.ID, - Enabled: param.Enabled, - RunnerBootstrapTimeout: param.RunnerBootstrapTimeout, - GitHubRunnerGroup: param.GitHubRunnerGroup, - Priority: param.Priority, - } - - if len(param.ExtraSpecs) > 0 { - newPool.ExtraSpecs = datatypes.JSON(param.ExtraSpecs) - } - - _, err = s.getOrgPoolByUniqueFields(ctx, orgID, newPool.ProviderName, newPool.Image, newPool.Flavor) - if err != nil { - if !errors.Is(err, runnerErrors.ErrNotFound) { - return params.Pool{}, errors.Wrap(err, "creating pool") - } - } else { - return params.Pool{}, runnerErrors.NewConflictError("pool with the same image and flavor already exists on this provider") - } - - tags := []Tag{} - for _, val := range param.Tags { - t, err := s.getOrCreateTag(val) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching tag") - } - tags = append(tags, t) - } - - q := s.conn.Create(&newPool) - if q.Error != nil { - return params.Pool{}, errors.Wrap(q.Error, "adding pool") - } - - for i := range tags { - if err := s.conn.Model(&newPool).Association("Tags").Append(&tags[i]); err != nil { - return params.Pool{}, errors.Wrap(err, "saving tag") - } - } - - pool, err := s.getPoolByID(ctx, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository") - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - - return s.sqlToCommonPool(pool) -} - -func (s *sqlDatabase) ListOrgPools(ctx context.Context, orgID string) ([]params.Pool, error) { - pools, err := s.listEntityPools(ctx, params.GithubEntityTypeOrganization, orgID, "Tags", "Instances", "Organization") - if err != nil { - return nil, errors.Wrap(err, "fetching pools") - } - - ret := make([]params.Pool, len(pools)) - for idx, pool := range pools { - ret[idx], err = s.sqlToCommonPool(pool) - if err != nil { - return nil, errors.Wrap(err, "fetching pool") - } - } - - return ret, nil -} - -func (s *sqlDatabase) GetOrganizationPool(ctx context.Context, orgID, poolID string) (params.Pool, error) { - pool, err := s.getEntityPool(ctx, params.GithubEntityTypeOrganization, orgID, poolID, "Tags", "Instances") - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - return s.sqlToCommonPool(pool) -} - -func (s *sqlDatabase) DeleteOrganizationPool(ctx context.Context, orgID, poolID string) error { - pool, err := s.getEntityPool(ctx, params.GithubEntityTypeOrganization, orgID, poolID) - if err != nil { - return errors.Wrap(err, "looking up org pool") - } - q := s.conn.Unscoped().Delete(&pool) - if q.Error != nil && !errors.Is(q.Error, gorm.ErrRecordNotFound) { - return errors.Wrap(q.Error, "deleting pool") - } - return nil -} - -func (s *sqlDatabase) FindOrganizationPoolByTags(_ context.Context, orgID string, tags []string) (params.Pool, error) { - pool, err := s.findPoolByTags(orgID, params.GithubEntityTypeOrganization, tags) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - return pool[0], nil -} - -func (s *sqlDatabase) ListOrgInstances(ctx context.Context, orgID string) ([]params.Instance, error) { - pools, err := s.listEntityPools(ctx, params.GithubEntityTypeOrganization, orgID, "Tags", "Instances", "Instances.Job") - if err != nil { - return nil, errors.Wrap(err, "fetching org") - } - ret := []params.Instance{} - for _, pool := range pools { - for _, instance := range pool.Instances { - paramsInstance, err := s.sqlToParamsInstance(instance) - if err != nil { - return nil, errors.Wrap(err, "fetching instance") - } - ret = append(ret, paramsInstance) - } - } - return ret, nil -} - -func (s *sqlDatabase) UpdateOrganizationPool(ctx context.Context, orgID, poolID string, param params.UpdatePoolParams) (params.Pool, error) { - pool, err := s.getEntityPool(ctx, params.GithubEntityTypeOrganization, orgID, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository") - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - - return s.updatePool(pool, param) -} - -func (s *sqlDatabase) getPoolByID(_ context.Context, poolID string, preload ...string) (Pool, error) { - u, err := uuid.Parse(poolID) - if err != nil { - return Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") - } - var pool Pool - q := s.conn.Model(&Pool{}) - if len(preload) > 0 { - for _, item := range preload { - q = q.Preload(item) - } - } - - q = q.Where("id = ?", u).First(&pool) - - if q.Error != nil { - if errors.Is(q.Error, gorm.ErrRecordNotFound) { - return Pool{}, runnerErrors.ErrNotFound - } - return Pool{}, errors.Wrap(q.Error, "fetching org from database") - } - return pool, nil -} - func (s *sqlDatabase) getOrgByID(_ context.Context, id string, preload ...string) (Organization, error) { u, err := uuid.Parse(id) if err != nil { @@ -351,22 +187,3 @@ func (s *sqlDatabase) getOrg(_ context.Context, name string) (Organization, erro } return org, nil } - -func (s *sqlDatabase) getOrgPoolByUniqueFields(ctx context.Context, orgID string, provider, image, flavor string) (Pool, error) { - org, err := s.getOrgByID(ctx, orgID) - if err != nil { - return Pool{}, errors.Wrap(err, "fetching org") - } - - q := s.conn - var pool []Pool - err = q.Model(&org).Association("Pools").Find(&pool, "provider_name = ? and image = ? and flavor = ?", provider, image, flavor) - if err != nil { - return Pool{}, errors.Wrap(err, "fetching pool") - } - if len(pool) == 0 { - return Pool{}, runnerErrors.ErrNotFound - } - - return pool[0], nil -} diff --git a/database/sql/organizations_test.go b/database/sql/organizations_test.go index db4f8ccd..86d13d72 100644 --- a/database/sql/organizations_test.go +++ b/database/sql/organizations_test.go @@ -405,7 +405,9 @@ func (s *OrgTestSuite) TestGetOrganizationByIDDBDecryptingErr() { } func (s *OrgTestSuite) TestCreateOrganizationPool() { - pool, err := s.Store.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) s.Require().Nil(err) @@ -422,56 +424,57 @@ func (s *OrgTestSuite) TestCreateOrganizationPool() { func (s *OrgTestSuite) TestCreateOrganizationPoolMissingTags() { s.Fixtures.CreatePoolParams.Tags = []string{} - - _, err := s.Store.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) + _, err = s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("no tags specified", err.Error()) } func (s *OrgTestSuite) TestCreateOrganizationPoolInvalidOrgID() { - _, err := s.Store.CreateOrganizationPool(context.Background(), "dummy-org-id", s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: "dummy-org-id", + EntityType: params.GithubEntityTypeOrganization, + } + _, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) - s.Require().Equal("fetching org: parsing id: invalid request", err.Error()) + s.Require().Equal("parsing id: invalid request", err.Error()) } func (s *OrgTestSuite) TestCreateOrganizationPoolDBCreateErr() { + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")). WithArgs(s.Fixtures.Orgs[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Orgs[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`org_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and org_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WillReturnError(fmt.Errorf("mocked creating pool error")) - _, err := s.StoreSQLMocked.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("creating pool: fetching pool: mocked creating pool error", err.Error()) + s.Require().Equal("checking pool existence: mocked creating pool error", err.Error()) + s.assertSQLMockExpectations() } func (s *OrgTestSuite) TestCreateOrganizationDBPoolAlreadyExistErr() { + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")). WithArgs(s.Fixtures.Orgs[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Orgs[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`org_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and org_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Orgs[0].ID). WillReturnRows(sqlmock.NewRows([]string{"org_id", "provider_name", "image", "flavor"}). AddRow( s.Fixtures.Orgs[0].ID, @@ -479,159 +482,142 @@ func (s *OrgTestSuite) TestCreateOrganizationDBPoolAlreadyExistErr() { s.Fixtures.CreatePoolParams.Image, s.Fixtures.CreatePoolParams.Flavor)) - _, err := s.StoreSQLMocked.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) s.Require().Equal(runnerErrors.NewConflictError("pool with the same image and flavor already exists on this provider"), err) + s.assertSQLMockExpectations() } func (s *OrgTestSuite) TestCreateOrganizationPoolDBFetchTagErr() { + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")). WithArgs(s.Fixtures.Orgs[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Orgs[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`org_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and org_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Orgs[0].ID). WillReturnRows(sqlmock.NewRows([]string{"org_id"})) s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `tags` WHERE name = ? AND `tags`.`deleted_at` IS NULL ORDER BY `tags`.`id` LIMIT 1")). WillReturnError(fmt.Errorf("mocked fetching tag error")) - _, err := s.StoreSQLMocked.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("fetching tag: fetching tag from database: mocked fetching tag error", err.Error()) + s.Require().Equal("creating tag: fetching tag from database: mocked fetching tag error", err.Error()) + s.assertSQLMockExpectations() } func (s *OrgTestSuite) TestCreateOrganizationPoolDBAddingPoolErr() { s.Fixtures.CreatePoolParams.Tags = []string{"linux"} + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")). WithArgs(s.Fixtures.Orgs[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Orgs[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`org_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and org_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Orgs[0].ID). WillReturnRows(sqlmock.NewRows([]string{"org_id"})) s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `tags` WHERE name = ? AND `tags`.`deleted_at` IS NULL ORDER BY `tags`.`id` LIMIT 1")). WillReturnRows(sqlmock.NewRows([]string{"linux"})) - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `tags`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `pools`")). WillReturnError(fmt.Errorf("mocked adding pool error")) s.Fixtures.SQLMock.ExpectRollback() - _, err := s.StoreSQLMocked.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("adding pool: mocked adding pool error", err.Error()) + s.Require().Equal("creating pool: mocked adding pool error", err.Error()) + s.assertSQLMockExpectations() } func (s *OrgTestSuite) TestCreateOrganizationPoolDBSaveTagErr() { s.Fixtures.CreatePoolParams.Tags = []string{"linux"} + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")). WithArgs(s.Fixtures.Orgs[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Orgs[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`org_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and org_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Orgs[0].ID). WillReturnRows(sqlmock.NewRows([]string{"org_id"})) s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `tags` WHERE name = ? AND `tags`.`deleted_at` IS NULL ORDER BY `tags`.`id` LIMIT 1")). WillReturnRows(sqlmock.NewRows([]string{"linux"})) - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `tags`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `pools`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("UPDATE `pools` SET")). WillReturnError(fmt.Errorf("mocked saving tag error")) s.Fixtures.SQLMock.ExpectRollback() - _, err := s.StoreSQLMocked.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("saving tag: mocked saving tag error", err.Error()) + s.Require().Equal("associating tags: mocked saving tag error", err.Error()) + s.assertSQLMockExpectations() } func (s *OrgTestSuite) TestCreateOrganizationPoolDBFetchPoolErr() { s.Fixtures.CreatePoolParams.Tags = []string{"linux"} + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")). WithArgs(s.Fixtures.Orgs[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `organizations` WHERE id = ? AND `organizations`.`deleted_at` IS NULL ORDER BY `organizations`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Orgs[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Orgs[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`org_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and org_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Orgs[0].ID). WillReturnRows(sqlmock.NewRows([]string{"org_id"})) s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `tags` WHERE name = ? AND `tags`.`deleted_at` IS NULL ORDER BY `tags`.`id` LIMIT 1")). WillReturnRows(sqlmock.NewRows([]string{"linux"})) - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `tags`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `pools`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("UPDATE `pools` SET")). WillReturnResult(sqlmock.NewResult(1, 1)) @@ -646,125 +632,123 @@ func (s *OrgTestSuite) TestCreateOrganizationPoolDBFetchPoolErr() { ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE id = ? AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WillReturnRows(sqlmock.NewRows([]string{"id"})) - _, err := s.StoreSQLMocked.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) + + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) s.Require().Equal("fetching pool: not found", err.Error()) + s.assertSQLMockExpectations() } func (s *OrgTestSuite) TestListOrgPools() { orgPools := []params.Pool{} + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) for i := 1; i <= 2; i++ { s.Fixtures.CreatePoolParams.Flavor = fmt.Sprintf("test-flavor-%v", i) - pool, err := s.Store.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } orgPools = append(orgPools, pool) } - - pools, err := s.Store.ListOrgPools(context.Background(), s.Fixtures.Orgs[0].ID) + pools, err := s.Store.ListEntityPools(context.Background(), entity) s.Require().Nil(err) garmTesting.EqualDBEntityID(s.T(), orgPools, pools) } func (s *OrgTestSuite) TestListOrgPoolsInvalidOrgID() { - _, err := s.Store.ListOrgPools(context.Background(), "dummy-org-id") + entity := params.GithubEntity{ + ID: "dummy-org-id", + EntityType: params.GithubEntityTypeOrganization, + } + _, err := s.Store.ListEntityPools(context.Background(), entity) s.Require().NotNil(err) s.Require().Equal("fetching pools: parsing id: invalid request", err.Error()) } func (s *OrgTestSuite) TestGetOrganizationPool() { - pool, err := s.Store.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } - orgPool, err := s.Store.GetOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, pool.ID) + orgPool, err := s.Store.GetEntityPool(context.Background(), entity, pool.ID) s.Require().Nil(err) s.Require().Equal(orgPool.ID, pool.ID) } func (s *OrgTestSuite) TestGetOrganizationPoolInvalidOrgID() { - _, err := s.Store.GetOrganizationPool(context.Background(), "dummy-org-id", "dummy-pool-id") + entity := params.GithubEntity{ + ID: "dummy-org-id", + EntityType: params.GithubEntityTypeOrganization, + } + _, err := s.Store.GetEntityPool(context.Background(), entity, "dummy-pool-id") s.Require().NotNil(err) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) } func (s *OrgTestSuite) TestDeleteOrganizationPool() { - pool, err := s.Store.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } - err = s.Store.DeleteOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, pool.ID) + err = s.Store.DeleteEntityPool(context.Background(), entity, pool.ID) s.Require().Nil(err) - _, err = s.Store.GetOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, pool.ID) + _, err = s.Store.GetEntityPool(context.Background(), entity, pool.ID) s.Require().Equal("fetching pool: finding pool: not found", err.Error()) } func (s *OrgTestSuite) TestDeleteOrganizationPoolInvalidOrgID() { - err := s.Store.DeleteOrganizationPool(context.Background(), "dummy-org-id", "dummy-pool-id") + entity := params.GithubEntity{ + ID: "dummy-org-id", + EntityType: params.GithubEntityTypeOrganization, + } + err := s.Store.DeleteEntityPool(context.Background(), entity, "dummy-pool-id") s.Require().NotNil(err) - s.Require().Equal("looking up org pool: parsing id: invalid request", err.Error()) + s.Require().Equal("parsing id: invalid request", err.Error()) } func (s *OrgTestSuite) TestDeleteOrganizationPoolDBDeleteErr() { - pool, err := s.Store.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) + + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (id = ? and org_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). - WithArgs(pool.ID, s.Fixtures.Orgs[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"org_id", "id"}).AddRow(s.Fixtures.Orgs[0].ID, pool.ID)) s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. - ExpectExec(regexp.QuoteMeta("DELETE FROM `pools` WHERE `pools`.`id` = ?")). - WithArgs(pool.ID). + ExpectExec(regexp.QuoteMeta("DELETE FROM `pools` WHERE id = ? and org_id = ?")). + WithArgs(pool.ID, s.Fixtures.Orgs[0].ID). WillReturnError(fmt.Errorf("mocked deleting pool error")) s.Fixtures.SQLMock.ExpectRollback() - err = s.StoreSQLMocked.DeleteOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, pool.ID) + err = s.StoreSQLMocked.DeleteEntityPool(context.Background(), entity, pool.ID) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("deleting pool: mocked deleting pool error", err.Error()) -} - -func (s *OrgTestSuite) TestFindOrganizationPoolByTags() { - orgPool, err := s.Store.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) - if err != nil { - s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) - } - - pool, err := s.Store.FindOrganizationPoolByTags(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams.Tags) - - s.Require().Nil(err) - s.Require().Equal(orgPool.ID, pool.ID) - s.Require().Equal(orgPool.Image, pool.Image) - s.Require().Equal(orgPool.Flavor, pool.Flavor) -} - -func (s *OrgTestSuite) TestFindOrganizationPoolByTagsMissingTags() { - tags := []string{} - - _, err := s.Store.FindOrganizationPoolByTags(context.Background(), s.Fixtures.Orgs[0].ID, tags) - - s.Require().NotNil(err) - s.Require().Equal("fetching pool: missing tags", err.Error()) + s.Require().Equal("removing pool: mocked deleting pool error", err.Error()) + s.assertSQLMockExpectations() } func (s *OrgTestSuite) TestListOrgInstances() { - pool, err := s.Store.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } @@ -778,26 +762,32 @@ func (s *OrgTestSuite) TestListOrgInstances() { poolInstances = append(poolInstances, instance) } - instances, err := s.Store.ListOrgInstances(context.Background(), s.Fixtures.Orgs[0].ID) + instances, err := s.Store.ListEntityInstances(context.Background(), entity) s.Require().Nil(err) s.equalInstancesByName(poolInstances, instances) } func (s *OrgTestSuite) TestListOrgInstancesInvalidOrgID() { - _, err := s.Store.ListOrgInstances(context.Background(), "dummy-org-id") + entity := params.GithubEntity{ + ID: "dummy-org-id", + EntityType: params.GithubEntityTypeOrganization, + } + _, err := s.Store.ListEntityInstances(context.Background(), entity) s.Require().NotNil(err) - s.Require().Equal("fetching org: parsing id: invalid request", err.Error()) + s.Require().Equal("fetching entity: parsing id: invalid request", err.Error()) } func (s *OrgTestSuite) TestUpdateOrganizationPool() { - pool, err := s.Store.CreateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Orgs[0].GetEntity() + s.Require().Nil(err) + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } - pool, err = s.Store.UpdateOrganizationPool(context.Background(), s.Fixtures.Orgs[0].ID, pool.ID, s.Fixtures.UpdatePoolParams) + pool, err = s.Store.UpdateEntityPool(context.Background(), entity, pool.ID, s.Fixtures.UpdatePoolParams) s.Require().Nil(err) s.Require().Equal(*s.Fixtures.UpdatePoolParams.MaxRunners, pool.MaxRunners) @@ -807,7 +797,11 @@ func (s *OrgTestSuite) TestUpdateOrganizationPool() { } func (s *OrgTestSuite) TestUpdateOrganizationPoolInvalidOrgID() { - _, err := s.Store.UpdateOrganizationPool(context.Background(), "dummy-org-id", "dummy-pool-id", s.Fixtures.UpdatePoolParams) + entity := params.GithubEntity{ + ID: "dummy-org-id", + EntityType: params.GithubEntityTypeOrganization, + } + _, err := s.Store.UpdateEntityPool(context.Background(), entity, "dummy-pool-id", s.Fixtures.UpdatePoolParams) s.Require().NotNil(err) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) diff --git a/database/sql/pools.go b/database/sql/pools.go index 65aca8ba..1dd7e68d 100644 --- a/database/sql/pools.go +++ b/database/sql/pools.go @@ -20,6 +20,7 @@ import ( "github.com/google/uuid" "github.com/pkg/errors" + "gorm.io/datatypes" "gorm.io/gorm" runnerErrors "github.com/cloudbase/garm-provider-common/errors" @@ -57,16 +58,16 @@ func (s *sqlDatabase) ListAllPools(_ context.Context) ([]params.Pool, error) { return ret, nil } -func (s *sqlDatabase) GetPoolByID(ctx context.Context, poolID string) (params.Pool, error) { - pool, err := s.getPoolByID(ctx, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository") +func (s *sqlDatabase) GetPoolByID(_ context.Context, poolID string) (params.Pool, error) { + pool, err := s.getPoolByID(s.conn, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository") if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool by ID") } return s.sqlToCommonPool(pool) } -func (s *sqlDatabase) DeletePoolByID(ctx context.Context, poolID string) error { - pool, err := s.getPoolByID(ctx, poolID) +func (s *sqlDatabase) DeletePoolByID(_ context.Context, poolID string) error { + pool, err := s.getPoolByID(s.conn, poolID) if err != nil { return errors.Wrap(err, "fetching pool by ID") } @@ -78,7 +79,7 @@ func (s *sqlDatabase) DeletePoolByID(ctx context.Context, poolID string) error { return nil } -func (s *sqlDatabase) getEntityPool(_ context.Context, entityType params.GithubEntityType, entityID, poolID string, preload ...string) (Pool, error) { +func (s *sqlDatabase) getEntityPool(tx *gorm.DB, entityType params.GithubEntityType, entityID, poolID string, preload ...string) (Pool, error) { if entityID == "" { return Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "missing entity id") } @@ -88,25 +89,30 @@ func (s *sqlDatabase) getEntityPool(_ context.Context, entityType params.GithubE return Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") } - q := s.conn - if len(preload) > 0 { - for _, item := range preload { - q = q.Preload(item) - } - } - var fieldName string + var entityField string switch entityType { case params.GithubEntityTypeRepository: fieldName = entityTypeRepoName + entityField = "Repository" case params.GithubEntityTypeOrganization: fieldName = entityTypeOrgName + entityField = "Organization" case params.GithubEntityTypeEnterprise: fieldName = entityTypeEnterpriseName + entityField = "Enterprise" default: return Pool{}, fmt.Errorf("invalid entityType: %v", entityType) } + q := tx + q = q.Preload(entityField) + if len(preload) > 0 { + for _, item := range preload { + q = q.Preload(item) + } + } + var pool Pool condition := fmt.Sprintf("id = ? and %s = ?", fieldName) err = q.Model(&Pool{}). @@ -122,30 +128,39 @@ func (s *sqlDatabase) getEntityPool(_ context.Context, entityType params.GithubE return pool, nil } -func (s *sqlDatabase) listEntityPools(_ context.Context, entityType params.GithubEntityType, entityID string, preload ...string) ([]Pool, error) { +func (s *sqlDatabase) listEntityPools(tx *gorm.DB, entityType params.GithubEntityType, entityID string, preload ...string) ([]Pool, error) { if _, err := uuid.Parse(entityID); err != nil { return nil, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") } - q := s.conn - if len(preload) > 0 { - for _, item := range preload { - q = q.Preload(item) - } + if err := s.hasGithubEntity(tx, entityType, entityID); err != nil { + return nil, errors.Wrap(err, "checking entity existence") } + var preloadEntity string var fieldName string switch entityType { case params.GithubEntityTypeRepository: fieldName = entityTypeRepoName + preloadEntity = "Repository" case params.GithubEntityTypeOrganization: fieldName = entityTypeOrgName + preloadEntity = "Organization" case params.GithubEntityTypeEnterprise: fieldName = entityTypeEnterpriseName + preloadEntity = "Enterprise" default: return nil, fmt.Errorf("invalid entityType: %v", entityType) } + q := tx + q = q.Preload(preloadEntity) + if len(preload) > 0 { + for _, item := range preload { + q = q.Preload(item) + } + } + var pools []Pool condition := fmt.Sprintf("%s = ?", fieldName) err := q.Model(&Pool{}). @@ -231,3 +246,176 @@ func (s *sqlDatabase) FindPoolsMatchingAllTags(_ context.Context, entityType par return pools, nil } + +func (s *sqlDatabase) CreateEntityPool(_ context.Context, entity params.GithubEntity, param params.CreatePoolParams) (params.Pool, error) { + if len(param.Tags) == 0 { + return params.Pool{}, runnerErrors.NewBadRequestError("no tags specified") + } + + newPool := Pool{ + ProviderName: param.ProviderName, + MaxRunners: param.MaxRunners, + MinIdleRunners: param.MinIdleRunners, + RunnerPrefix: param.GetRunnerPrefix(), + Image: param.Image, + Flavor: param.Flavor, + OSType: param.OSType, + OSArch: param.OSArch, + Enabled: param.Enabled, + RunnerBootstrapTimeout: param.RunnerBootstrapTimeout, + GitHubRunnerGroup: param.GitHubRunnerGroup, + Priority: param.Priority, + } + if len(param.ExtraSpecs) > 0 { + newPool.ExtraSpecs = datatypes.JSON(param.ExtraSpecs) + } + + entityID, err := uuid.Parse(entity.ID) + if err != nil { + return params.Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") + } + + switch entity.EntityType { + case params.GithubEntityTypeRepository: + newPool.RepoID = &entityID + case params.GithubEntityTypeOrganization: + newPool.OrgID = &entityID + case params.GithubEntityTypeEnterprise: + newPool.EnterpriseID = &entityID + } + err = s.conn.Transaction(func(tx *gorm.DB) error { + if err := s.hasGithubEntity(tx, entity.EntityType, entity.ID); err != nil { + return errors.Wrap(err, "checking entity existence") + } + + if _, err := s.getEntityPoolByUniqueFields(tx, entity, newPool.ProviderName, newPool.Image, newPool.Flavor); err != nil { + if !errors.Is(err, runnerErrors.ErrNotFound) { + return errors.Wrap(err, "checking pool existence") + } + } else { + return runnerErrors.NewConflictError("pool with the same image and flavor already exists on this provider") + } + + tags := []Tag{} + for _, val := range param.Tags { + t, err := s.getOrCreateTag(tx, val) + if err != nil { + return errors.Wrap(err, "creating tag") + } + tags = append(tags, t) + } + + q := tx.Create(&newPool) + if q.Error != nil { + return errors.Wrap(q.Error, "creating pool") + } + + for i := range tags { + if err := tx.Model(&newPool).Association("Tags").Append(&tags[i]); err != nil { + return errors.Wrap(err, "associating tags") + } + } + return nil + }) + if err != nil { + return params.Pool{}, err + } + + pool, err := s.getPoolByID(s.conn, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository") + if err != nil { + return params.Pool{}, errors.Wrap(err, "fetching pool") + } + + return s.sqlToCommonPool(pool) +} + +func (s *sqlDatabase) GetEntityPool(_ context.Context, entity params.GithubEntity, poolID string) (params.Pool, error) { + pool, err := s.getEntityPool(s.conn, entity.EntityType, entity.ID, poolID, "Tags", "Instances") + if err != nil { + return params.Pool{}, fmt.Errorf("fetching pool: %w", err) + } + return s.sqlToCommonPool(pool) +} + +func (s *sqlDatabase) DeleteEntityPool(_ context.Context, entity params.GithubEntity, poolID string) error { + entityID, err := uuid.Parse(entity.ID) + if err != nil { + return errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") + } + + poolUUID, err := uuid.Parse(poolID) + if err != nil { + return errors.Wrap(runnerErrors.ErrBadRequest, "parsing pool id") + } + var fieldName string + switch entity.EntityType { + case params.GithubEntityTypeRepository: + fieldName = entityTypeRepoName + case params.GithubEntityTypeOrganization: + fieldName = entityTypeOrgName + case params.GithubEntityTypeEnterprise: + fieldName = entityTypeEnterpriseName + default: + return fmt.Errorf("invalid entityType: %v", entity.EntityType) + } + condition := fmt.Sprintf("id = ? and %s = ?", fieldName) + if err := s.conn.Unscoped().Where(condition, poolUUID, entityID).Delete(&Pool{}).Error; err != nil { + return errors.Wrap(err, "removing pool") + } + return nil +} + +func (s *sqlDatabase) UpdateEntityPool(_ context.Context, entity params.GithubEntity, poolID string, param params.UpdatePoolParams) (params.Pool, error) { + var updatedPool params.Pool + err := s.conn.Transaction(func(tx *gorm.DB) error { + pool, err := s.getEntityPool(tx, entity.EntityType, entity.ID, poolID, "Tags", "Instances") + if err != nil { + return errors.Wrap(err, "fetching pool") + } + + updatedPool, err = s.updatePool(tx, pool, param) + if err != nil { + return errors.Wrap(err, "updating pool") + } + return nil + }) + if err != nil { + return params.Pool{}, err + } + return updatedPool, nil +} + +func (s *sqlDatabase) ListEntityPools(_ context.Context, entity params.GithubEntity) ([]params.Pool, error) { + pools, err := s.listEntityPools(s.conn, entity.EntityType, entity.ID, "Tags") + if err != nil { + return nil, errors.Wrap(err, "fetching pools") + } + + ret := make([]params.Pool, len(pools)) + for idx, pool := range pools { + ret[idx], err = s.sqlToCommonPool(pool) + if err != nil { + return nil, errors.Wrap(err, "fetching pool") + } + } + + return ret, nil +} + +func (s *sqlDatabase) ListEntityInstances(_ context.Context, entity params.GithubEntity) ([]params.Instance, error) { + pools, err := s.listEntityPools(s.conn, entity.EntityType, entity.ID, "Instances", "Instances.Job") + if err != nil { + return nil, errors.Wrap(err, "fetching entity") + } + ret := []params.Instance{} + for _, pool := range pools { + for _, instance := range pool.Instances { + paramsInstance, err := s.sqlToParamsInstance(instance) + if err != nil { + return nil, errors.Wrap(err, "fetching instance") + } + ret = append(ret, paramsInstance) + } + } + return ret, nil +} diff --git a/database/sql/pools_test.go b/database/sql/pools_test.go index 33fe8725..c05711cb 100644 --- a/database/sql/pools_test.go +++ b/database/sql/pools_test.go @@ -66,12 +66,14 @@ func (s *PoolsTestSuite) SetupTest() { s.FailNow(fmt.Sprintf("failed to create org: %s", err)) } + entity, err := org.GetEntity() + s.Require().Nil(err) // create some pool objects in the database, for testing purposes orgPools := []params.Pool{} for i := 1; i <= 3; i++ { - pool, err := db.CreateOrganizationPool( + pool, err := db.CreateEntityPool( context.Background(), - org.ID, + entity, params.CreatePoolParams{ ProviderName: "test-provider", MaxRunners: 4, diff --git a/database/sql/repositories.go b/database/sql/repositories.go index f7671840..164c0197 100644 --- a/database/sql/repositories.go +++ b/database/sql/repositories.go @@ -20,7 +20,6 @@ import ( "github.com/google/uuid" "github.com/pkg/errors" - "gorm.io/datatypes" "gorm.io/gorm" runnerErrors "github.com/cloudbase/garm-provider-common/errors" @@ -151,146 +150,6 @@ func (s *sqlDatabase) GetRepositoryByID(ctx context.Context, repoID string) (par return param, nil } -func (s *sqlDatabase) CreateRepositoryPool(ctx context.Context, repoID string, param params.CreatePoolParams) (params.Pool, error) { - if len(param.Tags) == 0 { - return params.Pool{}, runnerErrors.NewBadRequestError("no tags specified") - } - - repo, err := s.getRepoByID(ctx, repoID) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching repo") - } - - newPool := Pool{ - ProviderName: param.ProviderName, - MaxRunners: param.MaxRunners, - MinIdleRunners: param.MinIdleRunners, - RunnerPrefix: param.GetRunnerPrefix(), - Image: param.Image, - Flavor: param.Flavor, - OSType: param.OSType, - OSArch: param.OSArch, - RepoID: &repo.ID, - Enabled: param.Enabled, - RunnerBootstrapTimeout: param.RunnerBootstrapTimeout, - GitHubRunnerGroup: param.GitHubRunnerGroup, - Priority: param.Priority, - } - - if len(param.ExtraSpecs) > 0 { - newPool.ExtraSpecs = datatypes.JSON(param.ExtraSpecs) - } - - _, err = s.getRepoPoolByUniqueFields(ctx, repoID, newPool.ProviderName, newPool.Image, newPool.Flavor) - if err != nil { - if !errors.Is(err, runnerErrors.ErrNotFound) { - return params.Pool{}, errors.Wrap(err, "creating pool") - } - } else { - return params.Pool{}, runnerErrors.NewConflictError("pool with the same image and flavor already exists on this provider") - } - - tags := []Tag{} - for _, val := range param.Tags { - t, err := s.getOrCreateTag(val) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching tag") - } - tags = append(tags, t) - } - - q := s.conn.Create(&newPool) - if q.Error != nil { - return params.Pool{}, errors.Wrap(q.Error, "adding pool") - } - - for i := range tags { - if err := s.conn.Model(&newPool).Association("Tags").Append(&tags[i]); err != nil { - return params.Pool{}, errors.Wrap(err, "saving tag") - } - } - - pool, err := s.getPoolByID(ctx, newPool.ID.String(), "Tags", "Instances", "Enterprise", "Organization", "Repository") - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - - return s.sqlToCommonPool(pool) -} - -func (s *sqlDatabase) ListRepoPools(ctx context.Context, repoID string) ([]params.Pool, error) { - pools, err := s.listEntityPools(ctx, params.GithubEntityTypeRepository, repoID, "Tags", "Instances", "Repository") - if err != nil { - return nil, errors.Wrap(err, "fetching pools") - } - - ret := make([]params.Pool, len(pools)) - for idx, pool := range pools { - ret[idx], err = s.sqlToCommonPool(pool) - if err != nil { - return nil, errors.Wrap(err, "fetching pool") - } - } - - return ret, nil -} - -func (s *sqlDatabase) GetRepositoryPool(ctx context.Context, repoID, poolID string) (params.Pool, error) { - pool, err := s.getEntityPool(ctx, params.GithubEntityTypeRepository, repoID, poolID, "Tags", "Instances") - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - return s.sqlToCommonPool(pool) -} - -func (s *sqlDatabase) DeleteRepositoryPool(ctx context.Context, repoID, poolID string) error { - pool, err := s.getEntityPool(ctx, params.GithubEntityTypeRepository, repoID, poolID) - if err != nil { - return errors.Wrap(err, "looking up repo pool") - } - q := s.conn.Unscoped().Delete(&pool) - if q.Error != nil && !errors.Is(q.Error, gorm.ErrRecordNotFound) { - return errors.Wrap(q.Error, "deleting pool") - } - return nil -} - -func (s *sqlDatabase) FindRepositoryPoolByTags(_ context.Context, repoID string, tags []string) (params.Pool, error) { - pool, err := s.findPoolByTags(repoID, params.GithubEntityTypeRepository, tags) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - return pool[0], nil -} - -func (s *sqlDatabase) ListRepoInstances(ctx context.Context, repoID string) ([]params.Instance, error) { - pools, err := s.listEntityPools(ctx, params.GithubEntityTypeRepository, repoID, "Tags", "Instances", "Instances.Job") - if err != nil { - return nil, errors.Wrap(err, "fetching repo") - } - - ret := []params.Instance{} - for _, pool := range pools { - for _, instance := range pool.Instances { - paramsInstance, err := s.sqlToParamsInstance(instance) - if err != nil { - return nil, errors.Wrap(err, "fetching instance") - } - ret = append(ret, paramsInstance) - } - } - return ret, nil -} - -func (s *sqlDatabase) UpdateRepositoryPool(ctx context.Context, repoID, poolID string, param params.UpdatePoolParams) (params.Pool, error) { - pool, err := s.getEntityPool(ctx, params.GithubEntityTypeRepository, repoID, poolID, "Tags", "Instances", "Enterprise", "Organization", "Repository") - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool") - } - - return s.updatePool(pool, param) -} - func (s *sqlDatabase) getRepo(_ context.Context, owner, name string) (Repository, error) { var repo Repository @@ -308,23 +167,29 @@ func (s *sqlDatabase) getRepo(_ context.Context, owner, name string) (Repository return repo, nil } -func (s *sqlDatabase) getRepoPoolByUniqueFields(ctx context.Context, repoID string, provider, image, flavor string) (Pool, error) { - repo, err := s.getRepoByID(ctx, repoID) - if err != nil { - return Pool{}, errors.Wrap(err, "fetching repo") +func (s *sqlDatabase) getEntityPoolByUniqueFields(tx *gorm.DB, entity params.GithubEntity, provider, image, flavor string) (pool Pool, err error) { + var entityField string + switch entity.EntityType { + case params.GithubEntityTypeRepository: + entityField = entityTypeRepoName + case params.GithubEntityTypeOrganization: + entityField = entityTypeOrgName + case params.GithubEntityTypeEnterprise: + entityField = entityTypeEnterpriseName } - - q := s.conn - var pool []Pool - err = q.Model(&repo).Association("Pools").Find(&pool, "provider_name = ? and image = ? and flavor = ?", provider, image, flavor) + entityID, err := uuid.Parse(entity.ID) if err != nil { - return Pool{}, errors.Wrap(err, "fetching pool") + return pool, fmt.Errorf("parsing entity ID: %w", err) } - if len(pool) == 0 { - return Pool{}, runnerErrors.ErrNotFound + poolQueryString := fmt.Sprintf("provider_name = ? and image = ? and flavor = ? and %s = ?", entityField) + err = tx.Where(poolQueryString, provider, image, flavor, entityID).First(&pool).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return pool, runnerErrors.ErrNotFound + } + return } - - return pool[0], nil + return Pool{}, nil } func (s *sqlDatabase) getRepoByID(_ context.Context, id string, preload ...string) (Repository, error) { diff --git a/database/sql/repositories_test.go b/database/sql/repositories_test.go index 796048ea..18126197 100644 --- a/database/sql/repositories_test.go +++ b/database/sql/repositories_test.go @@ -443,7 +443,9 @@ func (s *RepoTestSuite) TestGetRepositoryByIDDBDecryptingErr() { } func (s *RepoTestSuite) TestCreateRepositoryPool() { - pool, err := s.Store.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) s.Require().Nil(err) repo, err := s.Store.GetRepositoryByID(context.Background(), s.Fixtures.Repos[0].ID) @@ -459,56 +461,57 @@ func (s *RepoTestSuite) TestCreateRepositoryPool() { func (s *RepoTestSuite) TestCreateRepositoryPoolMissingTags() { s.Fixtures.CreatePoolParams.Tags = []string{} - - _, err := s.Store.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) + _, err = s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) s.Require().Equal("no tags specified", err.Error()) } func (s *RepoTestSuite) TestCreateRepositoryPoolInvalidRepoID() { - _, err := s.Store.CreateRepositoryPool(context.Background(), "dummy-repo-id", s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: "dummy-repo-id", + EntityType: params.GithubEntityTypeRepository, + } + _, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) - s.Require().Equal("fetching repo: parsing id: invalid request", err.Error()) + s.Require().Equal("parsing id: invalid request", err.Error()) } func (s *RepoTestSuite) TestCreateRepositoryPoolDBCreateErr() { + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT 1")). WithArgs(s.Fixtures.Repos[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Repos[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`repo_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and repo_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WillReturnError(fmt.Errorf("mocked creating pool error")) - _, err := s.StoreSQLMocked.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("creating pool: fetching pool: mocked creating pool error", err.Error()) + s.Require().Equal("checking pool existence: mocked creating pool error", err.Error()) + s.assertSQLMockExpectations() } func (s *RepoTestSuite) TestCreateRepositoryPoolDBPoolAlreadyExistErr() { + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT 1")). WithArgs(s.Fixtures.Repos[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Repos[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`repo_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and repo_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Repos[0].ID). WillReturnRows(sqlmock.NewRows([]string{"repo_id", "provider_name", "image", "flavor"}). AddRow( s.Fixtures.Repos[0].ID, @@ -516,159 +519,145 @@ func (s *RepoTestSuite) TestCreateRepositoryPoolDBPoolAlreadyExistErr() { s.Fixtures.CreatePoolParams.Image, s.Fixtures.CreatePoolParams.Flavor)) - _, err := s.StoreSQLMocked.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) + + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) s.Require().Equal("pool with the same image and flavor already exists on this provider", err.Error()) + s.assertSQLMockExpectations() } func (s *RepoTestSuite) TestCreateRepositoryPoolDBFetchTagErr() { + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT 1")). WithArgs(s.Fixtures.Repos[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Repos[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`repo_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and repo_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Repos[0].ID). WillReturnRows(sqlmock.NewRows([]string{"repo_id"})) s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `tags` WHERE name = ? AND `tags`.`deleted_at` IS NULL ORDER BY `tags`.`id` LIMIT 1")). WillReturnError(fmt.Errorf("mocked fetching tag error")) - _, err := s.StoreSQLMocked.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) + + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("fetching tag: fetching tag from database: mocked fetching tag error", err.Error()) + s.Require().Equal("creating tag: fetching tag from database: mocked fetching tag error", err.Error()) + s.assertSQLMockExpectations() } func (s *RepoTestSuite) TestCreateRepositoryPoolDBAddingPoolErr() { s.Fixtures.CreatePoolParams.Tags = []string{"linux"} + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT 1")). WithArgs(s.Fixtures.Repos[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Repos[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`repo_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and repo_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Repos[0].ID). WillReturnRows(sqlmock.NewRows([]string{"repo_id"})) s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `tags` WHERE name = ? AND `tags`.`deleted_at` IS NULL ORDER BY `tags`.`id` LIMIT 1")). WillReturnRows(sqlmock.NewRows([]string{"linux"})) - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `tags`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `pools`")). WillReturnError(fmt.Errorf("mocked adding pool error")) s.Fixtures.SQLMock.ExpectRollback() - _, err := s.StoreSQLMocked.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) + + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) - s.Require().Equal("adding pool: mocked adding pool error", err.Error()) + s.Require().Equal("creating pool: mocked adding pool error", err.Error()) + s.assertSQLMockExpectations() } func (s *RepoTestSuite) TestCreateRepositoryPoolDBSaveTagErr() { s.Fixtures.CreatePoolParams.Tags = []string{"linux"} + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT 1")). WithArgs(s.Fixtures.Repos[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Repos[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`repo_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and repo_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Repos[0].ID). WillReturnRows(sqlmock.NewRows([]string{"repo_id"})) s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `tags` WHERE name = ? AND `tags`.`deleted_at` IS NULL ORDER BY `tags`.`id` LIMIT 1")). WillReturnRows(sqlmock.NewRows([]string{"linux"})) - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `tags`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `pools`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("UPDATE `pools` SET")). WillReturnError(fmt.Errorf("mocked saving tag error")) s.Fixtures.SQLMock.ExpectRollback() - _, err := s.StoreSQLMocked.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) - s.assertSQLMockExpectations() + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) s.Require().NotNil(err) - s.Require().Equal("saving tag: mocked saving tag error", err.Error()) + s.Require().Equal("associating tags: mocked saving tag error", err.Error()) + s.assertSQLMockExpectations() } func (s *RepoTestSuite) TestCreateRepositoryPoolDBFetchPoolErr() { s.Fixtures.CreatePoolParams.Tags = []string{"linux"} + s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT 1")). WithArgs(s.Fixtures.Repos[0].ID). WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `repositories` WHERE id = ? AND `repositories`.`deleted_at` IS NULL ORDER BY `repositories`.`id` LIMIT 1")). - WithArgs(s.Fixtures.Repos[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(s.Fixtures.Repos[0].ID)) - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE `pools`.`repo_id` = ? AND (provider_name = ? and image = ? and flavor = ?) AND `pools`.`deleted_at` IS NULL")). + ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (provider_name = ? and image = ? and flavor = ? and repo_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WithArgs( - s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams.ProviderName, s.Fixtures.CreatePoolParams.Image, - s.Fixtures.CreatePoolParams.Flavor). + s.Fixtures.CreatePoolParams.Flavor, + s.Fixtures.Repos[0].ID). WillReturnRows(sqlmock.NewRows([]string{"repo_id"})) s.Fixtures.SQLMock. ExpectQuery(regexp.QuoteMeta("SELECT * FROM `tags` WHERE name = ? AND `tags`.`deleted_at` IS NULL ORDER BY `tags`.`id` LIMIT 1")). WillReturnRows(sqlmock.NewRows([]string{"linux"})) - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `tags`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("INSERT INTO `pools`")). WillReturnResult(sqlmock.NewResult(1, 1)) - s.Fixtures.SQLMock.ExpectCommit() - s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. ExpectExec(regexp.QuoteMeta("UPDATE `pools` SET")). WillReturnResult(sqlmock.NewResult(1, 1)) @@ -683,124 +672,123 @@ func (s *RepoTestSuite) TestCreateRepositoryPoolDBFetchPoolErr() { ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE id = ? AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). WillReturnRows(sqlmock.NewRows([]string{"id"})) - _, err := s.StoreSQLMocked.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) + + _, err = s.StoreSQLMocked.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) - s.assertSQLMockExpectations() s.Require().NotNil(err) s.Require().Equal("fetching pool: not found", err.Error()) + s.assertSQLMockExpectations() } func (s *RepoTestSuite) TestListRepoPools() { + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) repoPools := []params.Pool{} for i := 1; i <= 2; i++ { s.Fixtures.CreatePoolParams.Flavor = fmt.Sprintf("test-flavor-%d", i) - pool, err := s.Store.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) } repoPools = append(repoPools, pool) } - pools, err := s.Store.ListRepoPools(context.Background(), s.Fixtures.Repos[0].ID) + pools, err := s.Store.ListEntityPools(context.Background(), entity) s.Require().Nil(err) garmTesting.EqualDBEntityID(s.T(), repoPools, pools) } func (s *RepoTestSuite) TestListRepoPoolsInvalidRepoID() { - _, err := s.Store.ListRepoPools(context.Background(), "dummy-repo-id") + entity := params.GithubEntity{ + ID: "dummy-repo-id", + EntityType: params.GithubEntityTypeRepository, + } + _, err := s.Store.ListEntityPools(context.Background(), entity) s.Require().NotNil(err) s.Require().Equal("fetching pools: parsing id: invalid request", err.Error()) } func (s *RepoTestSuite) TestGetRepositoryPool() { - pool, err := s.Store.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) } - repoPool, err := s.Store.GetRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, pool.ID) + repoPool, err := s.Store.GetEntityPool(context.Background(), entity, pool.ID) s.Require().Nil(err) s.Require().Equal(repoPool.ID, pool.ID) } func (s *RepoTestSuite) TestGetRepositoryPoolInvalidRepoID() { - _, err := s.Store.GetRepositoryPool(context.Background(), "dummy-repo-id", "dummy-pool-id") + entity := params.GithubEntity{ + ID: "dummy-repo-id", + EntityType: params.GithubEntityTypeRepository, + } + _, err := s.Store.GetEntityPool(context.Background(), entity, "dummy-pool-id") s.Require().NotNil(err) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) } func (s *RepoTestSuite) TestDeleteRepositoryPool() { - pool, err := s.Store.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) } - err = s.Store.DeleteRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, pool.ID) + err = s.Store.DeleteEntityPool(context.Background(), entity, pool.ID) s.Require().Nil(err) - _, err = s.Store.GetOrganizationPool(context.Background(), s.Fixtures.Repos[0].ID, pool.ID) + _, err = s.Store.GetEntityPool(context.Background(), entity, pool.ID) s.Require().Equal("fetching pool: finding pool: not found", err.Error()) } func (s *RepoTestSuite) TestDeleteRepositoryPoolInvalidRepoID() { - err := s.Store.DeleteRepositoryPool(context.Background(), "dummy-repo-id", "dummy-pool-id") + entity := params.GithubEntity{ + ID: "dummy-repo-id", + EntityType: params.GithubEntityTypeRepository, + } + err := s.Store.DeleteEntityPool(context.Background(), entity, "dummy-pool-id") s.Require().NotNil(err) - s.Require().Equal("looking up repo pool: parsing id: invalid request", err.Error()) + s.Require().Equal("parsing id: invalid request", err.Error()) } func (s *RepoTestSuite) TestDeleteRepositoryPoolDBDeleteErr() { - pool, err := s.Store.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) + + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) } - s.Fixtures.SQLMock. - ExpectQuery(regexp.QuoteMeta("SELECT * FROM `pools` WHERE (id = ? and repo_id = ?) AND `pools`.`deleted_at` IS NULL ORDER BY `pools`.`id` LIMIT 1")). - WithArgs(pool.ID, s.Fixtures.Repos[0].ID). - WillReturnRows(sqlmock.NewRows([]string{"repo_id", "id"}).AddRow(s.Fixtures.Repos[0].ID, pool.ID)) s.Fixtures.SQLMock.ExpectBegin() s.Fixtures.SQLMock. - ExpectExec(regexp.QuoteMeta("DELETE FROM `pools` WHERE `pools`.`id` = ?")). - WithArgs(pool.ID). + ExpectExec(regexp.QuoteMeta("DELETE FROM `pools` WHERE id = ? and repo_id = ?")). + WithArgs(pool.ID, s.Fixtures.Repos[0].ID). WillReturnError(fmt.Errorf("mocked deleting pool error")) s.Fixtures.SQLMock.ExpectRollback() - err = s.StoreSQLMocked.DeleteRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, pool.ID) - - s.assertSQLMockExpectations() + err = s.StoreSQLMocked.DeleteEntityPool(context.Background(), entity, pool.ID) s.Require().NotNil(err) - s.Require().Equal("deleting pool: mocked deleting pool error", err.Error()) -} - -func (s *RepoTestSuite) TestFindRepositoryPoolByTags() { - repoPool, err := s.Store.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) - if err != nil { - s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) - } - - pool, err := s.Store.FindRepositoryPoolByTags(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams.Tags) - s.Require().Nil(err) - s.Require().Equal(repoPool.ID, pool.ID) - s.Require().Equal(repoPool.Image, pool.Image) - s.Require().Equal(repoPool.Flavor, pool.Flavor) -} - -func (s *RepoTestSuite) TestFindRepositoryPoolByTagsMissingTags() { - tags := []string{} - - _, err := s.Store.FindRepositoryPoolByTags(context.Background(), s.Fixtures.Repos[0].ID, tags) - - s.Require().NotNil(err) - s.Require().Equal("fetching pool: missing tags", err.Error()) + s.Require().Equal("removing pool: mocked deleting pool error", err.Error()) + s.assertSQLMockExpectations() } func (s *RepoTestSuite) TestListRepoInstances() { - pool, err := s.Store.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) + pool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) } @@ -814,26 +802,32 @@ func (s *RepoTestSuite) TestListRepoInstances() { poolInstances = append(poolInstances, instance) } - instances, err := s.Store.ListRepoInstances(context.Background(), s.Fixtures.Repos[0].ID) + instances, err := s.Store.ListEntityInstances(context.Background(), entity) s.Require().Nil(err) s.equalInstancesByID(poolInstances, instances) } func (s *RepoTestSuite) TestListRepoInstancesInvalidRepoID() { - _, err := s.Store.ListRepoInstances(context.Background(), "dummy-repo-id") + entity := params.GithubEntity{ + ID: "dummy-repo-id", + EntityType: params.GithubEntityTypeRepository, + } + _, err := s.Store.ListEntityInstances(context.Background(), entity) s.Require().NotNil(err) - s.Require().Equal("fetching repo: parsing id: invalid request", err.Error()) + s.Require().Equal("fetching entity: parsing id: invalid request", err.Error()) } func (s *RepoTestSuite) TestUpdateRepositoryPool() { - repoPool, err := s.Store.CreateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, s.Fixtures.CreatePoolParams) + entity, err := s.Fixtures.Repos[0].GetEntity() + s.Require().Nil(err) + repoPool, err := s.Store.CreateEntityPool(context.Background(), entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) } - pool, err := s.Store.UpdateRepositoryPool(context.Background(), s.Fixtures.Repos[0].ID, repoPool.ID, s.Fixtures.UpdatePoolParams) + pool, err := s.Store.UpdateEntityPool(context.Background(), entity, repoPool.ID, s.Fixtures.UpdatePoolParams) s.Require().Nil(err) s.Require().Equal(*s.Fixtures.UpdatePoolParams.MaxRunners, pool.MaxRunners) @@ -843,7 +837,11 @@ func (s *RepoTestSuite) TestUpdateRepositoryPool() { } func (s *RepoTestSuite) TestUpdateRepositoryPoolInvalidRepoID() { - _, err := s.Store.UpdateRepositoryPool(context.Background(), "dummy-org-id", "dummy-repo-id", s.Fixtures.UpdatePoolParams) + entity := params.GithubEntity{ + ID: "dummy-repo-id", + EntityType: params.GithubEntityTypeRepository, + } + _, err := s.Store.UpdateEntityPool(context.Background(), entity, "dummy-repo-id", s.Fixtures.UpdatePoolParams) s.Require().NotNil(err) s.Require().Equal("fetching pool: parsing id: invalid request", err.Error()) diff --git a/database/sql/util.go b/database/sql/util.go index 2dd810f5..aaea31fe 100644 --- a/database/sql/util.go +++ b/database/sql/util.go @@ -18,10 +18,12 @@ import ( "encoding/json" "fmt" + "github.com/google/uuid" "github.com/pkg/errors" "gorm.io/datatypes" "gorm.io/gorm" + runnerErrors "github.com/cloudbase/garm-provider-common/errors" commonParams "github.com/cloudbase/garm-provider-common/params" "github.com/cloudbase/garm-provider-common/util" "github.com/cloudbase/garm/params" @@ -275,9 +277,9 @@ func (s *sqlDatabase) sqlToParamsUser(user User) params.User { } } -func (s *sqlDatabase) getOrCreateTag(tagName string) (Tag, error) { +func (s *sqlDatabase) getOrCreateTag(tx *gorm.DB, tagName string) (Tag, error) { var tag Tag - q := s.conn.Where("name = ?", tagName).First(&tag) + q := tx.Where("name = ?", tagName).First(&tag) if q.Error == nil { return tag, nil } @@ -288,14 +290,13 @@ func (s *sqlDatabase) getOrCreateTag(tagName string) (Tag, error) { Name: tagName, } - q = s.conn.Create(&newTag) - if q.Error != nil { - return Tag{}, errors.Wrap(q.Error, "creating tag") + if err := tx.Create(&newTag).Error; err != nil { + return Tag{}, errors.Wrap(err, "creating tag") } return newTag, nil } -func (s *sqlDatabase) updatePool(pool Pool, param params.UpdatePoolParams) (params.Pool, error) { +func (s *sqlDatabase) updatePool(tx *gorm.DB, pool Pool, param params.UpdatePoolParams) (params.Pool, error) { if param.Enabled != nil && pool.Enabled != *param.Enabled { pool.Enabled = *param.Enabled } @@ -344,24 +345,75 @@ func (s *sqlDatabase) updatePool(pool Pool, param params.UpdatePoolParams) (para pool.Priority = *param.Priority } - if q := s.conn.Save(&pool); q.Error != nil { + if q := tx.Save(&pool); q.Error != nil { return params.Pool{}, errors.Wrap(q.Error, "saving database entry") } tags := []Tag{} if param.Tags != nil && len(param.Tags) > 0 { for _, val := range param.Tags { - t, err := s.getOrCreateTag(val) + t, err := s.getOrCreateTag(tx, val) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching tag") } tags = append(tags, t) } - if err := s.conn.Model(&pool).Association("Tags").Replace(&tags); err != nil { + if err := tx.Model(&pool).Association("Tags").Replace(&tags); err != nil { return params.Pool{}, errors.Wrap(err, "replacing tags") } } return s.sqlToCommonPool(pool) } + +func (s *sqlDatabase) getPoolByID(tx *gorm.DB, poolID string, preload ...string) (Pool, error) { + u, err := uuid.Parse(poolID) + if err != nil { + return Pool{}, errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") + } + var pool Pool + q := tx.Model(&Pool{}) + if len(preload) > 0 { + for _, item := range preload { + q = q.Preload(item) + } + } + + q = q.Where("id = ?", u).First(&pool) + + if q.Error != nil { + if errors.Is(q.Error, gorm.ErrRecordNotFound) { + return Pool{}, runnerErrors.ErrNotFound + } + return Pool{}, errors.Wrap(q.Error, "fetching org from database") + } + return pool, nil +} + +func (s *sqlDatabase) hasGithubEntity(tx *gorm.DB, entityType params.GithubEntityType, entityID string) error { + u, err := uuid.Parse(entityID) + if err != nil { + return errors.Wrap(runnerErrors.ErrBadRequest, "parsing id") + } + var q *gorm.DB + switch entityType { + case params.GithubEntityTypeRepository: + q = tx.Model(&Repository{}).Where("id = ?", u) + case params.GithubEntityTypeOrganization: + q = tx.Model(&Organization{}).Where("id = ?", u) + case params.GithubEntityTypeEnterprise: + q = tx.Model(&Enterprise{}).Where("id = ?", u) + default: + return errors.Wrap(runnerErrors.ErrBadRequest, "invalid entity type") + } + + var entity interface{} + if err := q.First(entity).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return errors.Wrap(runnerErrors.ErrNotFound, "entity not found") + } + return errors.Wrap(err, "fetching entity from database") + } + return nil +} diff --git a/params/params.go b/params/params.go index 2b87a4c5..a2a44222 100644 --- a/params/params.go +++ b/params/params.go @@ -317,6 +317,27 @@ type Pool struct { Priority uint `json:"priority"` } +func (p Pool) GithubEntity() (GithubEntity, error) { + switch p.PoolType() { + case GithubEntityTypeRepository: + return GithubEntity{ + ID: p.RepoID, + EntityType: GithubEntityTypeRepository, + }, nil + case GithubEntityTypeOrganization: + return GithubEntity{ + ID: p.OrgID, + EntityType: GithubEntityTypeOrganization, + }, nil + case GithubEntityTypeEnterprise: + return GithubEntity{ + ID: p.EnterpriseID, + EntityType: GithubEntityTypeEnterprise, + }, nil + } + return GithubEntity{}, fmt.Errorf("pool has no associated entity") +} + func (p Pool) GetID() string { return p.ID } @@ -383,6 +404,18 @@ type Repository struct { WebhookSecret string `json:"-"` } +func (r Repository) GetEntity() (GithubEntity, error) { + if r.ID == "" { + return GithubEntity{}, fmt.Errorf("repository has no ID") + } + return GithubEntity{ + ID: r.ID, + EntityType: GithubEntityTypeRepository, + Owner: r.Owner, + Name: r.Name, + }, nil +} + func (r Repository) GetName() string { return r.Name } @@ -412,6 +445,18 @@ type Organization struct { WebhookSecret string `json:"-"` } +func (o Organization) GetEntity() (GithubEntity, error) { + if o.ID == "" { + return GithubEntity{}, fmt.Errorf("organization has no ID") + } + return GithubEntity{ + ID: o.ID, + EntityType: GithubEntityTypeOrganization, + Owner: o.Name, + WebhookSecret: o.WebhookSecret, + }, nil +} + func (o Organization) GetName() string { return o.Name } @@ -441,6 +486,18 @@ type Enterprise struct { WebhookSecret string `json:"-"` } +func (e Enterprise) GetEntity() (GithubEntity, error) { + if e.ID == "" { + return GithubEntity{}, fmt.Errorf("enterprise has no ID") + } + return GithubEntity{ + ID: e.ID, + EntityType: GithubEntityTypeEnterprise, + Owner: e.Name, + WebhookSecret: e.WebhookSecret, + }, nil +} + func (e Enterprise) GetName() string { return e.Name } diff --git a/runner/enterprises.go b/runner/enterprises.go index c76d3973..c5274e09 100644 --- a/runner/enterprises.go +++ b/runner/enterprises.go @@ -124,7 +124,12 @@ func (r *Runner) DeleteEnterprise(ctx context.Context, enterpriseID string) erro return errors.Wrap(err, "fetching enterprise") } - pools, err := r.store.ListEnterprisePools(ctx, enterpriseID) + entity, err := enterprise.GetEntity() + if err != nil { + return errors.Wrap(err, "getting entity") + } + + pools, err := r.store.ListEntityPools(ctx, entity) if err != nil { return errors.Wrap(err, "fetching enterprise pools") } @@ -193,30 +198,23 @@ func (r *Runner) CreateEnterprisePool(ctx context.Context, enterpriseID string, return params.Pool{}, runnerErrors.ErrUnauthorized } - r.mux.Lock() - defer r.mux.Unlock() - - enterprise, err := r.store.GetEnterpriseByID(ctx, enterpriseID) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching enterprise") - } - - if _, err := r.poolManagerCtrl.GetEnterprisePoolManager(enterprise); err != nil { - return params.Pool{}, runnerErrors.ErrNotFound - } - createPoolParams, err := r.appendTagsToCreatePoolParams(param) if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool params") + return params.Pool{}, fmt.Errorf("failed to append tags to create pool params: %w", err) } if param.RunnerBootstrapTimeout == 0 { param.RunnerBootstrapTimeout = appdefaults.DefaultRunnerBootstrapTimeout } - pool, err := r.store.CreateEnterprisePool(ctx, enterpriseID, createPoolParams) + entity := params.GithubEntity{ + ID: enterpriseID, + EntityType: params.GithubEntityTypeEnterprise, + } + + pool, err := r.store.CreateEntityPool(ctx, entity, createPoolParams) if err != nil { - return params.Pool{}, errors.Wrap(err, "creating pool") + return params.Pool{}, fmt.Errorf("failed to create enterprise pool: %w", err) } return pool, nil @@ -226,8 +224,11 @@ func (r *Runner) GetEnterprisePoolByID(ctx context.Context, enterpriseID, poolID if !auth.IsAdmin(ctx) { return params.Pool{}, runnerErrors.ErrUnauthorized } - - pool, err := r.store.GetEnterprisePool(ctx, enterpriseID, poolID) + entity := params.GithubEntity{ + ID: enterpriseID, + EntityType: params.GithubEntityTypeEnterprise, + } + pool, err := r.store.GetEntityPool(ctx, entity, poolID) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool") } @@ -239,29 +240,27 @@ func (r *Runner) DeleteEnterprisePool(ctx context.Context, enterpriseID, poolID return runnerErrors.ErrUnauthorized } - // nolint:golangci-lint,godox - // TODO: dedup instance count verification - pool, err := r.store.GetEnterprisePool(ctx, enterpriseID, poolID) - if err != nil { - return errors.Wrap(err, "fetching pool") + entity := params.GithubEntity{ + ID: enterpriseID, + EntityType: params.GithubEntityTypeEnterprise, } - instances, err := r.store.ListPoolInstances(ctx, pool.ID) + pool, err := r.store.GetEntityPool(ctx, entity, poolID) if err != nil { - return errors.Wrap(err, "fetching instances") + return errors.Wrap(err, "fetching pool") } // nolint:golangci-lint,godox // TODO: implement a count function - if len(instances) > 0 { + if len(pool.Instances) > 0 { runnerIDs := []string{} - for _, run := range instances { + for _, run := range pool.Instances { runnerIDs = append(runnerIDs, run.ID) } return runnerErrors.NewBadRequestError("pool has runners: %s", strings.Join(runnerIDs, ", ")) } - if err := r.store.DeleteEnterprisePool(ctx, enterpriseID, poolID); err != nil { + if err := r.store.DeleteEntityPool(ctx, entity, poolID); err != nil { return errors.Wrap(err, "deleting pool") } return nil @@ -272,7 +271,11 @@ func (r *Runner) ListEnterprisePools(ctx context.Context, enterpriseID string) ( return []params.Pool{}, runnerErrors.ErrUnauthorized } - pools, err := r.store.ListEnterprisePools(ctx, enterpriseID) + entity := params.GithubEntity{ + ID: enterpriseID, + EntityType: params.GithubEntityTypeEnterprise, + } + pools, err := r.store.ListEntityPools(ctx, entity) if err != nil { return nil, errors.Wrap(err, "fetching pools") } @@ -284,7 +287,11 @@ func (r *Runner) UpdateEnterprisePool(ctx context.Context, enterpriseID, poolID return params.Pool{}, runnerErrors.ErrUnauthorized } - pool, err := r.store.GetEnterprisePool(ctx, enterpriseID, poolID) + entity := params.GithubEntity{ + ID: enterpriseID, + EntityType: params.GithubEntityTypeEnterprise, + } + pool, err := r.store.GetEntityPool(ctx, entity, poolID) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool") } @@ -303,7 +310,7 @@ func (r *Runner) UpdateEnterprisePool(ctx context.Context, enterpriseID, poolID return params.Pool{}, runnerErrors.NewBadRequestError("min_idle_runners cannot be larger than max_runners") } - newPool, err := r.store.UpdateEnterprisePool(ctx, enterpriseID, poolID, param) + newPool, err := r.store.UpdateEntityPool(ctx, entity, poolID, param) if err != nil { return params.Pool{}, errors.Wrap(err, "updating pool") } @@ -314,8 +321,11 @@ func (r *Runner) ListEnterpriseInstances(ctx context.Context, enterpriseID strin if !auth.IsAdmin(ctx) { return nil, runnerErrors.ErrUnauthorized } - - instances, err := r.store.ListEnterpriseInstances(ctx, enterpriseID) + entity := params.GithubEntity{ + ID: enterpriseID, + EntityType: params.GithubEntityTypeEnterprise, + } + instances, err := r.store.ListEntityInstances(ctx, entity) if err != nil { return []params.Instance{}, errors.Wrap(err, "fetching instances") } diff --git a/runner/enterprises_test.go b/runner/enterprises_test.go index 311e743a..dc81da5e 100644 --- a/runner/enterprises_test.go +++ b/runner/enterprises_test.go @@ -272,7 +272,11 @@ func (s *EnterpriseTestSuite) TestDeleteEnterpriseErrUnauthorized() { } func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolDefinedFailed() { - pool, err := s.Fixtures.Store.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, + EntityType: params.GithubEntityTypeEnterprise, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create store enterprises pool: %v", err)) } @@ -340,8 +344,6 @@ func (s *EnterpriseTestSuite) TestUpdateEnterpriseCreateEnterprisePoolMgrFailed( } func (s *EnterpriseTestSuite) TestCreateEnterprisePool() { - s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, nil) - pool, err := s.Runner.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) @@ -365,30 +367,21 @@ func (s *EnterpriseTestSuite) TestCreateEnterprisePoolErrUnauthorized() { s.Require().Equal(runnerErrors.ErrUnauthorized, err) } -func (s *EnterpriseTestSuite) TestCreateEnterprisePoolErrNotFound() { - s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, runnerErrors.ErrNotFound) - - _, err := s.Runner.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) - - s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) - s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) - s.Require().Equal(runnerErrors.ErrNotFound, err) -} - func (s *EnterpriseTestSuite) TestCreateEnterprisePoolFetchPoolParamsFailed() { s.Fixtures.CreatePoolParams.ProviderName = notExistingProviderName - - s.Fixtures.PoolMgrCtrlMock.On("GetEnterprisePoolManager", mock.AnythingOfType("params.Enterprise")).Return(s.Fixtures.PoolMgrMock, nil) - _, err := s.Runner.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) - s.Require().Regexp("fetching pool params: no such provider", err.Error()) + s.Require().Regexp("failed to append tags to create pool params: no such provider not-existent-provider-name", err.Error()) } func (s *EnterpriseTestSuite) TestGetEnterprisePoolByID() { - enterprisePool, err := s.Fixtures.Store.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, + EntityType: params.GithubEntityTypeEnterprise, + } + enterprisePool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %s", err)) } @@ -406,7 +399,11 @@ func (s *EnterpriseTestSuite) TestGetEnterprisePoolByIDErrUnauthorized() { } func (s *EnterpriseTestSuite) TestDeleteEnterprisePool() { - pool, err := s.Fixtures.Store.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, + EntityType: params.GithubEntityTypeEnterprise, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %s", err)) } @@ -415,7 +412,7 @@ func (s *EnterpriseTestSuite) TestDeleteEnterprisePool() { s.Require().Nil(err) - _, err = s.Fixtures.Store.GetEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, pool.ID) + _, err = s.Fixtures.Store.GetEntityPool(s.Fixtures.AdminContext, entity, pool.ID) s.Require().Equal("fetching pool: finding pool: not found", err.Error()) } @@ -426,7 +423,11 @@ func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolErrUnauthorized() { } func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolRunnersFailed() { - pool, err := s.Fixtures.Store.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, + EntityType: params.GithubEntityTypeEnterprise, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) } @@ -441,10 +442,14 @@ func (s *EnterpriseTestSuite) TestDeleteEnterprisePoolRunnersFailed() { } func (s *EnterpriseTestSuite) TestListEnterprisePools() { + entity := params.GithubEntity{ + ID: s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, + EntityType: params.GithubEntityTypeEnterprise, + } enterprisePools := []params.Pool{} for i := 1; i <= 2; i++ { s.Fixtures.CreatePoolParams.Image = fmt.Sprintf("test-enterprise-%v", i) - pool, err := s.Fixtures.Store.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } @@ -464,7 +469,11 @@ func (s *EnterpriseTestSuite) TestListOrgPoolsErrUnauthorized() { } func (s *EnterpriseTestSuite) TestUpdateEnterprisePool() { - enterprisePool, err := s.Fixtures.Store.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, + EntityType: params.GithubEntityTypeEnterprise, + } + enterprisePool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %s", err)) } @@ -483,7 +492,11 @@ func (s *EnterpriseTestSuite) TestUpdateEnterprisePoolErrUnauthorized() { } func (s *EnterpriseTestSuite) TestUpdateEnterprisePoolMinIdleGreaterThanMax() { - pool, err := s.Fixtures.Store.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, + EntityType: params.GithubEntityTypeEnterprise, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %s", err)) } @@ -498,7 +511,11 @@ func (s *EnterpriseTestSuite) TestUpdateEnterprisePoolMinIdleGreaterThanMax() { } func (s *EnterpriseTestSuite) TestListEnterpriseInstances() { - pool, err := s.Fixtures.Store.CreateEnterprisePool(s.Fixtures.AdminContext, s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreEnterprises["test-enterprise-1"].ID, + EntityType: params.GithubEntityTypeEnterprise, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create enterprise pool: %v", err)) } diff --git a/runner/organizations.go b/runner/organizations.go index 3d24dcda..40847ccf 100644 --- a/runner/organizations.go +++ b/runner/organizations.go @@ -138,7 +138,12 @@ func (r *Runner) DeleteOrganization(ctx context.Context, orgID string, keepWebho return errors.Wrap(err, "fetching org") } - pools, err := r.store.ListOrgPools(ctx, orgID) + entity, err := org.GetEntity() + if err != nil { + return errors.Wrap(err, "getting entity") + } + + pools, err := r.store.ListEntityPools(ctx, entity) if err != nil { return errors.Wrap(err, "fetching org pools") } @@ -222,18 +227,6 @@ func (r *Runner) CreateOrgPool(ctx context.Context, orgID string, param params.C return params.Pool{}, runnerErrors.ErrUnauthorized } - r.mux.Lock() - defer r.mux.Unlock() - - org, err := r.store.GetOrganizationByID(ctx, orgID) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching org") - } - - if _, err := r.poolManagerCtrl.GetOrgPoolManager(org); err != nil { - return params.Pool{}, runnerErrors.ErrNotFound - } - createPoolParams, err := r.appendTagsToCreatePoolParams(param) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool params") @@ -243,7 +236,12 @@ func (r *Runner) CreateOrgPool(ctx context.Context, orgID string, param params.C param.RunnerBootstrapTimeout = appdefaults.DefaultRunnerBootstrapTimeout } - pool, err := r.store.CreateOrganizationPool(ctx, orgID, createPoolParams) + entity := params.GithubEntity{ + ID: orgID, + EntityType: params.GithubEntityTypeOrganization, + } + + pool, err := r.store.CreateEntityPool(ctx, entity, createPoolParams) if err != nil { return params.Pool{}, errors.Wrap(err, "creating pool") } @@ -256,10 +254,16 @@ func (r *Runner) GetOrgPoolByID(ctx context.Context, orgID, poolID string) (para return params.Pool{}, runnerErrors.ErrUnauthorized } - pool, err := r.store.GetOrganizationPool(ctx, orgID, poolID) + entity := params.GithubEntity{ + ID: orgID, + EntityType: params.GithubEntityTypeOrganization, + } + + pool, err := r.store.GetEntityPool(ctx, entity, poolID) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool") } + return pool, nil } @@ -268,29 +272,30 @@ func (r *Runner) DeleteOrgPool(ctx context.Context, orgID, poolID string) error return runnerErrors.ErrUnauthorized } - // nolint:golangci-lint,godox - // TODO: dedup instance count verification - pool, err := r.store.GetOrganizationPool(ctx, orgID, poolID) - if err != nil { - return errors.Wrap(err, "fetching pool") + entity := params.GithubEntity{ + ID: orgID, + EntityType: params.GithubEntityTypeOrganization, } - instances, err := r.store.ListPoolInstances(ctx, pool.ID) + pool, err := r.store.GetEntityPool(ctx, entity, poolID) if err != nil { - return errors.Wrap(err, "fetching instances") + if !errors.Is(err, runnerErrors.ErrNotFound) { + return errors.Wrap(err, "fetching pool") + } + return nil } // nolint:golangci-lint,godox // TODO: implement a count function - if len(instances) > 0 { + if len(pool.Instances) > 0 { runnerIDs := []string{} - for _, run := range instances { + for _, run := range pool.Instances { runnerIDs = append(runnerIDs, run.ID) } return runnerErrors.NewBadRequestError("pool has runners: %s", strings.Join(runnerIDs, ", ")) } - if err := r.store.DeleteOrganizationPool(ctx, orgID, poolID); err != nil { + if err := r.store.DeleteEntityPool(ctx, entity, poolID); err != nil { return errors.Wrap(err, "deleting pool") } return nil @@ -300,8 +305,11 @@ func (r *Runner) ListOrgPools(ctx context.Context, orgID string) ([]params.Pool, if !auth.IsAdmin(ctx) { return []params.Pool{}, runnerErrors.ErrUnauthorized } - - pools, err := r.store.ListOrgPools(ctx, orgID) + entity := params.GithubEntity{ + ID: orgID, + EntityType: params.GithubEntityTypeOrganization, + } + pools, err := r.store.ListEntityPools(ctx, entity) if err != nil { return nil, errors.Wrap(err, "fetching pools") } @@ -313,7 +321,12 @@ func (r *Runner) UpdateOrgPool(ctx context.Context, orgID, poolID string, param return params.Pool{}, runnerErrors.ErrUnauthorized } - pool, err := r.store.GetOrganizationPool(ctx, orgID, poolID) + entity := params.GithubEntity{ + ID: orgID, + EntityType: params.GithubEntityTypeOrganization, + } + + pool, err := r.store.GetEntityPool(ctx, entity, poolID) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool") } @@ -332,7 +345,7 @@ func (r *Runner) UpdateOrgPool(ctx context.Context, orgID, poolID string, param return params.Pool{}, runnerErrors.NewBadRequestError("min_idle_runners cannot be larger than max_runners") } - newPool, err := r.store.UpdateOrganizationPool(ctx, orgID, poolID, param) + newPool, err := r.store.UpdateEntityPool(ctx, entity, poolID, param) if err != nil { return params.Pool{}, errors.Wrap(err, "updating pool") } @@ -344,7 +357,12 @@ func (r *Runner) ListOrgInstances(ctx context.Context, orgID string) ([]params.I return nil, runnerErrors.ErrUnauthorized } - instances, err := r.store.ListOrgInstances(ctx, orgID) + entity := params.GithubEntity{ + ID: orgID, + EntityType: params.GithubEntityTypeOrganization, + } + + instances, err := r.store.ListEntityInstances(ctx, entity) if err != nil { return []params.Instance{}, errors.Wrap(err, "fetching instances") } diff --git a/runner/organizations_test.go b/runner/organizations_test.go index 7ebfcff8..d0113756 100644 --- a/runner/organizations_test.go +++ b/runner/organizations_test.go @@ -285,7 +285,11 @@ func (s *OrgTestSuite) TestDeleteOrganizationErrUnauthorized() { } func (s *OrgTestSuite) TestDeleteOrganizationPoolDefinedFailed() { - pool, err := s.Fixtures.Store.CreateOrganizationPool(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreOrgs["test-org-1"].ID, + EntityType: params.GithubEntityTypeOrganization, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create store organizations pool: %v", err)) } @@ -365,8 +369,6 @@ func (s *OrgTestSuite) TestUpdateOrganizationCreateOrgPoolMgrFailed() { } func (s *OrgTestSuite) TestCreateOrgPool() { - s.Fixtures.PoolMgrCtrlMock.On("GetOrgPoolManager", mock.AnythingOfType("params.Organization")).Return(s.Fixtures.PoolMgrMock, nil) - pool, err := s.Runner.CreateOrgPool(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.CreatePoolParams) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) @@ -390,21 +392,8 @@ func (s *OrgTestSuite) TestCreateOrgPoolErrUnauthorized() { s.Require().Equal(runnerErrors.ErrUnauthorized, err) } -func (s *OrgTestSuite) TestCreateOrgPoolErrNotFound() { - s.Fixtures.PoolMgrCtrlMock.On("GetOrgPoolManager", mock.AnythingOfType("params.Organization")).Return(s.Fixtures.PoolMgrMock, runnerErrors.ErrNotFound) - - _, err := s.Runner.CreateOrgPool(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.CreatePoolParams) - - s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) - s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) - s.Require().Equal(runnerErrors.ErrNotFound, err) -} - func (s *OrgTestSuite) TestCreateOrgPoolFetchPoolParamsFailed() { s.Fixtures.CreatePoolParams.ProviderName = notExistingProviderName - - s.Fixtures.PoolMgrCtrlMock.On("GetOrgPoolManager", mock.AnythingOfType("params.Organization")).Return(s.Fixtures.PoolMgrMock, nil) - _, err := s.Runner.CreateOrgPool(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.CreatePoolParams) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) @@ -413,7 +402,11 @@ func (s *OrgTestSuite) TestCreateOrgPoolFetchPoolParamsFailed() { } func (s *OrgTestSuite) TestGetOrgPoolByID() { - orgPool, err := s.Fixtures.Store.CreateOrganizationPool(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreOrgs["test-org-1"].ID, + EntityType: params.GithubEntityTypeOrganization, + } + orgPool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %s", err)) } @@ -431,7 +424,11 @@ func (s *OrgTestSuite) TestGetOrgPoolByIDErrUnauthorized() { } func (s *OrgTestSuite) TestDeleteOrgPool() { - pool, err := s.Fixtures.Store.CreateOrganizationPool(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreOrgs["test-org-1"].ID, + EntityType: params.GithubEntityTypeOrganization, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %s", err)) } @@ -440,7 +437,7 @@ func (s *OrgTestSuite) TestDeleteOrgPool() { s.Require().Nil(err) - _, err = s.Fixtures.Store.GetOrganizationPool(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, pool.ID) + _, err = s.Fixtures.Store.GetEntityPool(s.Fixtures.AdminContext, entity, pool.ID) s.Require().Equal("fetching pool: finding pool: not found", err.Error()) } @@ -451,7 +448,11 @@ func (s *OrgTestSuite) TestDeleteOrgPoolErrUnauthorized() { } func (s *OrgTestSuite) TestDeleteOrgPoolRunnersFailed() { - pool, err := s.Fixtures.Store.CreateOrganizationPool(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreOrgs["test-org-1"].ID, + EntityType: params.GithubEntityTypeOrganization, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } @@ -466,10 +467,14 @@ func (s *OrgTestSuite) TestDeleteOrgPoolRunnersFailed() { } func (s *OrgTestSuite) TestListOrgPools() { + entity := params.GithubEntity{ + ID: s.Fixtures.StoreOrgs["test-org-1"].ID, + EntityType: params.GithubEntityTypeOrganization, + } orgPools := []params.Pool{} for i := 1; i <= 2; i++ { s.Fixtures.CreatePoolParams.Image = fmt.Sprintf("test-org-%v", i) - pool, err := s.Fixtures.Store.CreateOrganizationPool(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.CreatePoolParams) + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } @@ -489,7 +494,11 @@ func (s *OrgTestSuite) TestListOrgPoolsErrUnauthorized() { } func (s *OrgTestSuite) TestUpdateOrgPool() { - orgPool, err := s.Fixtures.Store.CreateOrganizationPool(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreOrgs["test-org-1"].ID, + EntityType: params.GithubEntityTypeOrganization, + } + orgPool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %s", err)) } @@ -508,7 +517,11 @@ func (s *OrgTestSuite) TestUpdateOrgPoolErrUnauthorized() { } func (s *OrgTestSuite) TestUpdateOrgPoolMinIdleGreaterThanMax() { - pool, err := s.Fixtures.Store.CreateOrganizationPool(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreOrgs["test-org-1"].ID, + EntityType: params.GithubEntityTypeOrganization, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %s", err)) } @@ -523,7 +536,11 @@ func (s *OrgTestSuite) TestUpdateOrgPoolMinIdleGreaterThanMax() { } func (s *OrgTestSuite) TestListOrgInstances() { - pool, err := s.Fixtures.Store.CreateOrganizationPool(s.Fixtures.AdminContext, s.Fixtures.StoreOrgs["test-org-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreOrgs["test-org-1"].ID, + EntityType: params.GithubEntityTypeOrganization, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create org pool: %v", err)) } diff --git a/runner/pool/pool.go b/runner/pool/pool.go index 3fe1eb3e..bdfc0d3b 100644 --- a/runner/pool/pool.go +++ b/runner/pool/pool.go @@ -1725,10 +1725,6 @@ func (r *basePoolManager) WebhookSecret() string { return r.entity.WebhookSecret } -func (r *basePoolManager) GithubRunnerRegistrationToken() (string, error) { - return r.GetGithubRegistrationToken() -} - func (r *basePoolManager) ID() string { return r.entity.ID } @@ -2095,7 +2091,7 @@ func (r *basePoolManager) GetRunnerInfoFromWorkflow(job params.WorkflowJob) (par return params.RunnerInfo{}, fmt.Errorf("failed to find runner name from workflow") } -func (r *basePoolManager) GetGithubRegistrationToken() (string, error) { +func (r *basePoolManager) GithubRunnerRegistrationToken() (string, error) { tk, ghResp, err := r.ghcli.CreateEntityRegistrationToken(r.ctx) if err != nil { if ghResp != nil && ghResp.StatusCode == http.StatusUnauthorized { @@ -2177,41 +2173,15 @@ func (r *basePoolManager) GithubURL() string { } func (r *basePoolManager) FetchDbInstances() ([]params.Instance, error) { - switch r.entity.EntityType { - case params.GithubEntityTypeRepository: - return r.store.ListRepoInstances(r.ctx, r.entity.ID) - case params.GithubEntityTypeOrganization: - return r.store.ListOrgInstances(r.ctx, r.entity.ID) - case params.GithubEntityTypeEnterprise: - return r.store.ListEnterpriseInstances(r.ctx, r.entity.ID) - } - return nil, fmt.Errorf("unknown entity type: %s", r.entity.EntityType) + return r.store.ListEntityInstances(r.ctx, r.entity) } func (r *basePoolManager) ListPools() ([]params.Pool, error) { - switch r.entity.EntityType { - case params.GithubEntityTypeRepository: - return r.store.ListRepoPools(r.ctx, r.entity.ID) - case params.GithubEntityTypeOrganization: - return r.store.ListOrgPools(r.ctx, r.entity.ID) - case params.GithubEntityTypeEnterprise: - return r.store.ListEnterprisePools(r.ctx, r.entity.ID) - default: - return nil, fmt.Errorf("unknown entity type: %s", r.entity.EntityType) - } + return r.store.ListEntityPools(r.ctx, r.entity) } func (r *basePoolManager) GetPoolByID(poolID string) (params.Pool, error) { - switch r.entity.EntityType { - case params.GithubEntityTypeRepository: - return r.store.GetRepositoryPool(r.ctx, r.entity.ID, poolID) - case params.GithubEntityTypeOrganization: - return r.store.GetOrganizationPool(r.ctx, r.entity.ID, poolID) - case params.GithubEntityTypeEnterprise: - return r.store.GetEnterprisePool(r.ctx, r.entity.ID, poolID) - default: - return params.Pool{}, fmt.Errorf("unknown entity type: %s", r.entity.EntityType) - } + return r.store.GetEntityPool(r.ctx, r.entity, poolID) } func (r *basePoolManager) GetWebhookInfo(ctx context.Context) (params.HookInfo, error) { diff --git a/runner/pools.go b/runner/pools.go index 16194f65..aab423ff 100644 --- a/runner/pools.go +++ b/runner/pools.go @@ -16,7 +16,6 @@ package runner import ( "context" - "fmt" "github.com/pkg/errors" @@ -108,19 +107,12 @@ func (r *Runner) UpdatePoolByID(ctx context.Context, poolID string, param params param.Tags = newTags } - var newPool params.Pool - - switch { - case pool.RepoID != "": - newPool, err = r.store.UpdateRepositoryPool(ctx, pool.RepoID, poolID, param) - case pool.OrgID != "": - newPool, err = r.store.UpdateOrganizationPool(ctx, pool.OrgID, poolID, param) - case pool.EnterpriseID != "": - newPool, err = r.store.UpdateEnterprisePool(ctx, pool.EnterpriseID, poolID, param) - default: - return params.Pool{}, fmt.Errorf("pool not found to a repo, org or enterprise") + entity, err := pool.GithubEntity() + if err != nil { + return params.Pool{}, errors.Wrap(err, "getting entity") } + newPool, err := r.store.UpdateEntityPool(ctx, entity, poolID, param) if err != nil { return params.Pool{}, errors.Wrap(err, "updating pool") } diff --git a/runner/pools_test.go b/runner/pools_test.go index 59d6ff27..e2b269a0 100644 --- a/runner/pools_test.go +++ b/runner/pools_test.go @@ -64,11 +64,15 @@ func (s *PoolTestSuite) SetupTest() { } // create some pool objects in the database, for testing purposes + entity := params.GithubEntity{ + ID: org.ID, + EntityType: params.GithubEntityTypeOrganization, + } orgPools := []params.Pool{} for i := 1; i <= 3; i++ { - pool, err := db.CreateOrganizationPool( + pool, err := db.CreateEntityPool( context.Background(), - org.ID, + entity, params.CreatePoolParams{ ProviderName: "test-provider", MaxRunners: 4, diff --git a/runner/repositories.go b/runner/repositories.go index c71fab39..f7692b69 100644 --- a/runner/repositories.go +++ b/runner/repositories.go @@ -137,7 +137,12 @@ func (r *Runner) DeleteRepository(ctx context.Context, repoID string, keepWebhoo return errors.Wrap(err, "fetching repo") } - pools, err := r.store.ListRepoPools(ctx, repoID) + entity, err := repo.GetEntity() + if err != nil { + return errors.Wrap(err, "getting entity") + } + + pools, err := r.store.ListEntityPools(ctx, entity) if err != nil { return errors.Wrap(err, "fetching repo pools") } @@ -221,30 +226,23 @@ func (r *Runner) CreateRepoPool(ctx context.Context, repoID string, param params return params.Pool{}, runnerErrors.ErrUnauthorized } - r.mux.Lock() - defer r.mux.Unlock() - - repo, err := r.store.GetRepositoryByID(ctx, repoID) - if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching repo") - } - - if _, err := r.poolManagerCtrl.GetRepoPoolManager(repo); err != nil { - return params.Pool{}, runnerErrors.ErrNotFound - } - createPoolParams, err := r.appendTagsToCreatePoolParams(param) if err != nil { - return params.Pool{}, errors.Wrap(err, "fetching pool params") + return params.Pool{}, fmt.Errorf("failed to append tags to create pool params: %w", err) } if createPoolParams.RunnerBootstrapTimeout == 0 { createPoolParams.RunnerBootstrapTimeout = appdefaults.DefaultRunnerBootstrapTimeout } - pool, err := r.store.CreateRepositoryPool(ctx, repoID, createPoolParams) + entity := params.GithubEntity{ + ID: repoID, + EntityType: params.GithubEntityTypeRepository, + } + + pool, err := r.store.CreateEntityPool(ctx, entity, createPoolParams) if err != nil { - return params.Pool{}, errors.Wrap(err, "creating pool") + return params.Pool{}, fmt.Errorf("failed to create pool: %w", err) } return pool, nil @@ -255,10 +253,16 @@ func (r *Runner) GetRepoPoolByID(ctx context.Context, repoID, poolID string) (pa return params.Pool{}, runnerErrors.ErrUnauthorized } - pool, err := r.store.GetRepositoryPool(ctx, repoID, poolID) + entity := params.GithubEntity{ + ID: repoID, + EntityType: params.GithubEntityTypeRepository, + } + + pool, err := r.store.GetEntityPool(ctx, entity, poolID) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool") } + return pool, nil } @@ -267,27 +271,26 @@ func (r *Runner) DeleteRepoPool(ctx context.Context, repoID, poolID string) erro return runnerErrors.ErrUnauthorized } - pool, err := r.store.GetRepositoryPool(ctx, repoID, poolID) - if err != nil { - return errors.Wrap(err, "fetching pool") + entity := params.GithubEntity{ + ID: repoID, + EntityType: params.GithubEntityTypeRepository, } - - instances, err := r.store.ListPoolInstances(ctx, pool.ID) + pool, err := r.store.GetEntityPool(ctx, entity, poolID) if err != nil { - return errors.Wrap(err, "fetching instances") + return errors.Wrap(err, "fetching pool") } // nolint:golangci-lint,godox // TODO: implement a count function - if len(instances) > 0 { + if len(pool.Instances) > 0 { runnerIDs := []string{} - for _, run := range instances { + for _, run := range pool.Instances { runnerIDs = append(runnerIDs, run.ID) } return runnerErrors.NewBadRequestError("pool has runners: %s", strings.Join(runnerIDs, ", ")) } - if err := r.store.DeleteRepositoryPool(ctx, repoID, poolID); err != nil { + if err := r.store.DeleteEntityPool(ctx, entity, poolID); err != nil { return errors.Wrap(err, "deleting pool") } return nil @@ -297,8 +300,11 @@ func (r *Runner) ListRepoPools(ctx context.Context, repoID string) ([]params.Poo if !auth.IsAdmin(ctx) { return []params.Pool{}, runnerErrors.ErrUnauthorized } - - pools, err := r.store.ListRepoPools(ctx, repoID) + entity := params.GithubEntity{ + ID: repoID, + EntityType: params.GithubEntityTypeRepository, + } + pools, err := r.store.ListEntityPools(ctx, entity) if err != nil { return nil, errors.Wrap(err, "fetching pools") } @@ -322,7 +328,11 @@ func (r *Runner) UpdateRepoPool(ctx context.Context, repoID, poolID string, para return params.Pool{}, runnerErrors.ErrUnauthorized } - pool, err := r.store.GetRepositoryPool(ctx, repoID, poolID) + entity := params.GithubEntity{ + ID: repoID, + EntityType: params.GithubEntityTypeRepository, + } + pool, err := r.store.GetEntityPool(ctx, entity, poolID) if err != nil { return params.Pool{}, errors.Wrap(err, "fetching pool") } @@ -341,7 +351,7 @@ func (r *Runner) UpdateRepoPool(ctx context.Context, repoID, poolID string, para return params.Pool{}, runnerErrors.NewBadRequestError("min_idle_runners cannot be larger than max_runners") } - newPool, err := r.store.UpdateRepositoryPool(ctx, repoID, poolID, param) + newPool, err := r.store.UpdateEntityPool(ctx, entity, poolID, param) if err != nil { return params.Pool{}, errors.Wrap(err, "updating pool") } @@ -352,8 +362,11 @@ func (r *Runner) ListRepoInstances(ctx context.Context, repoID string) ([]params if !auth.IsAdmin(ctx) { return nil, runnerErrors.ErrUnauthorized } - - instances, err := r.store.ListRepoInstances(ctx, repoID) + entity := params.GithubEntity{ + ID: repoID, + EntityType: params.GithubEntityTypeRepository, + } + instances, err := r.store.ListEntityInstances(ctx, entity) if err != nil { return []params.Instance{}, errors.Wrap(err, "fetching instances") } diff --git a/runner/repositories_test.go b/runner/repositories_test.go index 8a1e8d9c..20814a86 100644 --- a/runner/repositories_test.go +++ b/runner/repositories_test.go @@ -295,7 +295,11 @@ func (s *RepoTestSuite) TestDeleteRepositoryErrUnauthorized() { } func (s *RepoTestSuite) TestDeleteRepositoryPoolDefinedFailed() { - pool, err := s.Fixtures.Store.CreateRepositoryPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreRepos["test-repo-1"].ID, + EntityType: params.GithubEntityTypeRepository, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create store repositories pool: %v", err)) } @@ -376,8 +380,6 @@ func (s *RepoTestSuite) TestUpdateRepositoryCreateRepoPoolMgrFailed() { } func (s *RepoTestSuite) TestCreateRepoPool() { - s.Fixtures.PoolMgrCtrlMock.On("GetRepoPoolManager", mock.AnythingOfType("params.Repository")).Return(s.Fixtures.PoolMgrMock, nil) - pool, err := s.Runner.CreateRepoPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.CreatePoolParams) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) @@ -401,30 +403,21 @@ func (s *RepoTestSuite) TestCreateRepoPoolErrUnauthorized() { s.Require().Equal(runnerErrors.ErrUnauthorized, err) } -func (s *RepoTestSuite) TestCreateRepoPoolErrNotFound() { - s.Fixtures.PoolMgrCtrlMock.On("GetRepoPoolManager", mock.AnythingOfType("params.Repository")).Return(s.Fixtures.PoolMgrMock, s.Fixtures.ErrMock) - - _, err := s.Runner.CreateRepoPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.CreatePoolParams) - - s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) - s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) - s.Require().Equal(runnerErrors.ErrNotFound, err) -} - func (s *RepoTestSuite) TestCreateRepoPoolFetchPoolParamsFailed() { s.Fixtures.CreatePoolParams.ProviderName = notExistingProviderName - - s.Fixtures.PoolMgrCtrlMock.On("GetRepoPoolManager", mock.AnythingOfType("params.Repository")).Return(s.Fixtures.PoolMgrMock, nil) - _, err := s.Runner.CreateRepoPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.CreatePoolParams) s.Fixtures.PoolMgrMock.AssertExpectations(s.T()) s.Fixtures.PoolMgrCtrlMock.AssertExpectations(s.T()) - s.Require().Regexp("fetching pool params: no such provider", err.Error()) + s.Require().Regexp("failed to append tags to create pool params: no such provider not-existent-provider-name", err.Error()) } func (s *RepoTestSuite) TestGetRepoPoolByID() { - repoPool, err := s.Fixtures.Store.CreateRepositoryPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreRepos["test-repo-1"].ID, + EntityType: params.GithubEntityTypeRepository, + } + repoPool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %s", err)) } @@ -442,7 +435,11 @@ func (s *RepoTestSuite) TestGetRepoPoolByIDErrUnauthorized() { } func (s *RepoTestSuite) TestDeleteRepoPool() { - pool, err := s.Fixtures.Store.CreateRepositoryPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreRepos["test-repo-1"].ID, + EntityType: params.GithubEntityTypeRepository, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %s", err)) } @@ -451,7 +448,7 @@ func (s *RepoTestSuite) TestDeleteRepoPool() { s.Require().Nil(err) - _, err = s.Fixtures.Store.GetRepositoryPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, pool.ID) + _, err = s.Fixtures.Store.GetEntityPool(s.Fixtures.AdminContext, entity, pool.ID) s.Require().Equal("fetching pool: finding pool: not found", err.Error()) } @@ -462,7 +459,11 @@ func (s *RepoTestSuite) TestDeleteRepoPoolErrUnauthorized() { } func (s *RepoTestSuite) TestDeleteRepoPoolRunnersFailed() { - pool, err := s.Fixtures.Store.CreateRepositoryPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreRepos["test-repo-1"].ID, + EntityType: params.GithubEntityTypeRepository, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %s", err)) } @@ -477,10 +478,14 @@ func (s *RepoTestSuite) TestDeleteRepoPoolRunnersFailed() { } func (s *RepoTestSuite) TestListRepoPools() { + entity := params.GithubEntity{ + ID: s.Fixtures.StoreRepos["test-repo-1"].ID, + EntityType: params.GithubEntityTypeRepository, + } repoPools := []params.Pool{} for i := 1; i <= 2; i++ { s.Fixtures.CreatePoolParams.Image = fmt.Sprintf("test-repo-%v", i) - pool, err := s.Fixtures.Store.CreateRepositoryPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.CreatePoolParams) + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) } @@ -500,7 +505,11 @@ func (s *RepoTestSuite) TestListRepoPoolsErrUnauthorized() { } func (s *RepoTestSuite) TestListPoolInstances() { - pool, err := s.Fixtures.Store.CreateRepositoryPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreRepos["test-repo-1"].ID, + EntityType: params.GithubEntityTypeRepository, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) } @@ -527,7 +536,11 @@ func (s *RepoTestSuite) TestListPoolInstancesErrUnauthorized() { } func (s *RepoTestSuite) TestUpdateRepoPool() { - repoPool, err := s.Fixtures.Store.CreateRepositoryPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreRepos["test-repo-1"].ID, + EntityType: params.GithubEntityTypeRepository, + } + repoPool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create store repositories pool: %v", err)) } @@ -546,7 +559,11 @@ func (s *RepoTestSuite) TestUpdateRepoPoolErrUnauthorized() { } func (s *RepoTestSuite) TestUpdateRepoPoolMinIdleGreaterThanMax() { - pool, err := s.Fixtures.Store.CreateRepositoryPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreRepos["test-repo-1"].ID, + EntityType: params.GithubEntityTypeRepository, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %s", err)) } @@ -561,7 +578,11 @@ func (s *RepoTestSuite) TestUpdateRepoPoolMinIdleGreaterThanMax() { } func (s *RepoTestSuite) TestListRepoInstances() { - pool, err := s.Fixtures.Store.CreateRepositoryPool(s.Fixtures.AdminContext, s.Fixtures.StoreRepos["test-repo-1"].ID, s.Fixtures.CreatePoolParams) + entity := params.GithubEntity{ + ID: s.Fixtures.StoreRepos["test-repo-1"].ID, + EntityType: params.GithubEntityTypeRepository, + } + pool, err := s.Fixtures.Store.CreateEntityPool(s.Fixtures.AdminContext, entity, s.Fixtures.CreatePoolParams) if err != nil { s.FailNow(fmt.Sprintf("cannot create repo pool: %v", err)) } diff --git a/runner/runner.go b/runner/runner.go index 7eab27f9..a29fda0c 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -785,7 +785,8 @@ func (r *Runner) DispatchWorkflowJob(hookTargetType, signature string, jobData [ func (r *Runner) appendTagsToCreatePoolParams(param params.CreatePoolParams) (params.CreatePoolParams, error) { if err := param.Validate(); err != nil { - return params.CreatePoolParams{}, errors.Wrapf(runnerErrors.ErrBadRequest, "validating params: %s", err) + return params.CreatePoolParams{}, fmt.Errorf("failed to validate params (%q): %w", err, runnerErrors.ErrBadRequest) + // errors.Wrapf(runnerErrors.ErrBadRequest, "validating params: %s", err) } if !IsSupportedOSType(param.OSType) { @@ -803,7 +804,7 @@ func (r *Runner) appendTagsToCreatePoolParams(param params.CreatePoolParams) (pa newTags, err := r.processTags(string(param.OSArch), param.OSType, param.Tags) if err != nil { - return params.CreatePoolParams{}, errors.Wrap(err, "processing tags") + return params.CreatePoolParams{}, fmt.Errorf("failed to process tags: %w", err) } param.Tags = newTags