From 47edce20c0f58cd6e93036b35ec3c17bcc001477 Mon Sep 17 00:00:00 2001 From: light-city <455954986@qq.com> Date: Wed, 10 Apr 2024 16:26:24 +0800 Subject: [PATCH] Fix: left anti join filter empty rows. --- cpp/src/arrow/acero/hash_join_node_test.cc | 19 +++++++++++++++++++ cpp/src/arrow/acero/swiss_join.cc | 3 +++ 2 files changed, 22 insertions(+) diff --git a/cpp/src/arrow/acero/hash_join_node_test.cc b/cpp/src/arrow/acero/hash_join_node_test.cc index 63969d9a3ed4b..c6818ca491ba3 100644 --- a/cpp/src/arrow/acero/hash_join_node_test.cc +++ b/cpp/src/arrow/acero/hash_join_node_test.cc @@ -2036,6 +2036,25 @@ TEST(HashJoin, ResidualFilter) { [3, 4, "alpha", 4, 16, "alpha"]])")}); } +TEST(HashJoin, FilterEmptyRows) { + BatchesWithSchema input_left; + input_left.batches = {ExecBatchFromJSON({int32(), utf8(), int32()}, R"([[2, "Jarry", 28]])")}; + input_left.schema = schema({field("id", int32()), field("name", utf8()), field("age", int32())}); + + BatchesWithSchema input_right; + input_right.batches = {ExecBatchFromJSON( + {int32(), int32(), utf8()}, + R"([[2, 10, "Jack"], [3, 12, "Mark"], [4, 15, "Tom"], [1, 10, "Jack"]])")}; + input_right.schema = schema({field("id", int32()), field("stu_id", int32()), field("subject", utf8())}); + + const ResidualFilterCaseRunner runner{std::move(input_left), std::move(input_right)}; + + Expression filter = greater(field_ref("age"), literal(25)); + + runner.Run(JoinType::LEFT_ANTI, {"id"}, {"stu_id"}, std::move(filter), + {ExecBatchFromJSON({int32(), utf8(), int32()}, R"([[2, "Jarry", 28]])")}); +} + TEST(HashJoin, TrivialResidualFilter) { Expression always_true = equal(call("add", {field_ref("l1"), field_ref("r1")}), literal(2)); // 1 + 1 == 2 diff --git a/cpp/src/arrow/acero/swiss_join.cc b/cpp/src/arrow/acero/swiss_join.cc index 61c8bfe95414e..a042ed4f07755 100644 --- a/cpp/src/arrow/acero/swiss_join.cc +++ b/cpp/src/arrow/acero/swiss_join.cc @@ -2160,6 +2160,9 @@ Status JoinResidualFilter::FilterOneBatch(const ExecBatch& keypayload_batch, bool output_key_ids, bool output_payload_ids, arrow::util::TempVectorStack* temp_stack, int* num_passing_rows) const { + if (num_batch_rows == 0) { + return Status::OK(); + } // Caller must do shortcuts for trivial filter. ARROW_DCHECK(!filter_.IsNullLiteral() && filter_ != literal(true) && filter_ != literal(false));