Skip to content

Commit

Permalink
refactor: remove the dependency between dao and repo
Browse files Browse the repository at this point in the history
Signed-off-by: saltbo <saltbo@foxmail.com>
  • Loading branch information
saltbo committed Aug 5, 2023
1 parent a76ee8e commit c06eed7
Show file tree
Hide file tree
Showing 11 changed files with 144 additions and 126 deletions.
2 changes: 1 addition & 1 deletion cmd/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ var serverCmd = &cobra.Command{
Use: "server",
Short: "A cloud disk base on the cloud service.",
RunE: func(cmd *cobra.Command, args []string) error {
s := app.NewServer()
s := app.InitializeServer()
return s.Run()
},
}
Expand Down
23 changes: 14 additions & 9 deletions internal/app/dao/dao.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,21 @@ func Init(driver, dsn string) error {
return nil
}

func GetDBQuery() *query.Query {
if viper.IsSet("installed") {
if err := Init(viper.GetString("database.driver"), viper.GetString("database.dsn")); err != nil {
log.Fatalln(err)
}
}
type DBQueryFactory struct {
}

return query.Use(gdb)
func NewDBQueryFactory() *DBQueryFactory {
return &DBQueryFactory{}
}

func Transaction(fc func(tx *gorm.DB) error) error {
return gdb.Transaction(fc)
func (D *DBQueryFactory) Q() *query.Query {
if !viper.IsSet("installed") {
return nil
}

if err := Init(viper.GetString("database.driver"), viper.GetString("database.dsn")); err != nil {
log.Fatalln(err)
}

return query.Use(gdb)
}
58 changes: 29 additions & 29 deletions internal/app/repo/matter.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,27 +47,27 @@ type Matter interface {
var _ Matter = (*MatterDBQuery)(nil)

type MatterDBQuery struct {
q *query.Query
DBQuery
}

func NewMatterDBQuery(q *query.Query) *MatterDBQuery {
return &MatterDBQuery{q: q}
func NewMatterDBQuery(q DBQuery) *MatterDBQuery {
return &MatterDBQuery{DBQuery: q}
}

func (db *MatterDBQuery) Find(ctx context.Context, id int64) (*entity.Matter, error) {
return db.q.Matter.WithContext(ctx).Where(db.q.Matter.Id.Eq(id)).First()
return db.Q().Matter.WithContext(ctx).Where(db.Q().Matter.Id.Eq(id)).First()
}

func (db *MatterDBQuery) FindWith(ctx context.Context, opt *MatterFindWithOption) (*entity.Matter, error) {
conds := make([]gen.Condition, 0)
if opt.Id != 0 {
conds = append(conds, db.q.Matter.Id.Eq(opt.Id))
conds = append(conds, db.Q().Matter.Id.Eq(opt.Id))
}
if opt.Alias != "" {
conds = append(conds, db.q.Matter.Alias_.Eq(opt.Alias))
conds = append(conds, db.Q().Matter.Alias_.Eq(opt.Alias))
}

q := db.q.Matter.WithContext(ctx)
q := db.Q().Matter.WithContext(ctx)
if opt.Deleted {
q = q.Unscoped()
}
Expand All @@ -76,7 +76,7 @@ func (db *MatterDBQuery) FindWith(ctx context.Context, opt *MatterFindWithOption
}

func (db *MatterDBQuery) FindByAlias(ctx context.Context, alias string) (*entity.Matter, error) {
return db.q.Matter.WithContext(ctx).Where(db.q.Matter.Alias_.Eq(alias)).First()
return db.Q().Matter.WithContext(ctx).Where(db.Q().Matter.Alias_.Eq(alias)).First()
}

func (db *MatterDBQuery) PathExist(ctx context.Context, filepath string) bool {
Expand All @@ -92,41 +92,41 @@ func (db *MatterDBQuery) PathExist(ctx context.Context, filepath string) bool {
parent, name = path.Split(filepath)
}

conds := []gen.Condition{db.q.Matter.Name.Eq(name)}
conds := []gen.Condition{db.Q().Matter.Name.Eq(name)}
if parent != name {
conds = append(conds, db.q.Matter.Parent.Eq(strings.TrimPrefix(parent, "/")))
conds = append(conds, db.Q().Matter.Parent.Eq(strings.TrimPrefix(parent, "/")))
}

_, err := db.q.Matter.WithContext(ctx).Where(conds...).First()
_, err := db.Q().Matter.WithContext(ctx).Where(conds...).First()
return err == nil
}

func (db *MatterDBQuery) FindAll(ctx context.Context, opts *MatterListOption) ([]*entity.Matter, int64, error) {
conds := make([]gen.Condition, 0)
if opts.Uid != 0 {
conds = append(conds, db.q.Matter.Uid.Eq(opts.Uid))
conds = append(conds, db.Q().Matter.Uid.Eq(opts.Uid))
}
if opts.Sid != 0 {
conds = append(conds, db.q.Matter.Sid.Eq(opts.Sid))
conds = append(conds, db.Q().Matter.Sid.Eq(opts.Sid))
}

if opts.Keyword != "" {
conds = append(conds, db.q.Matter.Name.Like(fmt.Sprintf("%%%s%%", opts.Keyword)))
conds = append(conds, db.Q().Matter.Name.Like(fmt.Sprintf("%%%s%%", opts.Keyword)))
} else if !opts.Draft {
conds = append(conds, db.q.Matter.Parent.Eq(opts.Dir))
conds = append(conds, db.Q().Matter.Parent.Eq(opts.Dir))
}

if opts.Type == "doc" {
conds = append(conds, db.q.Matter.Type.In(entity.DocTypes...))
conds = append(conds, db.Q().Matter.Type.In(entity.DocTypes...))
} else if opts.Type != "" {
conds = append(conds, db.q.Matter.Type.Like(fmt.Sprintf("%%%s%%", opts.Type)))
conds = append(conds, db.Q().Matter.Type.Like(fmt.Sprintf("%%%s%%", opts.Type)))
}

if !opts.Draft {
conds = append(conds, db.q.Matter.UploadedAt.IsNotNull())
conds = append(conds, db.Q().Matter.UploadedAt.IsNotNull())
}

q := db.q.Matter.WithContext(ctx).Where(conds...).Order(db.q.Matter.DirType.Desc(), db.q.Matter.Id.Desc())
q := db.Q().Matter.WithContext(ctx).Where(conds...).Order(db.Q().Matter.DirType.Desc(), db.Q().Matter.Id.Desc())

if opts.Limit == 0 {
matters, err := q.Find()
Expand All @@ -148,7 +148,7 @@ func (db *MatterDBQuery) Create(ctx context.Context, m *entity.Matter) error {
m.Name = strings.TrimSuffix(m.Name, ext) + suffix + ext
}

return db.q.Matter.Create(m)
return db.Q().Matter.Create(m)
}

func (db *MatterDBQuery) Copy(ctx context.Context, id int64, to string) (*entity.Matter, error) {
Expand All @@ -165,7 +165,7 @@ func (db *MatterDBQuery) Copy(ctx context.Context, id int64, to string) (*entity
newMatter.Parent = to
if !em.IsDir() {
// 如果是文件则只创建新的文件即可
return newMatter, db.q.Matter.Create(newMatter)
return newMatter, db.Q().Matter.Create(newMatter)
}

// 如果是文件夹则查找所有子文件/文件夹一起复制
Expand All @@ -180,7 +180,7 @@ func (db *MatterDBQuery) Copy(ctx context.Context, id int64, to string) (*entity
return newMatter
})

return newMatter, db.q.Matter.Create(newMatters...)
return newMatter, db.Q().Matter.Create(newMatters...)
}

func (db *MatterDBQuery) Update(ctx context.Context, id int64, m *entity.Matter) error {
Expand All @@ -189,7 +189,7 @@ func (db *MatterDBQuery) Update(ctx context.Context, id int64, m *entity.Matter)
return err
}

return db.q.Transaction(func(tx *query.Query) error {
return db.Q().Transaction(func(tx *query.Query) error {
tq := tx.Matter.WithContext(ctx)
if m.IsDir() {
// 如果是目录,则需要把该目录下的子文件/目录一并修改
Expand All @@ -212,7 +212,7 @@ func (db *MatterDBQuery) Delete(ctx context.Context, id int64) error {
}

m.TrashedBy = uuid.New().String()
return db.q.Transaction(func(tx *query.Query) error {
return db.Q().Transaction(func(tx *query.Query) error {
tq := tx.Matter.WithContext(ctx)
if m.IsDir() {
// 如果是目录,则需要把该目录下的子文件/目录一并删除
Expand All @@ -234,7 +234,7 @@ func (db *MatterDBQuery) Delete(ctx context.Context, id int64) error {
}

func (db *MatterDBQuery) Recovery(ctx context.Context, id int64) error {
m, err := db.q.Matter.WithContext(ctx).Unscoped().Where(db.q.Matter.Id.Eq(id)).First()
m, err := db.Q().Matter.WithContext(ctx).Unscoped().Where(db.Q().Matter.Id.Eq(id)).First()
if err != nil {
return err
}
Expand All @@ -243,8 +243,8 @@ func (db *MatterDBQuery) Recovery(ctx context.Context, id int64) error {
return fmt.Errorf("recovery: file parent[%s] not found", m.Parent)
}

_, err = db.q.Matter.WithContext(ctx).Unscoped().Where(db.q.Matter.TrashedBy.Eq(m.TrashedBy)).
UpdateSimple(db.q.Matter.TrashedBy.Value(""), db.q.Matter.DeletedAt.Null())
_, err = db.Q().Matter.WithContext(ctx).Unscoped().Where(db.Q().Matter.TrashedBy.Eq(m.TrashedBy)).
UpdateSimple(db.Q().Matter.TrashedBy.Value(""), db.Q().Matter.DeletedAt.Null())
return err
}

Expand All @@ -271,10 +271,10 @@ func (db *MatterDBQuery) GetObjects(ctx context.Context, id int64) ([]string, er
}

func (db *MatterDBQuery) findChildren(ctx context.Context, m *entity.Matter, withDeleted bool) ([]*entity.Matter, error) {
q := db.q.Matter.WithContext(ctx)
q := db.Q().Matter.WithContext(ctx)
if withDeleted {
q = q.Unscoped()
}

return q.Where(db.q.Matter.Parent.Like(m.FullPath() + "%")).Find()
return q.Where(db.Q().Matter.Parent.Like(m.FullPath() + "%")).Find()
}
26 changes: 4 additions & 22 deletions internal/app/repo/matter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,15 @@ import (
"context"
"database/sql/driver"
"testing"
"time"

"github.com/DATA-DOG/go-sqlmock"
"github.com/saltbo/zpan/internal/app/entity"
"github.com/saltbo/zpan/internal/app/repo/query"
"github.com/stretchr/testify/assert"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)

var nowFunc = func() time.Time {
return time.Unix(0, 0)
}

func newMockDB(t *testing.T) (sqlmock.Sqlmock, *gorm.DB) {
rdb, mock, err := sqlmock.New()
assert.NoError(t, err)
gdb, err := gorm.Open(mysql.New(mysql.Config{Conn: rdb, DriverName: "mysql", SkipInitializeWithVersion: true}), &gorm.Config{
NowFunc: nowFunc,
})
assert.NoError(t, err)
return mock, gdb.Debug()
}

func TestMatterDBQuery_PathExist(t *testing.T) {
mock, gdb := newMockDB(t)
q := NewMatterDBQuery(query.Use(gdb))
mock, db := newMockDB(t)
q := NewMatterDBQuery(db)
mock.ExpectQuery("SELECT").WithArgs("to", "path/")
q.PathExist(context.Background(), "/path/to/")

Expand Down Expand Up @@ -83,7 +65,7 @@ func TestMatterDBQuery_Update(t *testing.T) {

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
mock, gdb := newMockDB(t)
mock, db := newMockDB(t)
mock.ExpectQuery("SELECT").WithArgs(tc.target.Id).
WillReturnRows(tc.rows)

Expand All @@ -97,7 +79,7 @@ func TestMatterDBQuery_Update(t *testing.T) {
WillReturnResult(tc.expectMainResult)
mock.ExpectCommit()

q := NewMatterDBQuery(query.Use(gdb))
q := NewMatterDBQuery(db)
ctx := context.Background()
assert.NoError(t, q.Update(ctx, tc.target.Id, tc.target))
})
Expand Down
15 changes: 7 additions & 8 deletions internal/app/repo/recyclebin.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"

"github.com/saltbo/zpan/internal/app/entity"
"github.com/saltbo/zpan/internal/app/repo/query"
)

type RecycleBinFindOptions struct {
Expand All @@ -23,19 +22,19 @@ type RecycleBin interface {
var _ RecycleBin = (*RecycleBinDBQuery)(nil)

type RecycleBinDBQuery struct {
q *query.Query
DBQuery
}

func NewRecycleBinDBQuery(q *query.Query) *RecycleBinDBQuery {
return &RecycleBinDBQuery{q: q}
func NewRecycleBinDBQuery(q DBQuery) *RecycleBinDBQuery {
return &RecycleBinDBQuery{DBQuery: q}
}

func (r *RecycleBinDBQuery) Find(ctx context.Context, alias string) (*entity.RecycleBin, error) {
return r.q.RecycleBin.WithContext(ctx).Where(r.q.RecycleBin.Alias_.Eq(alias)).First()
return r.Q().RecycleBin.WithContext(ctx).Where(r.Q().RecycleBin.Alias_.Eq(alias)).First()
}

func (r *RecycleBinDBQuery) FindAll(ctx context.Context, opts *RecycleBinFindOptions) (rows []*entity.RecycleBin, total int64, err error) {
q := r.q.RecycleBin.WithContext(ctx).Where(r.q.RecycleBin.Uid.Eq(opts.Uid), r.q.RecycleBin.Sid.Eq(opts.Sid)).Order(r.q.RecycleBin.Id.Desc())
q := r.Q().RecycleBin.WithContext(ctx).Where(r.Q().RecycleBin.Uid.Eq(opts.Uid), r.Q().RecycleBin.Sid.Eq(opts.Sid)).Order(r.Q().RecycleBin.Id.Desc())

if opts.Limit == 0 {
rows, err = q.Find()
Expand All @@ -46,7 +45,7 @@ func (r *RecycleBinDBQuery) FindAll(ctx context.Context, opts *RecycleBinFindOpt
}

func (r *RecycleBinDBQuery) Create(ctx context.Context, m *entity.RecycleBin) error {
return r.q.RecycleBin.WithContext(ctx).Create(m)
return r.Q().RecycleBin.WithContext(ctx).Create(m)
}

func (r *RecycleBinDBQuery) Delete(ctx context.Context, alias string) error {
Expand All @@ -55,6 +54,6 @@ func (r *RecycleBinDBQuery) Delete(ctx context.Context, alias string) error {
return err
}

_, err = r.q.RecycleBin.WithContext(ctx).Delete(m)
_, err = r.Q().RecycleBin.WithContext(ctx).Delete(m)
return err
}
18 changes: 18 additions & 0 deletions internal/app/repo/base.go → internal/app/repo/shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package repo

import (
"context"

"github.com/saltbo/zpan/internal/app/repo/query"
)

type QueryPage struct {
Expand Down Expand Up @@ -44,3 +46,19 @@ type Updater[T comparable, ID IDType] interface {
type Deleter[ID IDType] interface {
Delete(ctx context.Context, id ID) error
}

type DBQuery interface {
Q() *query.Query
}

type DBQueryFactory struct {
q *query.Query
}

func NewDBQueryFactory(q *query.Query) *DBQueryFactory {
return &DBQueryFactory{q: q}
}

func (f *DBQueryFactory) Q() *query.Query {
return f.q
}
26 changes: 26 additions & 0 deletions internal/app/repo/shared_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package repo

import (
"testing"
"time"

"github.com/DATA-DOG/go-sqlmock"
"github.com/saltbo/zpan/internal/app/repo/query"
"github.com/stretchr/testify/assert"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)

var nowFunc = func() time.Time {
return time.Unix(0, 0)
}

func newMockDB(t *testing.T) (sqlmock.Sqlmock, DBQuery) {
rdb, mock, err := sqlmock.New()
assert.NoError(t, err)
gdb, err := gorm.Open(mysql.New(mysql.Config{Conn: rdb, DriverName: "mysql", SkipInitializeWithVersion: true}), &gorm.Config{
NowFunc: nowFunc,
})
assert.NoError(t, err)
return mock, NewDBQueryFactory(query.Use(gdb.Debug()))
}
Loading

0 comments on commit c06eed7

Please sign in to comment.