From 6145ed6f05a974d22b7c0ed01235816a7f724809 Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 2 May 2024 21:09:42 -0400
Subject: [PATCH] fix(go/adbc/driver/snowflake): handle empty result sets
Fixes #1804.
---
c/validation/adbc_validation.h | 2 +
c/validation/adbc_validation_statement.cc | 24 ++++
go/adbc/driver/snowflake/record_reader.go | 149 ++++++++++++----------
3 files changed, 105 insertions(+), 70 deletions(-)
diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h
index 6c59d95e09..abe9a76868 100644
--- a/c/validation/adbc_validation.h
+++ b/c/validation/adbc_validation.h
@@ -407,6 +407,7 @@ class StatementTest {
void TestSqlPrepareErrorNoQuery();
void TestSqlPrepareErrorParamCountMismatch();
+ void TestSqlQueryEmpty();
void TestSqlQueryInts();
void TestSqlQueryFloats();
void TestSqlQueryStrings();
@@ -504,6 +505,7 @@ class StatementTest {
TEST_F(FIXTURE, SqlPrepareErrorParamCountMismatch) { \
TestSqlPrepareErrorParamCountMismatch(); \
} \
+ TEST_F(FIXTURE, SqlQueryEmpty) { TestSqlQueryEmpty(); } \
TEST_F(FIXTURE, SqlQueryInts) { TestSqlQueryInts(); } \
TEST_F(FIXTURE, SqlQueryFloats) { TestSqlQueryFloats(); } \
TEST_F(FIXTURE, SqlQueryStrings) { TestSqlQueryStrings(); } \
diff --git a/c/validation/adbc_validation_statement.cc b/c/validation/adbc_validation_statement.cc
index 59f3f3f9a3..7da58c6986 100644
--- a/c/validation/adbc_validation_statement.cc
+++ b/c/validation/adbc_validation_statement.cc
@@ -2062,6 +2062,30 @@ void StatementTest::TestSqlPrepareErrorParamCountMismatch() {
::testing::Not(IsOkStatus(&error)));
}
+void StatementTest::TestSqlQueryEmpty() {
+ ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42 WHERE 1=0", &error),
+ IsOkStatus(&error));
+ StreamReader reader;
+ ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
+ &reader.rows_affected, &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(reader.rows_affected,
+ ::testing::AnyOf(::testing::Eq(0), ::testing::Eq(-1)));
+
+ ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+ ASSERT_EQ(1, reader.schema->n_children);
+
+ while (true) {
+ ASSERT_NO_FATAL_FAILURE(reader.Next());
+ if (!reader.array->release) {
+ break;
+ }
+ ASSERT_EQ(0, reader.array->length);
+ }
+ ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
+}
+
void StatementTest::TestSqlQueryInts() {
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42", &error),
diff --git a/go/adbc/driver/snowflake/record_reader.go b/go/adbc/driver/snowflake/record_reader.go
index bda3e8f70d..5c71322209 100644
--- a/go/adbc/driver/snowflake/record_reader.go
+++ b/go/adbc/driver/snowflake/record_reader.go
@@ -571,23 +571,9 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake
}
ch := make(chan arrow.Record, bufferSize)
- r, err := batches[0].GetStream(ctx)
- if err != nil {
- return nil, errToAdbcErr(adbc.StatusIO, err)
- }
-
- rr, err := ipc.NewReader(r, ipc.WithAllocator(alloc))
- if err != nil {
- return nil, adbc.Error{
- Msg: err.Error(),
- Code: adbc.StatusInvalidState,
- }
- }
-
group, ctx := errgroup.WithContext(compute.WithAllocator(ctx, alloc))
ctx, cancelFn := context.WithCancel(ctx)
-
- schema, recTransform := getTransformer(rr.Schema(), ld, useHighPrecision)
+ group.SetLimit(prefetchConcurrency)
defer func() {
if err != nil {
@@ -596,80 +582,103 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake
}
}()
- group.SetLimit(prefetchConcurrency)
- group.Go(func() error {
- defer rr.Release()
- defer r.Close()
- if len(batches) > 1 {
- defer close(ch)
- }
-
- for rr.Next() && ctx.Err() == nil {
- rec := rr.Record()
- rec, err = recTransform(ctx, rec)
- if err != nil {
- return err
- }
- ch <- rec
- }
- return rr.Err()
- })
-
chs := make([]chan arrow.Record, len(batches))
- chs[0] = ch
rdr := &reader{
refCount: 1,
chs: chs,
err: nil,
cancelFn: cancelFn,
- schema: schema,
}
- lastChannelIndex := len(chs) - 1
- go func() {
- for i, b := range batches[1:] {
- batch, batchIdx := b, i+1
- chs[batchIdx] = make(chan arrow.Record, bufferSize)
- group.Go(func() error {
- // close channels (except the last) so that Next can move on to the next channel properly
- if batchIdx != lastChannelIndex {
- defer close(chs[batchIdx])
- }
+ if len(batches) > 0 {
+ r, err := batches[0].GetStream(ctx)
+ if err != nil {
+ return nil, errToAdbcErr(adbc.StatusIO, err)
+ }
- rdr, err := batch.GetStream(ctx)
- if err != nil {
- return err
- }
- defer rdr.Close()
+ rr, err := ipc.NewReader(r, ipc.WithAllocator(alloc))
+ if err != nil {
+ return nil, adbc.Error{
+ Msg: err.Error(),
+ Code: adbc.StatusInvalidState,
+ }
+ }
+
+ var recTransform recordTransformer
+ rdr.schema, recTransform = getTransformer(rr.Schema(), ld, useHighPrecision)
- rr, err := ipc.NewReader(rdr, ipc.WithAllocator(alloc))
+ group.Go(func() error {
+ defer rr.Release()
+ defer r.Close()
+ if len(batches) > 1 {
+ defer close(ch)
+ }
+
+ for rr.Next() && ctx.Err() == nil {
+ rec := rr.Record()
+ rec, err = recTransform(ctx, rec)
if err != nil {
return err
}
- defer rr.Release()
+ ch <- rec
+ }
+ return rr.Err()
+ })
+
+ chs[0] = ch
+
+ lastChannelIndex := len(chs) - 1
+ go func() {
+ for i, b := range batches[1:] {
+ batch, batchIdx := b, i+1
+ chs[batchIdx] = make(chan arrow.Record, bufferSize)
+ group.Go(func() error {
+ // close channels (except the last) so that Next can move on to the next channel properly
+ if batchIdx != lastChannelIndex {
+ defer close(chs[batchIdx])
+ }
- for rr.Next() && ctx.Err() == nil {
- rec := rr.Record()
- rec, err = recTransform(ctx, rec)
+ rdr, err := batch.GetStream(ctx)
if err != nil {
return err
}
- chs[batchIdx] <- rec
- }
+ defer rdr.Close()
- return rr.Err()
- })
- }
+ rr, err := ipc.NewReader(rdr, ipc.WithAllocator(alloc))
+ if err != nil {
+ return err
+ }
+ defer rr.Release()
- // place this here so that we always clean up, but they can't be in a
- // separate goroutine. Otherwise we'll have a race condition between
- // the call to wait and the calls to group.Go to kick off the jobs
- // to perform the pre-fetching (GH-1283).
- rdr.err = group.Wait()
- // don't close the last channel until after the group is finished,
- // so that Next() can only return after reader.err may have been set
- close(chs[lastChannelIndex])
- }()
+ for rr.Next() && ctx.Err() == nil {
+ rec := rr.Record()
+ rec, err = recTransform(ctx, rec)
+ if err != nil {
+ return err
+ }
+ chs[batchIdx] <- rec
+ }
+
+ return rr.Err()
+ })
+ }
+
+ // place this here so that we always clean up, but they can't be in a
+ // separate goroutine. Otherwise we'll have a race condition between
+ // the call to wait and the calls to group.Go to kick off the jobs
+ // to perform the pre-fetching (GH-1283).
+ rdr.err = group.Wait()
+ // don't close the last channel until after the group is finished,
+ // so that Next() can only return after reader.err may have been set
+ close(chs[lastChannelIndex])
+ }()
+ } else {
+ schema, err := rowTypesToArrowSchema(ctx, ld, useHighPrecision)
+ if err != nil {
+ return nil, err
+ }
+ rdr.schema, _ = getTransformer(schema, ld, useHighPrecision)
+ }
return rdr, nil
}