Skip to content

Commit

Permalink
fix: allow LIMIT 0 for SELECT queries
Browse files Browse the repository at this point in the history
This commit enables the correct queries of type `SELECT ... LIMIT 0`.
Before that, the limit-clause wasn't applied to the query.
  • Loading branch information
ygabuev committed Oct 19, 2023
1 parent 8a43835 commit 690f449
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 3 deletions.
52 changes: 52 additions & 0 deletions internal/dbtest/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ func TestDB(t *testing.T) {
{testNilModel},
{testSelectScan},
{testSelectCount},
{testSelectLimit},
{testSelectMap},
{testSelectMapSlice},
{testSelectStruct},
Expand Down Expand Up @@ -348,6 +349,36 @@ func testSelectCount(t *testing.T, db *bun.DB) {
require.Equal(t, 3, count)
}

func testSelectLimit(t *testing.T, db *bun.DB) {
if !db.Dialect().Features().Has(feature.CTE) {
t.Skip()
return
}

values := db.NewValues(&[]map[string]interface{}{
{"num": 1},
{"num": 2},
{"num": 3},
})

q := db.NewSelect().
With("t", values).
Column("t.num").
TableExpr("t")

count, err := q.Limit(5).Count(ctx)
require.NoError(t, err)
require.Equal(t, 3, count)

count, err = q.Limit(2).Count(ctx)
require.NoError(t, err)
require.Equal(t, 2, count)

count, err = q.Limit(0).Count(ctx)
require.NoError(t, err)
require.Equal(t, 0, count)
}

func testSelectMap(t *testing.T, db *bun.DB) {
var m map[string]interface{}
err := db.NewSelect().
Expand Down Expand Up @@ -1344,6 +1375,9 @@ func testScanAndCount(t *testing.T, db *bun.DB) {
})

t.Run("no limit", func(t *testing.T) {
err := db.ResetModel(ctx, (*Model)(nil))
require.NoError(t, err)

src := []Model{
{Str: "str1"},
{Str: "str2"},
Expand All @@ -1357,6 +1391,24 @@ func testScanAndCount(t *testing.T, db *bun.DB) {
require.Equal(t, 2, count)
require.Equal(t, 2, len(dest))
})

t.Run("limit 0", func(t *testing.T) {
err := db.ResetModel(ctx, (*Model)(nil))
require.NoError(t, err)

src := []Model{
{Str: "str1"},
{Str: "str2"},
}
_, err = db.NewInsert().Model(&src).Exec(ctx)
require.NoError(t, err)

var dest []Model
count, err := db.NewSelect().Model(&dest).Limit(0).ScanAndCount(ctx)
require.NoError(t, err)
require.Equal(t, 0, count)
require.Equal(t, 0, len(dest))
})
}

func testEmbedModelValue(t *testing.T, db *bun.DB) {
Expand Down
7 changes: 4 additions & 3 deletions query_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ func NewSelectQuery(db *DB) *SelectQuery {
conn: db.DB,
},
},
limit: -1,
}
}

Expand Down Expand Up @@ -631,7 +632,7 @@ func (q *SelectQuery) appendQuery(
b = append(b, " ROWS"...)
}
} else {
if q.limit > 0 {
if q.limit >= 0 {
b = append(b, " LIMIT "...)
b = strconv.AppendInt(b, int64(q.limit), 10)
}
Expand Down Expand Up @@ -958,7 +959,7 @@ func (q *SelectQuery) scanAndCountConc(ctx context.Context, dest ...interface{})
var mu sync.Mutex
var firstErr error

if q.limit >= 0 {
if q.limit >= -1 {
wg.Add(1)
go func() {
defer wg.Done()
Expand Down Expand Up @@ -995,7 +996,7 @@ func (q *SelectQuery) scanAndCountConc(ctx context.Context, dest ...interface{})
func (q *SelectQuery) scanAndCountSeq(ctx context.Context, dest ...interface{}) (int, error) {
var firstErr error

if q.limit >= 0 {
if q.limit >= -1 {
firstErr = q.Scan(ctx, dest...)
}

Expand Down

0 comments on commit 690f449

Please sign in to comment.