diff --git a/be/src/pipeline/exec/hashjoin_probe_operator.cpp b/be/src/pipeline/exec/hashjoin_probe_operator.cpp index a5d195a2d8c592..476503f187e8c2 100644 --- a/be/src/pipeline/exec/hashjoin_probe_operator.cpp +++ b/be/src/pipeline/exec/hashjoin_probe_operator.cpp @@ -568,14 +568,35 @@ Status HashJoinProbeOperatorX::prepare(RuntimeState* state) { } } - const size_t right_col_idx = - (_is_right_semi_anti && !_have_other_join_conjunct) ? 0 : _left_table_data_types.size(); + _right_col_idx = (_is_right_semi_anti && !_have_other_join_conjunct && + (!_is_mark_join || _mark_join_conjuncts.empty())) + ? 0 + : _left_table_data_types.size(); + size_t idx = 0; for (const auto* slot : slots_to_check) { auto data_type = slot->get_data_type_ptr(); - const auto slot_on_left = idx < right_col_idx; + const auto slot_on_left = idx < _right_col_idx; + + if (slot_on_left) { + if (idx >= _left_table_data_types.size()) { + return Status::InternalError( + "Join node(id={}, OP={}) intermediate slot({}, #{})'s on left table " + "idx out bound of _left_table_data_types: {} vs {}", + _node_id, _join_op, slot->col_name(), slot->id(), idx, + _left_table_data_types.size()); + } + } else if (idx - _right_col_idx >= _right_table_data_types.size()) { + return Status::InternalError( + "Join node(id={}, OP={}) intermediate slot({}, #{})'s on right table " + "idx out bound of _right_table_data_types: {} vs {}(idx = {}, _right_col_idx = " + "{})", + _node_id, _join_op, slot->col_name(), slot->id(), idx - _right_col_idx, + _right_table_data_types.size(), idx, _right_col_idx); + } + auto target_data_type = slot_on_left ? _left_table_data_types[idx] - : _right_table_data_types[idx - right_col_idx]; + : _right_table_data_types[idx - _right_col_idx]; ++idx; if (data_type->equals(*target_data_type)) { continue; diff --git a/be/src/pipeline/exec/hashjoin_probe_operator.h b/be/src/pipeline/exec/hashjoin_probe_operator.h index 51758d8b8fbf18..7cbe443411272d 100644 --- a/be/src/pipeline/exec/hashjoin_probe_operator.h +++ b/be/src/pipeline/exec/hashjoin_probe_operator.h @@ -191,6 +191,9 @@ class HashJoinProbeOperatorX MOCK_REMOVE(final) std::set _should_not_lazy_materialized_column_ids; std::vector _right_table_column_names; const std::vector _partition_exprs; + + // Index of column(slot) from right table in the `_intermediate_row_desc`. + size_t _right_col_idx; }; } // namespace pipeline diff --git a/be/src/pipeline/exec/join/process_hash_table_probe.h b/be/src/pipeline/exec/join/process_hash_table_probe.h index 4fde7ed5fea69b..535ca74997d97c 100644 --- a/be/src/pipeline/exec/join/process_hash_table_probe.h +++ b/be/src/pipeline/exec/join/process_hash_table_probe.h @@ -22,6 +22,7 @@ #include "vec/columns/column.h" #include "vec/columns/columns_number.h" #include "vec/common/arena.h" +#include "vec/common/custom_allocator.h" namespace doris { namespace vectorized { @@ -119,8 +120,15 @@ struct ProcessHashTableProbe { RuntimeProfile::Counter* _probe_side_output_timer = nullptr; RuntimeProfile::Counter* _finish_probe_phase_timer = nullptr; - size_t _right_col_idx; + // See `HashJoinProbeOperatorX::_right_col_idx` + const size_t _right_col_idx; + size_t _right_col_len; + + // For right semi with mark join conjunct, we need to store the mark join flags + // in the hash table. + // -1 means null, 0 means false, 1 means true + DorisVector mark_join_flags; }; } // namespace pipeline diff --git a/be/src/pipeline/exec/join/process_hash_table_probe_impl.h b/be/src/pipeline/exec/join/process_hash_table_probe_impl.h index df7ad9456bb3f9..62fa5505d81b54 100644 --- a/be/src/pipeline/exec/join/process_hash_table_probe_impl.h +++ b/be/src/pipeline/exec/join/process_hash_table_probe_impl.h @@ -70,9 +70,7 @@ ProcessHashTableProbe::ProcessHashTableProbe(HashJoinProbeLocalState _build_side_output_timer(parent->_build_side_output_timer), _probe_side_output_timer(parent->_probe_side_output_timer), _finish_probe_phase_timer(parent->_finish_probe_phase_timer), - _right_col_idx((_parent_operator->_is_right_semi_anti && !_have_other_join_conjunct) - ? 0 - : _parent_operator->_left_table_data_types.size()), + _right_col_idx(_parent_operator->_right_col_idx), _right_col_len(_parent_operator->_right_table_data_types.size()) {} template @@ -272,7 +270,7 @@ Status ProcessHashTableProbe::process(HashTableType& hash_table_ctx, build_side_output_column(mcol, is_mark_join); - if (_have_other_join_conjunct || + if (_have_other_join_conjunct || !_parent->_mark_join_conjuncts.empty() || (JoinOpType != TJoinOp::RIGHT_SEMI_JOIN && JoinOpType != TJoinOp::RIGHT_ANTI_JOIN)) { probe_side_output_column(mcol); } @@ -281,7 +279,7 @@ Status ProcessHashTableProbe::process(HashTableType& hash_table_ctx, DCHECK_EQ(current_offset, output_block->rows()); COUNTER_UPDATE(_parent->_intermediate_rows_counter, current_offset); - if (is_mark_join && JoinOpType != TJoinOp::RIGHT_SEMI_JOIN) { + if (is_mark_join) { bool ignore_null_map = (JoinOpType == TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN || JoinOpType == TJoinOp::NULL_AWARE_LEFT_SEMI_JOIN) && @@ -409,15 +407,20 @@ Status ProcessHashTableProbe::finalize_block_with_filter( template Status ProcessHashTableProbe::do_mark_join_conjuncts(vectorized::Block* output_block, const uint8_t* null_map) { - DCHECK(JoinOpType == TJoinOp::LEFT_ANTI_JOIN || - JoinOpType == TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN || - JoinOpType == TJoinOp::LEFT_SEMI_JOIN || - JoinOpType == TJoinOp::NULL_AWARE_LEFT_SEMI_JOIN); + if (JoinOpType != TJoinOp::LEFT_ANTI_JOIN && JoinOpType != TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN && + JoinOpType != TJoinOp::LEFT_SEMI_JOIN && JoinOpType != TJoinOp::NULL_AWARE_LEFT_SEMI_JOIN && + JoinOpType != TJoinOp::RIGHT_SEMI_JOIN && JoinOpType != TJoinOp::RIGHT_ANTI_JOIN) { + return Status::InternalError("join type {} is not supported", JoinOpType); + } constexpr bool is_anti_join = JoinOpType == TJoinOp::LEFT_ANTI_JOIN || - JoinOpType == TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN; + JoinOpType == TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN || + JoinOpType == TJoinOp::RIGHT_ANTI_JOIN; constexpr bool is_null_aware_join = JoinOpType == TJoinOp::NULL_AWARE_LEFT_SEMI_JOIN || JoinOpType == TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN; + constexpr bool is_right_half_join = + JoinOpType == TJoinOp::RIGHT_SEMI_JOIN || JoinOpType == TJoinOp::RIGHT_ANTI_JOIN; + const auto row_count = output_block->rows(); if (!row_count) { return Status::OK(); @@ -488,37 +491,77 @@ Status ProcessHashTableProbe::do_mark_join_conjuncts(vectorized::Blo } } + if constexpr (is_right_half_join) { + if (mark_join_flags.empty() && _build_block != nullptr) { + mark_join_flags.resize(_build_block->rows(), 0); + } + } + auto filter_column = vectorized::ColumnUInt8::create(row_count, 0); auto* __restrict filter_map = filter_column->get_data().data(); for (size_t i = 0; i != row_count; ++i) { - if (_parent->_last_probe_match == _probe_indexs.get_element(i)) { - continue; - } - if (_build_indexs.get_element(i) == 0) { - bool has_null_mark_value = - _parent->_last_probe_null_mark == _probe_indexs.get_element(i); - filter_map[i] = true; - mark_filter_data[i] = false; - mark_null_map[i] |= has_null_mark_value; - } else if (mark_null_map[i]) { - _parent->_last_probe_null_mark = _probe_indexs.get_element(i); - } else if (mark_filter_data[i]) { - filter_map[i] = true; - _parent->_last_probe_match = _probe_indexs.get_element(i); + if constexpr (is_right_half_join) { + const auto& build_index = _build_indexs.get_element(i); + if (build_index == 0) { + continue; + } + + if (mark_join_flags[build_index] == 1) { + continue; + } + + if (mark_null_map[i]) { + mark_join_flags[build_index] = -1; + } else if (mark_filter_data[i]) { + mark_join_flags[build_index] = 1; + } + } else { + if (_parent->_last_probe_match == _probe_indexs.get_element(i)) { + continue; + } + if (_build_indexs.get_element(i) == 0) { + bool has_null_mark_value = + _parent->_last_probe_null_mark == _probe_indexs.get_element(i); + filter_map[i] = true; + mark_filter_data[i] = false; + mark_null_map[i] |= has_null_mark_value; + } else if (mark_null_map[i]) { + _parent->_last_probe_null_mark = _probe_indexs.get_element(i); + } else if (mark_filter_data[i]) { + filter_map[i] = true; + _parent->_last_probe_match = _probe_indexs.get_element(i); + } } } - if constexpr (is_anti_join) { - // flip the mark column - for (size_t i = 0; i != row_count; ++i) { - mark_filter_data[i] ^= 1; // not null/ null + if constexpr (is_right_half_join) { + if constexpr (is_anti_join) { + // flip the mark column + for (size_t i = 0; i != row_count; ++i) { + if (mark_join_flags[i] == -1) { + // -1 means null. + continue; + } + + mark_join_flags[i] ^= 1; + } + } + // For right semi/anti join, no rows will be output in probe phase. + output_block->clear_column_data(); + return Status::OK(); + } else { + if constexpr (is_anti_join) { + // flip the mark column + for (size_t i = 0; i != row_count; ++i) { + mark_filter_data[i] ^= 1; // not null/ null + } } - } - auto result_column_id = output_block->columns(); - output_block->insert( - {std::move(filter_column), std::make_shared(), ""}); - return finalize_block_with_filter(output_block, result_column_id, result_column_id); + auto result_column_id = output_block->columns(); + output_block->insert( + {std::move(filter_column), std::make_shared(), ""}); + return finalize_block_with_filter(output_block, result_column_id, result_column_id); + } } template @@ -675,8 +718,31 @@ Status ProcessHashTableProbe::finish_probing(HashTableType& hash_tab } } + if constexpr (JoinOpType == TJoinOp::RIGHT_ANTI_JOIN || + JoinOpType == TJoinOp::RIGHT_SEMI_JOIN) { + if (is_mark_join) { + if (mark_join_flags.empty() && _build_block != nullptr) { + mark_join_flags.resize(_build_block->rows(), 0); + } + + // mark column is nullable + auto* mark_column = assert_cast( + mcol[_parent->_mark_column_id].get()); + mark_column->resize(block_size); + auto* null_map = mark_column->get_null_map_data().data(); + auto* data = assert_cast(mark_column->get_nested_column()) + .get_data() + .data(); + for (size_t i = 0; i != block_size; ++i) { + const auto build_index = _build_indexs.get_element(i); + null_map[i] = mark_join_flags[build_index] == -1; + data[i] = mark_join_flags[build_index] == 1; + } + } + } + // just resize the left table column in case with other conjunct to make block size is not zero - if (_parent_operator->_is_right_semi_anti && _have_other_join_conjunct) { + if (_parent_operator->_is_right_semi_anti && _right_col_idx != 0) { for (int i = 0; i < _right_col_idx; ++i) { mcol[i]->resize(block_size); } diff --git a/be/src/vec/common/hash_table/join_hash_table.h b/be/src/vec/common/hash_table/join_hash_table.h index 61dc8b1e18271d..86ce9854a9122a 100644 --- a/be/src/vec/common/hash_table/join_hash_table.h +++ b/be/src/vec/common/hash_table/join_hash_table.h @@ -106,7 +106,7 @@ class JoinHashTable { keys, build_idx_map, probe_idx, build_idx, probe_rows, probe_idxs, build_idxs); } - if (is_mark_join && JoinOpType != TJoinOp::RIGHT_SEMI_JOIN) { + if (is_mark_join) { bool is_null_aware_join = JoinOpType == TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN || JoinOpType == TJoinOp::NULL_AWARE_LEFT_SEMI_JOIN; bool is_left_half_join = @@ -292,15 +292,6 @@ class JoinHashTable { auto do_the_probe = [&]() { while (build_idx && matched_cnt < batch_size) { - if constexpr (JoinOpType == TJoinOp::RIGHT_ANTI_JOIN || - JoinOpType == TJoinOp::RIGHT_SEMI_JOIN) { - if (!visited[build_idx] && keys[probe_idx] == build_keys[build_idx]) { - probe_idxs[matched_cnt] = probe_idx; - build_idxs[matched_cnt] = build_idx; - matched_cnt++; - } - } - if (keys[probe_idx] == build_keys[build_idx]) { build_idxs[matched_cnt] = build_idx; probe_idxs[matched_cnt] = probe_idx; diff --git a/be/test/pipeline/operator/hashjoin_probe_operator_test.cpp b/be/test/pipeline/operator/hashjoin_probe_operator_test.cpp index ebc5795dd66a31..bf0e29023447c5 100644 --- a/be/test/pipeline/operator/hashjoin_probe_operator_test.cpp +++ b/be/test/pipeline/operator/hashjoin_probe_operator_test.cpp @@ -798,6 +798,34 @@ TEST_F(HashJoinProbeOperatorTest, RightSemiJoin) { check_column_values(*sorted_block.get_by_position(1).column, {"c", "d"}); } +TEST_F(HashJoinProbeOperatorTest, RightSemiJoinMarkJoin) { + auto sink_block = ColumnHelper::create_block({1, 2, 3, 4, 5}); + sink_block.insert(ColumnHelper::create_nullable_column_with_name( + {"a", "b", "c", "d", "e"}, {1, 0, 0, 0, 1})); + + auto probe_block = + ColumnHelper::create_nullable_block({1, 2, 3, 4, 5}, {0, 1, 0, 0, 1}); + probe_block.insert( + ColumnHelper::create_column_with_name({"a", "b", "c", "d", "e"})); + + Block output_block; + std::vector build_blocks = {sink_block}; + std::vector probe_blocks = {probe_block}; + run_test({.join_op_type = TJoinOp::RIGHT_SEMI_JOIN, + .is_mark_join = true, + .mark_join_conjuncts_size = 1}, + {TPrimitiveType::INT, TPrimitiveType::STRING}, {true, false}, {false, true}, + build_blocks, probe_blocks, output_block); + + auto sorted_block = sort_block_by_columns(output_block); + std::cout << "Output block: " << sorted_block.dump_data() << std::endl; + ASSERT_EQ(sorted_block.rows(), 5); + + check_column_values(*sorted_block.get_by_position(2).column, {1, 2, 3, 4, 5}); + check_column_values(*sorted_block.get_by_position(3).column, {Null(), "b", "c", "d", Null()}); + check_column_values(*sorted_block.get_by_position(4).column, {0, Null(), 1, 1, 0}); +} + TEST_F(HashJoinProbeOperatorTest, NullAwareLeftAntiJoin) { auto sink_block = ColumnHelper::create_block({1, 2, 3, 4, 5}); sink_block.insert(ColumnHelper::create_nullable_column_with_name( diff --git a/regression-test/data/query_p0/join/mark_join/mark_join.out b/regression-test/data/query_p0/join/mark_join/mark_join.out index ed3575d0e14476..f4ac9204b4775c 100644 --- a/regression-test/data/query_p0/join/mark_join/mark_join.out +++ b/regression-test/data/query_p0/join/mark_join/mark_join.out @@ -17,3 +17,51 @@ 3 -3 \N c 3 3 \N c +-- !test_right_semi_mark_join -- +1 v1 o1 \N \N +2 v2 o2 \N \N +3 v3 o3 \N \N +4 v4 o4 \N \N +5 v5 o5 \N \N +6 v1 \N \N \N +7 v2 \N \N \N +8 v3 \N \N \N +9 v4 \N \N \N +10 v5 \N \N \N + +-- !test_right_semi_mark_join_2 -- +1 v1 o1 \N \N +2 v2 o2 \N \N +3 v3 o3 \N \N +4 v4 o4 \N \N +5 v5 o5 \N \N +6 v1 \N \N \N +7 v2 \N \N \N +8 v3 \N \N \N +9 v4 \N \N \N +10 v5 \N \N \N + +-- !test_right_semi_mark_join_no_null -- +1 v1 o1 false true +2 v2 o2 false true +3 v3 o3 false true +4 v4 o4 false true +5 v5 o5 false true +6 v1 \N \N \N +7 v2 \N \N \N +8 v3 \N \N \N +9 v4 \N \N \N +10 v5 \N \N \N + +-- !test_right_semi_mark_join_no_null_2 -- +1 v1 o1 false true +2 v2 o2 false true +3 v3 o3 false true +4 v4 o4 false true +5 v5 o5 false true +6 v1 \N \N \N +7 v2 \N \N \N +8 v3 \N \N \N +9 v4 \N \N \N +10 v5 \N \N \N + diff --git a/regression-test/suites/query_p0/join/mark_join/mark_join.groovy b/regression-test/suites/query_p0/join/mark_join/mark_join.groovy index 9759a0e9b4cd70..6b0a9d938c25d7 100644 --- a/regression-test/suites/query_p0/join/mark_join/mark_join.groovy +++ b/regression-test/suites/query_p0/join/mark_join/mark_join.groovy @@ -61,4 +61,219 @@ suite("mark_join") { qt_test """ select * from t1 where t1.k1 not in (select t2.k3 from t2 where t2.k2 = t1.k2) or k1 < 10 order by k1, k2; """ + + + sql "drop table if exists tbl1;" + sql "drop table if exists tbl2;" + sql "drop table if exists tbl3;" + + sql """ + CREATE TABLE `tbl1` ( + `unit_name` varchar(1080) NULL, + `cur_unit_name` varchar(1080) NOT NULL + ) ENGINE=OLAP + DUPLICATE KEY(`unit_name`) + DISTRIBUTED BY RANDOM BUCKETS AUTO + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + + sql """ + CREATE TABLE `tbl2` ( + `org_code` varchar(150) NOT NULL , + `org_name` varchar(300) NULL + ) ENGINE=OLAP + DUPLICATE KEY(`org_code`) + DISTRIBUTED BY HASH(`org_code`) BUCKETS 4 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + + sql """ + CREATE TABLE `tbl3` ( + `id` bigint NOT NULL, + `acntm_name` varchar(500) NULL , + `vendor_name` varchar(500) NULL + ) ENGINE=OLAP + DUPLICATE KEY(`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS AUTO + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + + sql """ + insert into tbl1 (unit_name, cur_unit_name) values + ('v1', 'o1'), + ('v2', 'o2'), + ('v3', 'o3'), + ('v4', 'o4'), + ('v5', 'o5'), + (null, 'o1'), + ('v1', 'o1'), + ('v2', 'o2'), + ('v3', 'o3'), + ('v4', 'o4'), + ('v5', 'o5'), + (null, 'o1'), + (null, 'o2'), + (null, 'o3'), + (null, 'o4'), + (null, 'o5'), + ('v1', 'o1'), + ('v2', 'o2'), + ('v3', 'o3'), + ('v4', 'o4'), + ('v5', 'o5'); + """ + + sql """ + insert into tbl2(org_code, org_name) values + ('v1', 'o1'), + ('v2', 'o2'), + ('v3', 'o3'), + ('v4', 'o4'), + ('v5', 'o5'), + ('v1', null), + ('v2', null), + ('v3', null), + ('v4', null), + ('v5', null); + """ + + sql """ + insert into tbl3 (id, vendor_name, acntm_name) + values(1, 'o1', 'v1'), + (2, 'o2', 'v2'), + (3, 'o3', 'v3'), + (4, 'o4', 'v4'), + (5, 'o5', 'v5'), + (6, null, 'v1'), + (7, null, 'v2'), + (8, null, 'v3'), + (9, null, 'v4'), + (10, null, 'v5'); + """ + + sql " analyze table tbl1 with sync;" + sql " analyze table tbl2 with sync;" + sql " analyze table tbl3 with sync;" + + sql "set disable_join_reorder=0;" + qt_test_right_semi_mark_join """ + select + tbl3.id, + tbl3.acntm_name, + tbl3.vendor_name, + tbl3.vendor_name in ( + select + tbl1.unit_name + from + tbl2 + join tbl1 on tbl1.cur_unit_name = tbl2.org_name + where + tbl2.org_code = tbl3.acntm_name + ) v1, + tbl3.vendor_name not in ( + select + tbl1.unit_name + from + tbl2 + join tbl1 on tbl1.cur_unit_name = tbl2.org_name + where + tbl2.org_code = tbl3.acntm_name + ) v2 + from + tbl3 order by 1,2,3,4,5; + """ + + sql "set disable_join_reorder=1;" + qt_test_right_semi_mark_join_2 """ + select + tbl3.id, + tbl3.acntm_name, + tbl3.vendor_name, + tbl3.vendor_name in ( + select + tbl1.unit_name + from + tbl2 + join tbl1 on tbl1.cur_unit_name = tbl2.org_name + where + tbl2.org_code = tbl3.acntm_name + ) v1, + tbl3.vendor_name not in ( + select + tbl1.unit_name + from + tbl2 + join tbl1 on tbl1.cur_unit_name = tbl2.org_name + where + tbl2.org_code = tbl3.acntm_name + ) v2 + from + tbl3 order by 1,2,3,4,5; + """ + + sql "set disable_join_reorder=0;" + qt_test_right_semi_mark_join_no_null """ + select + tbl3.id, + tbl3.acntm_name, + tbl3.vendor_name, + tbl3.vendor_name in ( + select + tbl1.unit_name + from + tbl2 + join tbl1 on tbl1.cur_unit_name = tbl2.org_name + where + tbl2.org_code = tbl3.acntm_name + and tbl1.unit_name is not null + ) v1, + tbl3.vendor_name not in ( + select + tbl1.unit_name + from + tbl2 + join tbl1 on tbl1.cur_unit_name = tbl2.org_name + where + tbl2.org_code = tbl3.acntm_name + and tbl1.unit_name is not null + ) v2 + from + tbl3 order by 1,2,3,4,5; + """ + + sql "set disable_join_reorder=1;" + qt_test_right_semi_mark_join_no_null_2 """ + select + tbl3.id, + tbl3.acntm_name, + tbl3.vendor_name, + tbl3.vendor_name in ( + select + tbl1.unit_name + from + tbl2 + join tbl1 on tbl1.cur_unit_name = tbl2.org_name + where + tbl2.org_code = tbl3.acntm_name + and tbl1.unit_name is not null + ) v1, + tbl3.vendor_name not in ( + select + tbl1.unit_name + from + tbl2 + join tbl1 on tbl1.cur_unit_name = tbl2.org_name + where + tbl2.org_code = tbl3.acntm_name + and tbl1.unit_name is not null + ) v2 + from + tbl3 order by 1,2,3,4,5; + """ }