Skip to content

Commit

Permalink
fix(go/adbc/driver/snowflake): handle empty result sets
Browse files Browse the repository at this point in the history
  • Loading branch information
lidavidm committed May 3, 2024
1 parent d6ddc01 commit 6145ed6
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 70 deletions.
2 changes: 2 additions & 0 deletions c/validation/adbc_validation.h
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ class StatementTest {
void TestSqlPrepareErrorNoQuery();
void TestSqlPrepareErrorParamCountMismatch();

void TestSqlQueryEmpty();
void TestSqlQueryInts();
void TestSqlQueryFloats();
void TestSqlQueryStrings();
Expand Down Expand Up @@ -504,6 +505,7 @@ class StatementTest {
TEST_F(FIXTURE, SqlPrepareErrorParamCountMismatch) { \
TestSqlPrepareErrorParamCountMismatch(); \
} \
TEST_F(FIXTURE, SqlQueryEmpty) { TestSqlQueryEmpty(); } \
TEST_F(FIXTURE, SqlQueryInts) { TestSqlQueryInts(); } \
TEST_F(FIXTURE, SqlQueryFloats) { TestSqlQueryFloats(); } \
TEST_F(FIXTURE, SqlQueryStrings) { TestSqlQueryStrings(); } \
Expand Down
24 changes: 24 additions & 0 deletions c/validation/adbc_validation_statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2062,6 +2062,30 @@ void StatementTest::TestSqlPrepareErrorParamCountMismatch() {
::testing::Not(IsOkStatus(&error)));
}

void StatementTest::TestSqlQueryEmpty() {
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42 WHERE 1=0", &error),
IsOkStatus(&error));
StreamReader reader;
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
&reader.rows_affected, &error),
IsOkStatus(&error));
ASSERT_THAT(reader.rows_affected,
::testing::AnyOf(::testing::Eq(0), ::testing::Eq(-1)));

ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
ASSERT_EQ(1, reader.schema->n_children);

while (true) {
ASSERT_NO_FATAL_FAILURE(reader.Next());
if (!reader.array->release) {
break;
}
ASSERT_EQ(0, reader.array->length);
}
ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
}

void StatementTest::TestSqlQueryInts() {
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42", &error),
Expand Down
149 changes: 79 additions & 70 deletions go/adbc/driver/snowflake/record_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -571,23 +571,9 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake
}

ch := make(chan arrow.Record, bufferSize)
r, err := batches[0].GetStream(ctx)
if err != nil {
return nil, errToAdbcErr(adbc.StatusIO, err)
}

rr, err := ipc.NewReader(r, ipc.WithAllocator(alloc))
if err != nil {
return nil, adbc.Error{
Msg: err.Error(),
Code: adbc.StatusInvalidState,
}
}

group, ctx := errgroup.WithContext(compute.WithAllocator(ctx, alloc))
ctx, cancelFn := context.WithCancel(ctx)

schema, recTransform := getTransformer(rr.Schema(), ld, useHighPrecision)
group.SetLimit(prefetchConcurrency)

defer func() {
if err != nil {
Expand All @@ -596,80 +582,103 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake
}
}()

group.SetLimit(prefetchConcurrency)
group.Go(func() error {
defer rr.Release()
defer r.Close()
if len(batches) > 1 {
defer close(ch)
}

for rr.Next() && ctx.Err() == nil {
rec := rr.Record()
rec, err = recTransform(ctx, rec)
if err != nil {
return err
}
ch <- rec
}
return rr.Err()
})

chs := make([]chan arrow.Record, len(batches))
chs[0] = ch
rdr := &reader{
refCount: 1,
chs: chs,
err: nil,
cancelFn: cancelFn,
schema: schema,
}

lastChannelIndex := len(chs) - 1
go func() {
for i, b := range batches[1:] {
batch, batchIdx := b, i+1
chs[batchIdx] = make(chan arrow.Record, bufferSize)
group.Go(func() error {
// close channels (except the last) so that Next can move on to the next channel properly
if batchIdx != lastChannelIndex {
defer close(chs[batchIdx])
}
if len(batches) > 0 {
r, err := batches[0].GetStream(ctx)
if err != nil {
return nil, errToAdbcErr(adbc.StatusIO, err)
}

rdr, err := batch.GetStream(ctx)
if err != nil {
return err
}
defer rdr.Close()
rr, err := ipc.NewReader(r, ipc.WithAllocator(alloc))
if err != nil {
return nil, adbc.Error{
Msg: err.Error(),
Code: adbc.StatusInvalidState,
}
}

var recTransform recordTransformer
rdr.schema, recTransform = getTransformer(rr.Schema(), ld, useHighPrecision)

rr, err := ipc.NewReader(rdr, ipc.WithAllocator(alloc))
group.Go(func() error {
defer rr.Release()
defer r.Close()
if len(batches) > 1 {
defer close(ch)
}

for rr.Next() && ctx.Err() == nil {
rec := rr.Record()
rec, err = recTransform(ctx, rec)
if err != nil {
return err
}
defer rr.Release()
ch <- rec
}
return rr.Err()
})

chs[0] = ch

lastChannelIndex := len(chs) - 1
go func() {
for i, b := range batches[1:] {
batch, batchIdx := b, i+1
chs[batchIdx] = make(chan arrow.Record, bufferSize)
group.Go(func() error {
// close channels (except the last) so that Next can move on to the next channel properly
if batchIdx != lastChannelIndex {
defer close(chs[batchIdx])
}

for rr.Next() && ctx.Err() == nil {
rec := rr.Record()
rec, err = recTransform(ctx, rec)
rdr, err := batch.GetStream(ctx)
if err != nil {
return err
}
chs[batchIdx] <- rec
}
defer rdr.Close()

return rr.Err()
})
}
rr, err := ipc.NewReader(rdr, ipc.WithAllocator(alloc))
if err != nil {
return err
}
defer rr.Release()

// place this here so that we always clean up, but they can't be in a
// separate goroutine. Otherwise we'll have a race condition between
// the call to wait and the calls to group.Go to kick off the jobs
// to perform the pre-fetching (GH-1283).
rdr.err = group.Wait()
// don't close the last channel until after the group is finished,
// so that Next() can only return after reader.err may have been set
close(chs[lastChannelIndex])
}()
for rr.Next() && ctx.Err() == nil {
rec := rr.Record()
rec, err = recTransform(ctx, rec)
if err != nil {
return err
}
chs[batchIdx] <- rec
}

return rr.Err()
})
}

// place this here so that we always clean up, but they can't be in a
// separate goroutine. Otherwise we'll have a race condition between
// the call to wait and the calls to group.Go to kick off the jobs
// to perform the pre-fetching (GH-1283).
rdr.err = group.Wait()
// don't close the last channel until after the group is finished,
// so that Next() can only return after reader.err may have been set
close(chs[lastChannelIndex])
}()
} else {
schema, err := rowTypesToArrowSchema(ctx, ld, useHighPrecision)
if err != nil {
return nil, err
}
rdr.schema, _ = getTransformer(schema, ld, useHighPrecision)
}

return rdr, nil
}
Expand Down

0 comments on commit 6145ed6

Please sign in to comment.