From 6145ed6f05a974d22b7c0ed01235816a7f724809 Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 2 May 2024 21:09:42 -0400 Subject: [PATCH] fix(go/adbc/driver/snowflake): handle empty result sets Fixes #1804. --- c/validation/adbc_validation.h | 2 + c/validation/adbc_validation_statement.cc | 24 ++++ go/adbc/driver/snowflake/record_reader.go | 149 ++++++++++++---------- 3 files changed, 105 insertions(+), 70 deletions(-) diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h index 6c59d95e09..abe9a76868 100644 --- a/c/validation/adbc_validation.h +++ b/c/validation/adbc_validation.h @@ -407,6 +407,7 @@ class StatementTest { void TestSqlPrepareErrorNoQuery(); void TestSqlPrepareErrorParamCountMismatch(); + void TestSqlQueryEmpty(); void TestSqlQueryInts(); void TestSqlQueryFloats(); void TestSqlQueryStrings(); @@ -504,6 +505,7 @@ class StatementTest { TEST_F(FIXTURE, SqlPrepareErrorParamCountMismatch) { \ TestSqlPrepareErrorParamCountMismatch(); \ } \ + TEST_F(FIXTURE, SqlQueryEmpty) { TestSqlQueryEmpty(); } \ TEST_F(FIXTURE, SqlQueryInts) { TestSqlQueryInts(); } \ TEST_F(FIXTURE, SqlQueryFloats) { TestSqlQueryFloats(); } \ TEST_F(FIXTURE, SqlQueryStrings) { TestSqlQueryStrings(); } \ diff --git a/c/validation/adbc_validation_statement.cc b/c/validation/adbc_validation_statement.cc index 59f3f3f9a3..7da58c6986 100644 --- a/c/validation/adbc_validation_statement.cc +++ b/c/validation/adbc_validation_statement.cc @@ -2062,6 +2062,30 @@ void StatementTest::TestSqlPrepareErrorParamCountMismatch() { ::testing::Not(IsOkStatus(&error))); } +void StatementTest::TestSqlQueryEmpty() { + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42 WHERE 1=0", &error), + IsOkStatus(&error)); + StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(reader.rows_affected, + ::testing::AnyOf(::testing::Eq(0), ::testing::Eq(-1))); + + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_EQ(1, reader.schema->n_children); + + while (true) { + ASSERT_NO_FATAL_FAILURE(reader.Next()); + if (!reader.array->release) { + break; + } + ASSERT_EQ(0, reader.array->length); + } + ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); +} + void StatementTest::TestSqlQueryInts() { ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42", &error), diff --git a/go/adbc/driver/snowflake/record_reader.go b/go/adbc/driver/snowflake/record_reader.go index bda3e8f70d..5c71322209 100644 --- a/go/adbc/driver/snowflake/record_reader.go +++ b/go/adbc/driver/snowflake/record_reader.go @@ -571,23 +571,9 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake } ch := make(chan arrow.Record, bufferSize) - r, err := batches[0].GetStream(ctx) - if err != nil { - return nil, errToAdbcErr(adbc.StatusIO, err) - } - - rr, err := ipc.NewReader(r, ipc.WithAllocator(alloc)) - if err != nil { - return nil, adbc.Error{ - Msg: err.Error(), - Code: adbc.StatusInvalidState, - } - } - group, ctx := errgroup.WithContext(compute.WithAllocator(ctx, alloc)) ctx, cancelFn := context.WithCancel(ctx) - - schema, recTransform := getTransformer(rr.Schema(), ld, useHighPrecision) + group.SetLimit(prefetchConcurrency) defer func() { if err != nil { @@ -596,80 +582,103 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake } }() - group.SetLimit(prefetchConcurrency) - group.Go(func() error { - defer rr.Release() - defer r.Close() - if len(batches) > 1 { - defer close(ch) - } - - for rr.Next() && ctx.Err() == nil { - rec := rr.Record() - rec, err = recTransform(ctx, rec) - if err != nil { - return err - } - ch <- rec - } - return rr.Err() - }) - chs := make([]chan arrow.Record, len(batches)) - chs[0] = ch rdr := &reader{ refCount: 1, chs: chs, err: nil, cancelFn: cancelFn, - schema: schema, } - lastChannelIndex := len(chs) - 1 - go func() { - for i, b := range batches[1:] { - batch, batchIdx := b, i+1 - chs[batchIdx] = make(chan arrow.Record, bufferSize) - group.Go(func() error { - // close channels (except the last) so that Next can move on to the next channel properly - if batchIdx != lastChannelIndex { - defer close(chs[batchIdx]) - } + if len(batches) > 0 { + r, err := batches[0].GetStream(ctx) + if err != nil { + return nil, errToAdbcErr(adbc.StatusIO, err) + } - rdr, err := batch.GetStream(ctx) - if err != nil { - return err - } - defer rdr.Close() + rr, err := ipc.NewReader(r, ipc.WithAllocator(alloc)) + if err != nil { + return nil, adbc.Error{ + Msg: err.Error(), + Code: adbc.StatusInvalidState, + } + } + + var recTransform recordTransformer + rdr.schema, recTransform = getTransformer(rr.Schema(), ld, useHighPrecision) - rr, err := ipc.NewReader(rdr, ipc.WithAllocator(alloc)) + group.Go(func() error { + defer rr.Release() + defer r.Close() + if len(batches) > 1 { + defer close(ch) + } + + for rr.Next() && ctx.Err() == nil { + rec := rr.Record() + rec, err = recTransform(ctx, rec) if err != nil { return err } - defer rr.Release() + ch <- rec + } + return rr.Err() + }) + + chs[0] = ch + + lastChannelIndex := len(chs) - 1 + go func() { + for i, b := range batches[1:] { + batch, batchIdx := b, i+1 + chs[batchIdx] = make(chan arrow.Record, bufferSize) + group.Go(func() error { + // close channels (except the last) so that Next can move on to the next channel properly + if batchIdx != lastChannelIndex { + defer close(chs[batchIdx]) + } - for rr.Next() && ctx.Err() == nil { - rec := rr.Record() - rec, err = recTransform(ctx, rec) + rdr, err := batch.GetStream(ctx) if err != nil { return err } - chs[batchIdx] <- rec - } + defer rdr.Close() - return rr.Err() - }) - } + rr, err := ipc.NewReader(rdr, ipc.WithAllocator(alloc)) + if err != nil { + return err + } + defer rr.Release() - // place this here so that we always clean up, but they can't be in a - // separate goroutine. Otherwise we'll have a race condition between - // the call to wait and the calls to group.Go to kick off the jobs - // to perform the pre-fetching (GH-1283). - rdr.err = group.Wait() - // don't close the last channel until after the group is finished, - // so that Next() can only return after reader.err may have been set - close(chs[lastChannelIndex]) - }() + for rr.Next() && ctx.Err() == nil { + rec := rr.Record() + rec, err = recTransform(ctx, rec) + if err != nil { + return err + } + chs[batchIdx] <- rec + } + + return rr.Err() + }) + } + + // place this here so that we always clean up, but they can't be in a + // separate goroutine. Otherwise we'll have a race condition between + // the call to wait and the calls to group.Go to kick off the jobs + // to perform the pre-fetching (GH-1283). + rdr.err = group.Wait() + // don't close the last channel until after the group is finished, + // so that Next() can only return after reader.err may have been set + close(chs[lastChannelIndex]) + }() + } else { + schema, err := rowTypesToArrowSchema(ctx, ld, useHighPrecision) + if err != nil { + return nil, err + } + rdr.schema, _ = getTransformer(schema, ld, useHighPrecision) + } return rdr, nil }