diff --git a/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala b/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala index 26ac26822..d6137bff2 100644 --- a/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala +++ b/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala @@ -98,7 +98,7 @@ case class ColumnarShuffleExchangeExec( // check input datatype for (attr <- child.output) { try { - ConverterUtils.checkIfTypeSupported(attr.dataType) + ConverterUtils.createArrowField(attr) } catch { case e: UnsupportedOperationException => throw new UnsupportedOperationException( diff --git a/native-sql-engine/cpp/src/shuffle/splitter.cc b/native-sql-engine/cpp/src/shuffle/splitter.cc index d04ae84c2..d3c259828 100644 --- a/native-sql-engine/cpp/src/shuffle/splitter.cc +++ b/native-sql-engine/cpp/src/shuffle/splitter.cc @@ -1289,9 +1289,11 @@ arrow::Status Splitter::AppendList( using ValueBuilderType = typename arrow::TypeTraits::BuilderType; using ValueArrayType = typename arrow::TypeTraits::ArrayType; std::vector dst_values_builders; - for (auto builder : dst_builders) { - dst_values_builders.push_back( - checked_cast(builder->value_builder())); + dst_values_builders.resize(dst_builders.size()); + for (auto i = 0; i < dst_builders.size(); ++i) { + if (dst_builders[i] != nullptr) + dst_values_builders[i] = + checked_cast(dst_builders[i]->value_builder()); } auto src_arr_values = std::dynamic_pointer_cast(src_arr->values()); diff --git a/native-sql-engine/cpp/src/tests/shuffle_split_test.cc b/native-sql-engine/cpp/src/tests/shuffle_split_test.cc index d58152b36..4fefdbc7b 100644 --- a/native-sql-engine/cpp/src/tests/shuffle_split_test.cc +++ b/native-sql-engine/cpp/src/tests/shuffle_split_test.cc @@ -525,6 +525,51 @@ TEST_F(SplitterTest, TestRoundRobinListArraySplitter) { } } +TEST_F(SplitterTest, TestHashListArraySplitterWithMorePartitions) { + int32_t num_partitions = 5; + split_options_.buffer_size = 4; + + auto f_uint64 = field("f_uint64", arrow::uint64()); + auto f_arr_str = field("f_arr", arrow::list(arrow::utf8())); + + auto rb_schema = arrow::schema({f_uint64, f_arr_str}); + + const std::vector input_batch_1_data = { + R"([1, 2])", R"([["alice0", "bob1"], ["alice2"]])"}; + std::shared_ptr input_batch_arr; + MakeInputBatch(input_batch_1_data, rb_schema, &input_batch_arr); + + auto f_2 = TreeExprBuilder::MakeField(f_uint64); + auto expr_1 = TreeExprBuilder::MakeExpression(f_2, field("f_uint64", uint64())); + + ARROW_ASSIGN_OR_THROW(splitter_, Splitter::Make("hash", rb_schema, num_partitions, + {expr_1}, split_options_)); + + ASSERT_NOT_OK(splitter_->Split(*input_batch_arr)); + + ASSERT_NOT_OK(splitter_->Stop()); + + const auto& lengths = splitter_->PartitionLengths(); + ASSERT_EQ(lengths.size(), 5); + + CheckFileExsists(splitter_->DataFile()); + + std::shared_ptr file_reader; + ARROW_ASSIGN_OR_THROW(file_reader, GetRecordBatchStreamReader(splitter_->DataFile())); + + ASSERT_EQ(*file_reader->schema(), *rb_schema); + + std::vector> batches; + ASSERT_NOT_OK(file_reader->ReadAll(&batches)); + + for (const auto& rb : batches) { + ASSERT_EQ(rb->num_columns(), rb_schema->num_fields()); + for (auto i = 0; i < rb->num_columns(); ++i) { + ASSERT_EQ(rb->column(i)->length(), rb->num_rows()); + } + } +} + TEST_F(SplitterTest, TestRoundRobinListArraySplitterwithCompression) { auto f_arr_str = field("f_arr", arrow::list(arrow::utf8())); auto f_arr_bool = field("f_bool", arrow::list(arrow::boolean()));