Skip to content

Commit

Permalink
Batch rows count API (#1063)
Browse files Browse the repository at this point in the history
* Add Rows() func to batch

* Add tests
  • Loading branch information
EpicStep authored Aug 10, 2023
1 parent 4a2e8ff commit f97fdbe
Show file tree
Hide file tree
Showing 27 changed files with 123 additions and 1 deletion.
4 changes: 4 additions & 0 deletions conn_batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,10 @@ func (b *batch) Flush() error {
return nil
}

func (b *batch) Rows() int {
return b.block.Rows()
}

type batchColumn struct {
err error
batch driver.Batch
Expand Down
4 changes: 4 additions & 0 deletions conn_http_batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,4 +229,8 @@ func (b *httpBatch) Send() (err error) {
return err
}

func (b *httpBatch) Rows() int {
return b.block.Rows()
}

var _ driver.Batch = (*httpBatch)(nil)
1 change: 1 addition & 0 deletions lib/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ type (
Flush() error
Send() error
IsSent() bool
Rows() int
}
BatchColumn interface {
Append(any) error
Expand Down
7 changes: 7 additions & 0 deletions tests/array_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ func TestSimpleArray(t *testing.T) {
)
for i := 0; i < 10; i++ {
require.NoError(t, batch.Append(col1Data))
require.Equal(t, 1, batch.Rows())
batch.Flush()
}
require.NoError(t, batch.Send())
Expand Down Expand Up @@ -92,8 +93,10 @@ func TestCustomArray(t *testing.T) {
)
for i := 0; i < 10; i++ {
require.NoError(t, batch.Append(col1Data, col2Data))
require.Equal(t, 1, batch.Rows())
require.NoError(t, batch.Flush())
}
require.Equal(t, 0, batch.Rows())
require.NoError(t, batch.Send())
rows, err := conn.Query(ctx, "SELECT * FROM test_array")
require.NoError(t, err)
Expand Down Expand Up @@ -133,6 +136,7 @@ func TestInterfaceArray(t *testing.T) {
for i := 0; i < 10; i++ {
require.NoError(t, batch.Append(col1Data))
}
require.Equal(t, 10, batch.Rows())
require.Nil(t, batch.Send())
rows, err := conn.Query(ctx, "SELECT * FROM test_array")
require.NoError(t, err)
Expand Down Expand Up @@ -200,8 +204,10 @@ func TestArray(t *testing.T) {
)
for i := 0; i < 10; i++ {
require.NoError(t, batch.Append(col1Data, col2Data, col3Data, col4Data))
require.Equal(t, 1, batch.Rows())
batch.Flush()
}
require.Equal(t, 0, batch.Rows())
require.NoError(t, batch.Send())
rows, err := conn.Query(ctx, "SELECT * FROM test_array")
require.NoError(t, err)
Expand Down Expand Up @@ -289,6 +295,7 @@ func TestColumnarArray(t *testing.T) {
require.NoError(t, batch.Column(1).Append(col2DataColArr))
require.NoError(t, batch.Column(2).Append(col3DataColArr))
require.NoError(t, batch.Column(3).Append(col4DataColArr))
require.Equal(t, 10, batch.Rows())
require.NoError(t, batch.Send())
rows, err := conn.Query(ctx, "SELECT * FROM test_array")
require.NoError(t, err)
Expand Down
9 changes: 9 additions & 0 deletions tests/base_types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ func TestUInt8(t *testing.T) {
require.NoError(t, err)
require.NoError(t, batch.Append(uint8(1), data, &data, []uint8{data}, []*uint8{&data, nil, &data}, customUint8(data)))
require.NoError(t, batch.Append(uint8(2), data, nil, []uint8{data}, []*uint8{nil, nil, &data}, customUint8(data)))
require.Equal(t, 2, batch.Rows())
require.NoError(t, batch.Send())
var (
result1 result
Expand Down Expand Up @@ -151,6 +152,7 @@ func TestColumnarUInt8(t *testing.T) {
return
}
}
require.Equal(t, 1000, batch.Rows())
require.NoError(t, batch.Send())
var result struct {
Col1 uint8
Expand Down Expand Up @@ -209,6 +211,7 @@ func TestNullableInt(t *testing.T) {
col5Data := sql.NullInt16{Int16: 3, Valid: true}
col6Data := sql.NullInt16{Int16: 0, Valid: false}
require.NoError(t, batch.Append(col1Data, col2Data, col3Data, col4Data, col5Data, col6Data))
require.Equal(t, 1, batch.Rows())
require.NoError(t, batch.Send())
var (
col1 sql.NullInt64
Expand Down Expand Up @@ -250,9 +253,15 @@ func TestIntFlush(t *testing.T) {
vals[i] = uint8(i)
require.NoError(t, batch.Append(vals[i]))
if i%100 == 0 {
if i == 0 {
require.Equal(t, 1, batch.Rows())
} else {
require.Equal(t, 100, batch.Rows())
}
require.NoError(t, batch.Flush())
}
}
require.Equal(t, 99, batch.Rows())
batch.Send()
rows, err := conn.Query(ctx, "SELECT * FROM int_flush")
require.NoError(t, err)
Expand Down
6 changes: 6 additions & 0 deletions tests/bigint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ func TestSimpleBigInt(t *testing.T) {
col1Data, ok := new(big.Int).SetString("170141183460469231731687303715884105727", 10)
require.True(t, ok)
require.NoError(t, batch.Append(col1Data))
require.Equal(t, 1, batch.Rows())
require.NoError(t, batch.Send())
var (
col1 big.Int
Expand Down Expand Up @@ -111,6 +112,7 @@ func TestBigInt(t *testing.T) {
}
)
require.NoError(t, batch.Append(col1Data, col2Data, col3Data, col4Data, col5Data, col6Data, col7Data))
require.Equal(t, 1, batch.Rows())
require.NoError(t, batch.Send())
var (
col1 big.Int
Expand Down Expand Up @@ -178,6 +180,7 @@ func TestNullableBigInt(t *testing.T) {
}
)
require.NoError(t, batch.Append(col1Data, col2Data, col3Data, col4Data, col5Data, col6Data))
require.Equal(t, 1, batch.Rows())
require.NoError(t, batch.Send())
var (
col1 *big.Int
Expand Down Expand Up @@ -246,6 +249,7 @@ func TestBigIntUIntOverflow(t *testing.T) {
}
)
require.NoError(t, batch.Append(col1Data, col2Data, col3Data, col4Data, col5Data, col6Data))
require.Equal(t, 1, batch.Rows())
require.NoError(t, batch.Send())
var (
col1 big.Int
Expand Down Expand Up @@ -287,8 +291,10 @@ func TestBigIntFlush(t *testing.T) {
bigUint128Val.SetString(RandIntString(20), 10)
vals[i] = bigUint128Val
batch.Append(vals[i])
require.Equal(t, 1, batch.Rows())
batch.Flush()
}
require.Equal(t, 0, batch.Rows())
batch.Send()
rows, err := conn.Query(ctx, "SELECT * FROM big_int_flush")
require.NoError(t, err)
Expand Down
4 changes: 4 additions & 0 deletions tests/bool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ func TestBool(t *testing.T) {
Bool: false,
Valid: false,
}))
require.Equal(t, 1, batch.Rows())
require.NoError(t, batch.Send())
var (
col1 bool
Expand Down Expand Up @@ -140,6 +141,7 @@ func TestColumnarBool(t *testing.T) {
require.NoError(t, batch.Column(3).Append(col3))
require.NoError(t, batch.Column(4).Append(col4))
require.NoError(t, batch.Column(5).Append(col5))
require.Equal(t, 1000, batch.Rows())
require.NoError(t, batch.Send())
var (
id uint64
Expand Down Expand Up @@ -182,8 +184,10 @@ func TestBoolFlush(t *testing.T) {
for i := 0; i < 1000; i++ {
vals[i] = r.Intn(2) != 0
require.NoError(t, batch.Append(vals[i]))
require.Equal(t, 1, batch.Rows())
require.NoError(t, batch.Flush())
}
require.Equal(t, 0, batch.Rows())
batch.Send()
rows, err := conn.Query(ctx, "SELECT * FROM bool_flush")
require.NoError(t, err)
Expand Down
5 changes: 5 additions & 0 deletions tests/columnar_batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ func TestColumnarInterface(t *testing.T) {
require.NoError(t, batch.Column(3).Append(col4Data))
require.NoError(t, batch.Column(4).Append(col5Data))
require.NoError(t, batch.Column(5).Append(col6Data))
require.Equal(t, 150, batch.Rows())
require.NoError(t, batch.Send())
var count uint64
require.NoError(t, conn.QueryRow(ctx, "SELECT COUNT() FROM test_column_interface").Scan(&count))
Expand Down Expand Up @@ -151,6 +152,7 @@ func TestNullableColumnarInterface(t *testing.T) {
require.NoError(t, batch.Column(1).Append(col2Data))
require.NoError(t, batch.Column(2).Append(col3Data))
require.NoError(t, batch.Column(3).Append(col4Data))
require.Equal(t, 150, batch.Rows())
require.NoError(t, batch.Send())
var count uint64
require.NoError(t, conn.QueryRow(ctx, "SELECT COUNT() FROM test_column_interface").Scan(&count))
Expand Down Expand Up @@ -210,6 +212,7 @@ func TestNullableColumnarInterface(t *testing.T) {
require.NoError(t, batch.Column(1).Append(col2Data))
require.NoError(t, batch.Column(2).Append(col3Data))
require.NoError(t, batch.Column(3).Append(col4Data))
require.Equal(t, 150, batch.Rows())
require.NoError(t, batch.Send())
var count uint64
require.NoError(t, conn.QueryRow(ctx, "SELECT COUNT() FROM test_column_interface").Scan(&count))
Expand Down Expand Up @@ -283,6 +286,7 @@ func TestColumnarAppendRowInterface(t *testing.T) {
require.NoError(t, batch.Column(4).AppendRow(sql.NullTime{Time: currentTime, Valid: true}))
require.NoError(t, batch.Column(5).AppendRow(sql.NullInt64{Int64: int64(i), Valid: true}))
}
require.Equal(t, 150, batch.Rows())
require.NoError(t, batch.Send())
var count uint64
require.NoError(t, conn.QueryRow(ctx, "SELECT COUNT() FROM test_column_interface").Scan(&count))
Expand Down Expand Up @@ -352,6 +356,7 @@ func TestNullableAppendRowColumnarInterface(t *testing.T) {
require.NoError(t, batch.Column(3).AppendRow(&decimalVal))
}
}
require.Equal(t, 150, batch.Rows())
require.NoError(t, batch.Send())
var count uint64
require.NoError(t, conn.QueryRow(ctx, "SELECT COUNT() FROM test_column_interface").Scan(&count))
Expand Down
8 changes: 8 additions & 0 deletions tests/date32_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ func TestDate32(t *testing.T) {
require.NoError(t, batch.Append(uint8(1), date1, &date2, []time.Time{date2}, []*time.Time{&date2, nil, &date1}, dateStr1, dateStrNil, []string{dateStr1, dateStr2, dateStr3}, []*string{dateStrNil, &dateStr1, dateStrNil}))
require.NoError(t, batch.Append(uint8(2), date2, nil, []time.Time{date1}, []*time.Time{nil, nil, &date2}, &testStr{Col1: dateStr1}, nil, []string{dateStr1, dateStr2, dateStr3}, []*string{nil, &dateStr1, dateStrNil}))
require.NoError(t, batch.Append(uint8(3), date3, nil, []time.Time{date3}, []*time.Time{nil, nil, &date3}, &testStr{Col1: dateStr1}, &dateStr1, []string{dateStr1, dateStr2, dateStr3}, []*string{nil, nil, dateStrNil}))
require.Equal(t, 3, batch.Rows())
require.NoError(t, batch.Send())
var (
result1 result
Expand Down Expand Up @@ -158,6 +159,7 @@ func TestNullableDate32(t *testing.T) {
require.NoError(t, err)
dateStr := "2283-11-11"
require.NoError(t, batch.Append(date, &date, dateStr, &dateStr))
require.Equal(t, 1, batch.Rows())
require.NoError(t, batch.Send())
var (
col1 *time.Time
Expand All @@ -176,6 +178,7 @@ func TestNullableDate32(t *testing.T) {
date, err = time.Parse("2006-01-02 15:04:05", "1925-01-01 00:00:00")
require.NoError(t, err)
require.NoError(t, batch.Append(date, nil, &date, nil))
require.Equal(t, 1, batch.Rows())
require.NoError(t, batch.Send())
col2 = nil
col4 = nil
Expand Down Expand Up @@ -245,6 +248,7 @@ func TestColumnarDate32(t *testing.T) {
require.NoError(t, batch.Column(3).Append(col3Data))
require.NoError(t, batch.Column(4).Append(col4Data))
}
require.Equal(t, 1000, batch.Rows())
require.NoError(t, batch.Send())
var result struct {
Col1 time.Time
Expand Down Expand Up @@ -286,8 +290,10 @@ func TestDate32Flush(t *testing.T) {
for i := 0; i < 1000; i++ {
vals[i] = now.Add(time.Duration(i) * time.Hour)
batch.Append(vals[i])
require.Equal(t, 1, batch.Rows())
batch.Flush()
}
require.Equal(t, 0, batch.Rows())
batch.Send()
rows, err := conn.Query(ctx, "SELECT * FROM date_32_flush")
require.NoError(t, err)
Expand Down Expand Up @@ -321,6 +327,7 @@ func TestDate32TZ(t *testing.T) {
"2022-07-20 +08:00",
))
require.NoError(t, err)
require.Equal(t, 1, batch.Rows())
require.NoError(t, batch.Send())
var (
col15, col16 time.Time
Expand Down Expand Up @@ -355,6 +362,7 @@ func TestCustomDateTime32(t *testing.T) {
require.NoError(t, err)
now := time.Now().UTC().Truncate(time.Hour)
require.NoError(t, batch.Append(now))
require.Equal(t, 1, batch.Rows())
require.NoError(t, batch.Send())
row := conn.QueryRow(ctx, "SELECT * FROM date32_custom")
var col1 CustomDateTime
Expand Down
6 changes: 6 additions & 0 deletions tests/datetime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ func TestDateTime(t *testing.T) {
&testStr{Col1: dateTimeStr},
iDateTime,
))
require.Equal(t, 1, batch.Rows())
require.NoError(t, batch.Send())
var (
col1 time.Time
Expand Down Expand Up @@ -141,6 +142,7 @@ func TestNullableDateTime(t *testing.T) {
require.NoError(t, err)
datetime := time.Now().Truncate(time.Second)
require.NoError(t, batch.Append(datetime, datetime, datetime, datetime, datetime, datetime, datetime, datetime, datetime, datetime))
require.Equal(t, 1, batch.Rows())
require.NoError(t, batch.Send())
var (
col1 time.Time
Expand Down Expand Up @@ -181,6 +183,7 @@ func TestNullableDateTime(t *testing.T) {
datetimeNilStr *string = nil
)
require.NoError(t, batch.Append(datetime, nil, datetime, nil, datetime, nil, datetimeStr, nil, datetimeStr, datetimeNilStr))
require.Equal(t, 1, batch.Rows())
require.NoError(t, batch.Send())
var (
col1 time.Time
Expand Down Expand Up @@ -294,6 +297,7 @@ func TestColumnarDateTime(t *testing.T) {
require.NoError(t, batch.Column(6).Append(col6Data))
require.NoError(t, batch.Column(7).Append(col7Data))
}
require.Equal(t, 1000, batch.Rows())
require.NoError(t, batch.Send())
var result struct {
Col1 time.Time
Expand Down Expand Up @@ -338,8 +342,10 @@ func TestDateTimeFlush(t *testing.T) {
for i := 0; i < 1000; i++ {
vals[i] = now.Add(time.Duration(i) * time.Hour).Truncate(time.Second)
batch.Append(vals[i])
require.Equal(t, 1, batch.Rows())
batch.Flush()
}
require.Equal(t, 0, batch.Rows())
batch.Send()
rows, err := conn.Query(ctx, "SELECT * FROM datetime_flush")
require.NoError(t, err)
Expand Down
10 changes: 9 additions & 1 deletion tests/decimal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func TestDecimal(t *testing.T) {
decimal.New(135, 7),
decimal.New(256, 8),
))
require.Equal(t, 1, batch.Rows())
require.NoError(t, batch.Send())
var (
col1 decimal.Decimal
Expand Down Expand Up @@ -109,6 +110,7 @@ func TestNegativeDecimal(t *testing.T) {
decimal.RequireFromString("-0.01171"),
decimal.RequireFromString("-3.0111"),
decimal.RequireFromString("-21111122.0111111111111111111171")))
require.Equal(t, 1, batch.Rows())
require.NoError(t, batch.Send())
var (
col1 decimal.Decimal
Expand Down Expand Up @@ -149,6 +151,7 @@ func TestNullableDecimal(t *testing.T) {
batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_decimal")
require.NoError(t, err)
require.NoError(t, batch.Append(decimal.New(25, 0), decimal.New(30, 0), decimal.New(35, 0)))
require.Equal(t, 1, batch.Rows())
require.NoError(t, batch.Send())
var (
col1 *decimal.Decimal
Expand All @@ -163,6 +166,7 @@ func TestNullableDecimal(t *testing.T) {
batch, err = conn.PrepareBatch(ctx, "INSERT INTO test_decimal")
require.NoError(t, err)
require.NoError(t, batch.Append(decimal.New(25, 0), nil, decimal.New(35, 0)))
require.Equal(t, 1, batch.Rows())
require.NoError(t, batch.Send())
{
var (
Expand Down Expand Up @@ -198,8 +202,10 @@ func TestDecimalFlush(t *testing.T) {
for i := 0; i < 1000; i++ {
vals[i] = decimal.RequireFromString(fmt.Sprintf("1.%s", RandIntString(5)))
batch.Append(vals[i])
require.Equal(t, 1, batch.Rows())
batch.Flush()
}
require.Equal(t, 0, batch.Rows())
batch.Send()
rows, err := conn.Query(ctx, "SELECT * FROM decimal_flush")
require.NoError(t, err)
Expand Down Expand Up @@ -253,9 +259,11 @@ func TestRoundDecimals(t *testing.T) {
decimal.NewFromFloat(601), // this will make decimal 601*e^0
decimal.NewFromFloat(601.21), // check that normal case is working
}
for _, c := range checks {
for i, c := range checks {
batch.Append(c)
require.Equal(t, i+1, batch.Rows())
}
require.Equal(t, 3, batch.Rows())
batch.Send()
rows, err := conn.Query(ctx, "SELECT * FROM decimal_flush ORDER BY Col1 asc")
require.NoError(tt, err)
Expand Down
1 change: 1 addition & 0 deletions tests/empty_query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,6 @@ func TestEmptyQuery(t *testing.T) {
defer cancel()
batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_empty_query")
require.NoError(t, err)
require.Equal(t, 0, batch.Rows())
assert.NoError(t, batch.Send())
}
Loading

0 comments on commit f97fdbe

Please sign in to comment.