Skip to content

Commit

Permalink
feat: Add method ScanAndCount and AllAndCount
Browse files Browse the repository at this point in the history
  • Loading branch information
lusson-luo committed May 15, 2023
1 parent bda5d25 commit 9ea5611
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 0 deletions.
115 changes: 115 additions & 0 deletions contrib/drivers/sqlite/sqlite_model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,41 @@ func Test_Model_All(t *testing.T) {
})
}

func Test_Model_AllAndCount(t *testing.T) {
table := createInitTable()
defer dropTable(table)
// AllAndCount with all data
gtest.C(t, func(t *gtest.T) {
result, count, err := db.Model(table).AllAndCount()
t.AssertNil(err)
t.Assert(len(result), TableSize)
t.Assert(count, TableSize)
})
// AllAndCount with no data
gtest.C(t, func(t *gtest.T) {
result, count, err := db.Model(table).Where("id<0").AllAndCount()
t.Assert(result, nil)
t.AssertNil(err)
t.Assert(count, 0)
})
// AllAndCount with page
gtest.C(t, func(t *gtest.T) {
result, count, err := db.Model(table).Page(1, 5).AllAndCount()
t.AssertNil(err)
t.Assert(len(result), 5)
t.Assert(count, TableSize)
})
// AllAndCount with normal result
gtest.C(t, func(t *gtest.T) {
result, count, err := db.Model(table).Where("id=?", 1).AllAndCount()
t.AssertNil(err)
t.Assert(count, 1)
t.Assert(result[0]["id"], 1)
t.Assert(result[0]["nickname"], "name_1")
t.Assert(result[0]["passport"], "user_1")
})
}

func Test_Model_Fields(t *testing.T) {
tableName1 := createInitTable()
defer dropTable(tableName1)
Expand Down Expand Up @@ -1080,6 +1115,86 @@ func Test_Model_Scan(t *testing.T) {
})
}

func Test_Model_ScanAndCount(t *testing.T) {
table := createInitTable()
defer dropTable(table)
// ScanAndCount with normal struct result
gtest.C(t, func(t *gtest.T) {
type User struct {
Id int
Passport string
Password string
NickName string
CreateTime *gtime.Time
}
user := new(User)
var count int
err := db.Model(table).Where("id=1").ScanAndCount(user, &count)
t.AssertNil(err)
t.Assert(user.NickName, "name_1")
t.Assert(user.CreateTime.String(), CreateTime)
t.Assert(count, 1)
})
// ScanAndCount with normal array result
gtest.C(t, func(t *gtest.T) {
type User struct {
Id int
Passport string
Password string
NickName string
CreateTime gtime.Time
}
var users []User
var count int
err := db.Model(table).Order("id asc").ScanAndCount(&users, &count)
t.AssertNil(err)
t.Assert(len(users), TableSize)
t.Assert(users[0].Id, 1)
t.Assert(users[1].Id, 2)
t.Assert(users[2].Id, 3)
t.Assert(users[0].NickName, "name_1")
t.Assert(users[1].NickName, "name_2")
t.Assert(users[2].NickName, "name_3")
t.Assert(users[0].CreateTime.String(), CreateTime)
t.Assert(count, len(users))
})
// sql.ErrNoRows
gtest.C(t, func(t *gtest.T) {
type User struct {
Id int
Passport string
Password string
NickName string
CreateTime *gtime.Time
}
var (
user = new(User)
users = new([]*User)
)
var count int
err1 := db.Model(table).Where("id < 0").ScanAndCount(user, &count)
err2 := db.Model(table).Where("id < 0").ScanAndCount(users, &count)
t.Assert(err1, nil)
t.Assert(err2, nil)
})
// ScanAndCount with page
gtest.C(t, func(t *gtest.T) {
type User struct {
Id int
Passport string
Password string
NickName string
CreateTime gtime.Time
}
var users []User
var count int
err := db.Model(table).Order("id asc").Page(1, 3).ScanAndCount(&users, &count)
t.AssertNil(err)
t.Assert(len(users), 3)
t.Assert(count, TableSize)
})
}

func Test_Model_Scan_NilSliceAttrWhenNoRecordsFound(t *testing.T) {
table := createTable()
defer dropTable(table)
Expand Down
64 changes: 64 additions & 0 deletions database/gdb/gdb_model_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,36 @@ func (m *Model) All(where ...interface{}) (Result, error) {
return m.doGetAll(ctx, false, where...)
}

// AllAndCount retrieves all records that match the given conditions and counts the total number of records.
// It returns a Result containing the retrieved records and an integer representing the total number of records that match the given conditions.
// The where parameter is an optional list of conditions to use when retrieving records.
//
// Example:
//
// var model Model
// var result Result
// var count int
// where := []interface{}{"name = ?", "John"}
// result, count, err := model.AllAndCount(where...)
// if err != nil {
// // Handle error.
// }
// // Use the retrieved records and total count.
// fmt.Println(result, count)
func (m *Model) AllAndCount() (result Result, count int, err error) {
countModel := m.Clone()
countModel.fields = "1"
count, err = countModel.Count()
if err != nil {
return
}
if count == 0 {
return
}
result, err = m.doGetAll(m.GetCtx(), false)
return
}

// Chunk iterates the query result with given `size` and `handler` function.
func (m *Model) Chunk(size int, handler ChunkHandler) {
page := m.start
Expand Down Expand Up @@ -235,6 +265,40 @@ func (m *Model) Scan(pointer interface{}, where ...interface{}) error {
}
}

// ScanAndCount scans a single record or record array that matches the given conditions and counts the total number of records that match those conditions.
// The pointer parameter is a pointer to a struct that the scanned data will be stored in.
// The pointerCount parameter is a pointer to an integer that will be set to the total number of records that match the given conditions.
// The where parameter is an optional list of conditions to use when retrieving records.
// ScanAndCount can't support Fileds with *, example: .Fileds("a.*, b.name").ScanAndCount()
//
// Example:
//
// var count int
// user := new(User)
// err := db.Model("user").Where("id", 1).ScanAndCount(user,&count)
//
// // Use the retrieved data and total count.
// fmt.Println(user, count)
//
// Example:
// var count int
// var users []User
// dao.Station.Ctx(ctx).Page(page.PageNo, page.PageSize).ScanAndCount(&users, &count)
func (m *Model) ScanAndCount(pointer interface{}, count *int) (err error) {
// support Fileds with *, example: .Fileds("a.*, b.name"). Count sql is select count(1) from xxx
countModel := m.Clone()
countModel.fields = "1"
*count, err = countModel.Count()
if err != nil {
return
}
if *count == 0 {
return
}
err = m.Scan(pointer)
return
}

// ScanList converts `r` to struct slice which contains other complex struct attributes.
// Note that the parameter `listPointer` should be type of *[]struct/*[]*struct.
//
Expand Down

0 comments on commit 9ea5611

Please sign in to comment.