From 755bd58c0bc1b7b21a0283745b268a7138ebfa3a Mon Sep 17 00:00:00 2001
From: Nikita Karmatskikh <nickkarmatskikh@gmail.com>
Date: Sun, 21 Nov 2021 01:00:39 +0300
Subject: [PATCH] qb: add named limit and per partition limit clauses

---
 qb/limit.go       | 45 +++++++++++++++++++++++++++++++++++++++++++++
 qb/select.go      | 34 ++++++++++++++++++----------------
 qb/select_test.go | 24 ++++++++++++++++++++++++
 3 files changed, 87 insertions(+), 16 deletions(-)
 create mode 100644 qb/limit.go

diff --git a/qb/limit.go b/qb/limit.go
new file mode 100644
index 0000000..e91a9a7
--- /dev/null
+++ b/qb/limit.go
@@ -0,0 +1,45 @@
+// Copyright (C) 2017 ScyllaDB
+// Use of this source code is governed by a ALv2-style
+// license that can be found in the LICENSE file.
+
+package qb
+
+import (
+	"bytes"
+	"strconv"
+)
+
+type limit struct {
+	value        value
+	perPartition bool
+}
+
+func limitLit(l uint, perPartition bool) limit {
+	val := strconv.FormatUint(uint64(l), 10)
+	return limit{
+		value:        lit(val),
+		perPartition: perPartition,
+	}
+}
+
+func limitNamed(name string, perPartition bool) limit {
+	return limit{
+		value:        param(name),
+		perPartition: perPartition,
+	}
+}
+
+func (l limit) writeCql(cql *bytes.Buffer) (names []string) {
+	if l.value == nil {
+		return nil
+	}
+
+	if l.perPartition {
+		cql.WriteString("PER PARTITION ")
+	}
+	cql.WriteString("LIMIT ")
+
+	names = l.value.writeCql(cql)
+	cql.WriteByte(' ')
+	return
+}
diff --git a/qb/select.go b/qb/select.go
index ac18e29..ce8c7a0 100644
--- a/qb/select.go
+++ b/qb/select.go
@@ -10,7 +10,6 @@ package qb
 import (
 	"bytes"
 	"context"
-	"fmt"
 	"time"
 
 	"github.com/scylladb/gocqlx/v2"
@@ -42,8 +41,8 @@ type SelectBuilder struct {
 	where             where
 	groupBy           columns
 	orderBy           columns
-	limit             uint
-	limitPerPartition uint
+	limit             limit
+	limitPerPartition limit
 	allowFiltering    bool
 	bypassCache       bool
 	json              bool
@@ -100,17 +99,8 @@ func (b *SelectBuilder) ToCql() (stmt string, names []string) {
 		cql.WriteByte(' ')
 	}
 
-	if b.limit != 0 {
-		cql.WriteString("LIMIT ")
-		cql.WriteString(fmt.Sprint(b.limit))
-		cql.WriteByte(' ')
-	}
-
-	if b.limitPerPartition != 0 {
-		cql.WriteString("PER PARTITION LIMIT ")
-		cql.WriteString(fmt.Sprint(b.limitPerPartition))
-		cql.WriteByte(' ')
-	}
+	names = append(names, b.limitPerPartition.writeCql(&cql)...)
+	names = append(names, b.limit.writeCql(&cql)...)
 
 	if b.allowFiltering {
 		cql.WriteString("ALLOW FILTERING ")
@@ -214,13 +204,25 @@ func (b *SelectBuilder) OrderBy(column string, o Order) *SelectBuilder {
 
 // Limit sets a LIMIT clause on the query.
 func (b *SelectBuilder) Limit(limit uint) *SelectBuilder {
-	b.limit = limit
+	b.limit = limitLit(limit, false)
+	return b
+}
+
+// LimitNamed produces LIMIT ? clause with a custom parameter name.
+func (b *SelectBuilder) LimitNamed(name string) *SelectBuilder {
+	b.limit = limitNamed(name, false)
 	return b
 }
 
 // LimitPerPartition sets a PER PARTITION LIMIT clause on the query.
 func (b *SelectBuilder) LimitPerPartition(limit uint) *SelectBuilder {
-	b.limitPerPartition = limit
+	b.limitPerPartition = limitLit(limit, true)
+	return b
+}
+
+// LimitPerPartitionNamed produces PER PARTITION LIMIT ? clause with a custom parameter name.
+func (b *SelectBuilder) LimitPerPartitionNamed(name string) *SelectBuilder {
+	b.limitPerPartition = limitNamed(name, true)
 	return b
 }
 
diff --git a/qb/select_test.go b/qb/select_test.go
index be01991..85d66ac 100644
--- a/qb/select_test.go
+++ b/qb/select_test.go
@@ -128,12 +128,36 @@ func TestSelectBuilder(t *testing.T) {
 			S: "SELECT * FROM cycling.cyclist_name WHERE id=? LIMIT 10 ",
 			N: []string{"expr"},
 		},
+		// Add named LIMIT
+		{
+			B: Select("cycling.cyclist_name").Where(w).LimitNamed("limit"),
+			S: "SELECT * FROM cycling.cyclist_name WHERE id=? LIMIT ? ",
+			N: []string{"expr", "limit"},
+		},
 		// Add PER PARTITION LIMIT
 		{
 			B: Select("cycling.cyclist_name").Where(w).LimitPerPartition(10),
 			S: "SELECT * FROM cycling.cyclist_name WHERE id=? PER PARTITION LIMIT 10 ",
 			N: []string{"expr"},
 		},
+		// Add named PER PARTITION LIMIT
+		{
+			B: Select("cycling.cyclist_name").Where(w).LimitPerPartitionNamed("partition_limit"),
+			S: "SELECT * FROM cycling.cyclist_name WHERE id=? PER PARTITION LIMIT ? ",
+			N: []string{"expr", "partition_limit"},
+		},
+		// Add PER PARTITION LIMIT and LIMIT
+		{
+			B: Select("cycling.cyclist_name").Where(w).LimitPerPartition(2).Limit(10),
+			S: "SELECT * FROM cycling.cyclist_name WHERE id=? PER PARTITION LIMIT 2 LIMIT 10 ",
+			N: []string{"expr"},
+		},
+		// Add named PER PARTITION LIMIT and LIMIT
+		{
+			B: Select("cycling.cyclist_name").Where(w).LimitPerPartitionNamed("partition_limit").LimitNamed("limit"),
+			S: "SELECT * FROM cycling.cyclist_name WHERE id=? PER PARTITION LIMIT ? LIMIT ? ",
+			N: []string{"expr", "partition_limit", "limit"},
+		},
 		// Add ALLOW FILTERING
 		{
 			B: Select("cycling.cyclist_name").Where(w).AllowFiltering(),