From 3ee659c1dc4f3abebccdf20fc755cef6dc144fce Mon Sep 17 00:00:00 2001 From: Yuan Date: Thu, 8 Sep 2022 10:46:06 +0800 Subject: [PATCH] [NSE-1104] fix hashagg w/ empty string (#1099) * fix hashagg w/ empty string Signed-off-by: Yuan Zhou * adding empty bits in unsafe row Signed-off-by: Yuan Zhou * fix ut Signed-off-by: Yuan Zhou Signed-off-by: Yuan Zhou --- .../nativesql/NativeSQLConvertedSuite.scala | 15 ++++++++++++++ .../cpp/src/precompile/unsafe_array.h | 1 + .../third_party/row_wise_memory/unsafe_row.h | 20 ++++++++++++++++--- 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/nativesql/NativeSQLConvertedSuite.scala b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/nativesql/NativeSQLConvertedSuite.scala index 30d22b4d3..d6589496e 100644 --- a/native-sql-engine/core/src/test/scala/org/apache/spark/sql/nativesql/NativeSQLConvertedSuite.scala +++ b/native-sql-engine/core/src/test/scala/org/apache/spark/sql/nativesql/NativeSQLConvertedSuite.scala @@ -368,6 +368,21 @@ class NativeSQLConvertedSuite extends QueryTest Row("val1a", Timestamp.valueOf("2014-04-04 01:02:00.001")))) } + test("groupby with empty string") { + Seq[(String, Integer, String, String)]( + ("9", 1, "", "20220608"), + ("9", 2, "20220608", ""), + ("9", 2, "20220608", ""), + ("9", 1, "20220608", ""), + ("9", 2, "20220608", ""), + ("9", 1, "20220608", ""), + ("9", null, "20220608", ""), + ("9", 3, "20220608", "")).toDF("a", "b", "c", "d").createOrReplaceTempView("testData") + + val df1 = sql( "SELECT a,c,d, COUNT(*) FROM testData group by a, c, d") + checkAnswer(df1, Seq(Row("9","", "20220608", 1), Row("9","20220608","", 7))) + } + test("groupby") { Seq[(Integer, java.lang.Boolean)]( (1, true), diff --git a/native-sql-engine/cpp/src/precompile/unsafe_array.h b/native-sql-engine/cpp/src/precompile/unsafe_array.h index e13314135..56d20ed70 100644 --- a/native-sql-engine/cpp/src/precompile/unsafe_array.h +++ b/native-sql-engine/cpp/src/precompile/unsafe_array.h @@ -104,6 +104,7 @@ class TypedUnsafeArray> : public Unsaf appendToUnsafeRow((*unsafe_row).get(), idx_, *(int64_t*)(v.data())); break; default: + // empty string will go here appendToUnsafeRow((*unsafe_row).get(), idx_, v); } } diff --git a/native-sql-engine/cpp/src/third_party/row_wise_memory/unsafe_row.h b/native-sql-engine/cpp/src/third_party/row_wise_memory/unsafe_row.h index 64c76af25..14993772a 100644 --- a/native-sql-engine/cpp/src/third_party/row_wise_memory/unsafe_row.h +++ b/native-sql-engine/cpp/src/third_party/row_wise_memory/unsafe_row.h @@ -47,12 +47,14 @@ struct UnsafeRow { char* data = nullptr; int cursor; int validity_size; + int is_empty_size; UnsafeRow() {} UnsafeRow(int numFields) : numFields(numFields) { validity_size = (numFields / 8) + 1; - cursor = validity_size; + is_empty_size = (numFields / 8) + 1; + cursor = validity_size + is_empty_size; data = (char*)nativeMalloc(TEMP_UNSAFEROW_BUFFER_SIZE, MEMTYPE_ROW); - memset(data, 0, validity_size); + memset(data, 0, validity_size + is_empty_size); } ~UnsafeRow() { if (data) { @@ -61,8 +63,10 @@ struct UnsafeRow { } int sizeInBytes() { return cursor; } void reset() { + validity_size = (numFields / 8) + 1; + is_empty_size = (numFields / 8) + 1; memset(data, 0, cursor); - cursor = validity_size; + cursor = validity_size + is_empty_size; } bool isNullExists() { for (int i = 0; i < ((numFields / 8) + 1); i++) { @@ -99,6 +103,12 @@ static inline void setNullAt(UnsafeRow* row, int index) { *(row->data + bitSetIdx) |= kBitmask[index % 8]; } +static inline void setEmptyAt(UnsafeRow* row, int index) { + assert((index >= 0) && (index < row->numFields)); + auto bitSetIdx = index >> 3; // mod 8 + *(row->data + row->validity_size + bitSetIdx) |= kBitmask[index % 8]; +} + template using is_number_alike = std::integral_constant::value || @@ -115,6 +125,10 @@ static inline void appendToUnsafeRow(UnsafeRow* row, const int& index, const T& static inline void appendToUnsafeRow(UnsafeRow* row, const int& index, arrow::util::string_view str) { + if (unlikely(str.size() == 0)) { + setEmptyAt(row, index); + return; + } if (unlikely(row->cursor + str.size() > TEMP_UNSAFEROW_BUFFER_SIZE)) row->data = (char*)nativeRealloc(row->data, 2 * TEMP_UNSAFEROW_BUFFER_SIZE, MEMTYPE_ROW);