From 64a747a179d736475a66e02d926efbab5a133c32 Mon Sep 17 00:00:00 2001 From: SeaRise Date: Tue, 24 May 2022 15:20:46 +0800 Subject: [PATCH] refine `handleJoin` (#4722) ref pingcap/tiflash#4118 --- .../Coprocessor/DAGQueryBlockInterpreter.cpp | 329 +++------------- .../Coprocessor/DAGQueryBlockInterpreter.h | 14 - .../Coprocessor/JoinInterpreterHelper.cpp | 356 ++++++++++++++++++ .../Flash/Coprocessor/JoinInterpreterHelper.h | 133 +++++++ dbms/src/Flash/tests/exchange_perftest.cpp | 4 +- dbms/src/Interpreters/ExpressionAnalyzer.cpp | 2 +- dbms/src/Interpreters/Join.cpp | 138 ++++--- dbms/src/Interpreters/Join.h | 61 ++- dbms/src/Storages/StorageJoin.cpp | 2 +- tests/fullstack-test/mpp/misc_join.test | 41 ++ 10 files changed, 719 insertions(+), 361 deletions(-) create mode 100644 dbms/src/Flash/Coprocessor/JoinInterpreterHelper.cpp create mode 100644 dbms/src/Flash/Coprocessor/JoinInterpreterHelper.h create mode 100644 tests/fullstack-test/mpp/misc_join.test diff --git a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp index 7dfe0ebd871..4d8faffde6c 100644 --- a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp +++ b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp @@ -33,9 +33,7 @@ #include #include #include -#include #include -#include #include #include #include @@ -43,6 +41,7 @@ #include #include #include +#include #include #include #include @@ -50,9 +49,6 @@ #include #include #include -#include -#include - namespace DB { @@ -183,322 +179,117 @@ void DAGQueryBlockInterpreter::handleTableScan(const TiDBTableScan & table_scan, analyzer = std::move(storage_interpreter.analyzer); } -void DAGQueryBlockInterpreter::prepareJoin( - const google::protobuf::RepeatedPtrField & keys, - const DataTypes & key_types, - DAGPipeline & pipeline, - Names & key_names, - bool left, - bool is_right_out_join, - const google::protobuf::RepeatedPtrField & filters, - String & filter_column_name) -{ - NamesAndTypes source_columns; - for (auto const & p : pipeline.firstStream()->getHeader().getNamesAndTypesList()) - source_columns.emplace_back(p.name, p.type); - DAGExpressionAnalyzer dag_analyzer(std::move(source_columns), context); - ExpressionActionsChain chain; - if (dag_analyzer.appendJoinKeyAndJoinFilters(chain, keys, key_types, key_names, left, is_right_out_join, filters, filter_column_name)) - { - pipeline.transform([&](auto & stream) { stream = std::make_shared(stream, chain.getLastActions(), log->identifier()); }); - } -} - -ExpressionActionsPtr DAGQueryBlockInterpreter::genJoinOtherConditionAction( - const tipb::Join & join, - NamesAndTypes & source_columns, - String & filter_column_for_other_condition, - String & filter_column_for_other_eq_condition) -{ - if (join.other_conditions_size() == 0 && join.other_eq_conditions_from_in_size() == 0) - return nullptr; - DAGExpressionAnalyzer dag_analyzer(source_columns, context); - ExpressionActionsChain chain; - std::vector condition_vector; - if (join.other_conditions_size() > 0) - { - for (const auto & c : join.other_conditions()) - { - condition_vector.push_back(&c); - } - filter_column_for_other_condition = dag_analyzer.appendWhere(chain, condition_vector); - } - if (join.other_eq_conditions_from_in_size() > 0) - { - condition_vector.clear(); - for (const auto & c : join.other_eq_conditions_from_in()) - { - condition_vector.push_back(&c); - } - filter_column_for_other_eq_condition = dag_analyzer.appendWhere(chain, condition_vector); - } - return chain.getLastActions(); -} - -/// ClickHouse require join key to be exactly the same type -/// TiDB only require the join key to be the same category -/// for example decimal(10,2) join decimal(20,0) is allowed in -/// TiDB and will throw exception in ClickHouse -void getJoinKeyTypes(const tipb::Join & join, DataTypes & key_types) -{ - for (int i = 0; i < join.left_join_keys().size(); i++) - { - if (!exprHasValidFieldType(join.left_join_keys(i)) || !exprHasValidFieldType(join.right_join_keys(i))) - throw TiFlashException("Join key without field type", Errors::Coprocessor::BadRequest); - DataTypes types; - types.emplace_back(getDataTypeByFieldTypeForComputingLayer(join.left_join_keys(i).field_type())); - types.emplace_back(getDataTypeByFieldTypeForComputingLayer(join.right_join_keys(i).field_type())); - DataTypePtr common_type = getLeastSupertype(types); - key_types.emplace_back(common_type); - } -} - void DAGQueryBlockInterpreter::handleJoin(const tipb::Join & join, DAGPipeline & pipeline, SubqueryForSet & right_query) { - // build - static const std::unordered_map equal_join_type_map{ - {tipb::JoinType::TypeInnerJoin, ASTTableJoin::Kind::Inner}, - {tipb::JoinType::TypeLeftOuterJoin, ASTTableJoin::Kind::Left}, - {tipb::JoinType::TypeRightOuterJoin, ASTTableJoin::Kind::Right}, - {tipb::JoinType::TypeSemiJoin, ASTTableJoin::Kind::Inner}, - {tipb::JoinType::TypeAntiSemiJoin, ASTTableJoin::Kind::Anti}, - {tipb::JoinType::TypeLeftOuterSemiJoin, ASTTableJoin::Kind::LeftSemi}, - {tipb::JoinType::TypeAntiLeftOuterSemiJoin, ASTTableJoin::Kind::LeftAnti}}; - static const std::unordered_map cartesian_join_type_map{ - {tipb::JoinType::TypeInnerJoin, ASTTableJoin::Kind::Cross}, - {tipb::JoinType::TypeLeftOuterJoin, ASTTableJoin::Kind::Cross_Left}, - {tipb::JoinType::TypeRightOuterJoin, ASTTableJoin::Kind::Cross_Right}, - {tipb::JoinType::TypeSemiJoin, ASTTableJoin::Kind::Cross}, - {tipb::JoinType::TypeAntiSemiJoin, ASTTableJoin::Kind::Cross_Anti}, - {tipb::JoinType::TypeLeftOuterSemiJoin, ASTTableJoin::Kind::Cross_LeftSemi}, - {tipb::JoinType::TypeAntiLeftOuterSemiJoin, ASTTableJoin::Kind::Cross_LeftAnti}}; - - if (input_streams_vec.size() != 2) + if (unlikely(input_streams_vec.size() != 2)) { throw TiFlashException("Join query block must have 2 input streams", Errors::BroadcastJoin::Internal); } - const auto & join_type_map = join.left_join_keys_size() == 0 ? cartesian_join_type_map : equal_join_type_map; - auto join_type_it = join_type_map.find(join.join_type()); - if (join_type_it == join_type_map.end()) - throw TiFlashException("Unknown join type in dag request", Errors::Coprocessor::BadRequest); - - /// (cartesian) (anti) left semi join. - const bool is_left_semi_family = join.join_type() == tipb::JoinType::TypeLeftOuterSemiJoin || join.join_type() == tipb::JoinType::TypeAntiLeftOuterSemiJoin; - - ASTTableJoin::Kind kind = join_type_it->second; - const bool is_semi_join = join.join_type() == tipb::JoinType::TypeSemiJoin || join.join_type() == tipb::JoinType::TypeAntiSemiJoin || is_left_semi_family; - ASTTableJoin::Strictness strictness = ASTTableJoin::Strictness::All; - if (is_semi_join) - strictness = ASTTableJoin::Strictness::Any; - - /// in DAG request, inner part is the build side, however for TiFlash implementation, - /// the build side must be the right side, so need to swap the join side if needed - /// 1. for (cross) inner join, there is no problem in this swap. - /// 2. for (cross) semi/anti-semi join, the build side is always right, needn't swap. - /// 3. for non-cross left/right join, there is no problem in this swap. - /// 4. for cross left join, the build side is always right, needn't and can't swap. - /// 5. for cross right join, the build side is always left, so it will always swap and change to cross left join. - /// note that whatever the build side is, we can't support cross-right join now. - - bool swap_join_side; - if (kind == ASTTableJoin::Kind::Cross_Right) - swap_join_side = true; - else if (kind == ASTTableJoin::Kind::Cross_Left) - swap_join_side = false; - else - swap_join_side = join.inner_idx() == 0; + JoinInterpreterHelper::TiFlashJoin tiflash_join{join}; - DAGPipeline left_pipeline; - DAGPipeline right_pipeline; + DAGPipeline probe_pipeline; + DAGPipeline build_pipeline; + probe_pipeline.streams = input_streams_vec[1 - tiflash_join.build_side_index]; + build_pipeline.streams = input_streams_vec[tiflash_join.build_side_index]; - if (swap_join_side) - { - if (kind == ASTTableJoin::Kind::Left) - kind = ASTTableJoin::Kind::Right; - else if (kind == ASTTableJoin::Kind::Right) - kind = ASTTableJoin::Kind::Left; - else if (kind == ASTTableJoin::Kind::Cross_Right) - kind = ASTTableJoin::Kind::Cross_Left; - left_pipeline.streams = input_streams_vec[1]; - right_pipeline.streams = input_streams_vec[0]; - } - else - { - left_pipeline.streams = input_streams_vec[0]; - right_pipeline.streams = input_streams_vec[1]; - } - - NamesAndTypes join_output_columns; - /// columns_for_other_join_filter is a vector of columns used - /// as the input columns when compiling other join filter. - /// Note the order in the column vector is very important: - /// first the columns in input_streams_vec[0], then followed - /// by the columns in input_streams_vec[1], if there are other - /// columns generated before compile other join filter, then - /// append the extra columns afterwards. In order to figure out - /// whether a given column is already in the column vector or - /// not quickly, we use another set to store the column names - NamesAndTypes columns_for_other_join_filter; - std::unordered_set column_set_for_other_join_filter; - bool make_nullable = join.join_type() == tipb::JoinType::TypeRightOuterJoin; - for (auto const & p : input_streams_vec[0][0]->getHeader().getNamesAndTypesList()) - { - join_output_columns.emplace_back(p.name, make_nullable ? makeNullable(p.type) : p.type); - columns_for_other_join_filter.emplace_back(p.name, make_nullable ? makeNullable(p.type) : p.type); - column_set_for_other_join_filter.emplace(p.name); - } - make_nullable = join.join_type() == tipb::JoinType::TypeLeftOuterJoin; - for (auto const & p : input_streams_vec[1][0]->getHeader().getNamesAndTypesList()) - { - if (!is_semi_join) - /// for semi join, the columns from right table will be ignored - join_output_columns.emplace_back(p.name, make_nullable ? makeNullable(p.type) : p.type); - /// however, when compiling join's other condition, we still need the columns from right table - columns_for_other_join_filter.emplace_back(p.name, make_nullable ? makeNullable(p.type) : p.type); - column_set_for_other_join_filter.emplace(p.name); - } + RUNTIME_ASSERT(!input_streams_vec[0].empty(), log, "left input streams cannot be empty"); + const Block & left_input_header = input_streams_vec[0].back()->getHeader(); - bool is_tiflash_left_join = kind == ASTTableJoin::Kind::Left || kind == ASTTableJoin::Kind::Cross_Left; - /// Cross_Right join will be converted to Cross_Left join, so no need to check Cross_Right - bool is_tiflash_right_join = kind == ASTTableJoin::Kind::Right; - /// all the columns from right table should be added after join, even for the join key - NamesAndTypesList columns_added_by_join; - make_nullable = is_tiflash_left_join; - for (auto const & p : right_pipeline.streams[0]->getHeader().getNamesAndTypesList()) - { - columns_added_by_join.emplace_back(p.name, make_nullable ? makeNullable(p.type) : p.type); - } - - String match_helper_name; - if (is_left_semi_family) - { - const auto & left_block = input_streams_vec[0][0]->getHeader(); - const auto & right_block = input_streams_vec[1][0]->getHeader(); + RUNTIME_ASSERT(!input_streams_vec[1].empty(), log, "right input streams cannot be empty"); + const Block & right_input_header = input_streams_vec[1].back()->getHeader(); - match_helper_name = Join::match_helper_prefix; - for (int i = 1; left_block.has(match_helper_name) || right_block.has(match_helper_name); ++i) - { - match_helper_name = Join::match_helper_prefix + std::to_string(i); - } - - columns_added_by_join.emplace_back(match_helper_name, Join::match_helper_type); - join_output_columns.emplace_back(match_helper_name, Join::match_helper_type); - } - - DataTypes join_key_types; - getJoinKeyTypes(join, join_key_types); - TiDB::TiDBCollators collators; - size_t join_key_size = join_key_types.size(); - if (join.probe_types_size() == static_cast(join_key_size) && join.build_types_size() == join.probe_types_size()) - for (size_t i = 0; i < join_key_size; i++) - { - if (removeNullable(join_key_types[i])->isString()) - { - if (join.probe_types(i).collate() != join.build_types(i).collate()) - throw TiFlashException("Join with different collators on the join key", Errors::Coprocessor::BadRequest); - collators.push_back(getCollatorFromFieldType(join.probe_types(i))); - } - else - collators.push_back(nullptr); - } - - Names left_key_names, right_key_names; - String left_filter_column_name, right_filter_column_name; + String match_helper_name = tiflash_join.genMatchHelperName(left_input_header, right_input_header); + NamesAndTypesList columns_added_by_join = tiflash_join.genColumnsAddedByJoin(build_pipeline.firstStream()->getHeader(), match_helper_name); + NamesAndTypes join_output_columns = tiflash_join.genJoinOutputColumns(left_input_header, right_input_header, match_helper_name); /// add necessary transformation if the join key is an expression - prepareJoin( - swap_join_side ? join.right_join_keys() : join.left_join_keys(), - join_key_types, - left_pipeline, - left_key_names, + bool is_tiflash_right_join = tiflash_join.isTiFlashRightJoin(); + + // prepare probe side + auto [probe_side_prepare_actions, probe_key_names, probe_filter_column_name] = JoinInterpreterHelper::prepareJoin( + context, + probe_pipeline.firstStream()->getHeader(), + tiflash_join.getProbeJoinKeys(), + tiflash_join.join_key_types, true, is_tiflash_right_join, - swap_join_side ? join.right_conditions() : join.left_conditions(), - left_filter_column_name); - - prepareJoin( - swap_join_side ? join.left_join_keys() : join.right_join_keys(), - join_key_types, - right_pipeline, - right_key_names, + tiflash_join.getProbeConditions()); + RUNTIME_ASSERT(probe_side_prepare_actions, log, "probe_side_prepare_actions cannot be nullptr"); + + // prepare build side + auto [build_side_prepare_actions, build_key_names, build_filter_column_name] = JoinInterpreterHelper::prepareJoin( + context, + build_pipeline.firstStream()->getHeader(), + tiflash_join.getBuildJoinKeys(), + tiflash_join.join_key_types, false, is_tiflash_right_join, - swap_join_side ? join.left_conditions() : join.right_conditions(), - right_filter_column_name); + tiflash_join.getBuildConditions()); + RUNTIME_ASSERT(build_side_prepare_actions, log, "build_side_prepare_actions cannot be nullptr"); - String other_filter_column_name, other_eq_filter_from_in_column_name; - for (auto const & p : left_pipeline.streams[0]->getHeader().getNamesAndTypesList()) - { - if (column_set_for_other_join_filter.find(p.name) == column_set_for_other_join_filter.end()) - columns_for_other_join_filter.emplace_back(p.name, p.type); - } - for (auto const & p : right_pipeline.streams[0]->getHeader().getNamesAndTypesList()) - { - if (column_set_for_other_join_filter.find(p.name) == column_set_for_other_join_filter.end()) - columns_for_other_join_filter.emplace_back(p.name, p.type); - } - - ExpressionActionsPtr other_condition_expr - = genJoinOtherConditionAction(join, columns_for_other_join_filter, other_filter_column_name, other_eq_filter_from_in_column_name); + auto [other_condition_expr, other_filter_column_name, other_eq_filter_from_in_column_name] + = tiflash_join.genJoinOtherConditionAction(context, left_input_header, right_input_header, probe_side_prepare_actions); const Settings & settings = context.getSettingsRef(); - size_t join_build_concurrency = settings.join_concurrent_build ? std::min(max_streams, right_pipeline.streams.size()) : 1; size_t max_block_size_for_cross_join = settings.max_block_size; fiu_do_on(FailPoints::minimum_block_size_for_cross_join, { max_block_size_for_cross_join = 1; }); JoinPtr join_ptr = std::make_shared( - left_key_names, - right_key_names, + probe_key_names, + build_key_names, true, SizeLimits(settings.max_rows_in_join, settings.max_bytes_in_join, settings.join_overflow_mode), - kind, - strictness, + tiflash_join.kind, + tiflash_join.strictness, log->identifier(), - join_build_concurrency, - collators, - left_filter_column_name, - right_filter_column_name, + tiflash_join.join_key_collators, + probe_filter_column_name, + build_filter_column_name, other_filter_column_name, other_eq_filter_from_in_column_name, other_condition_expr, max_block_size_for_cross_join, match_helper_name); - recordJoinExecuteInfo(swap_join_side ? 0 : 1, join_ptr); + recordJoinExecuteInfo(tiflash_join.build_side_index, join_ptr); + + size_t join_build_concurrency = settings.join_concurrent_build ? std::min(max_streams, build_pipeline.streams.size()) : 1; + /// build side streams + executeExpression(build_pipeline, build_side_prepare_actions); // add a HashJoinBuildBlockInputStream to build a shared hash table - size_t concurrency_build_index = 0; - auto get_concurrency_build_index = [&concurrency_build_index, &join_build_concurrency]() { - return (concurrency_build_index++) % join_build_concurrency; - }; - right_pipeline.transform([&](auto & stream) { + auto get_concurrency_build_index = JoinInterpreterHelper::concurrencyBuildIndexGenerator(join_build_concurrency); + build_pipeline.transform([&](auto & stream) { stream = std::make_shared(stream, join_ptr, get_concurrency_build_index(), log->identifier()); }); - executeUnion(right_pipeline, max_streams, log, /*ignore_block=*/true); + executeUnion(build_pipeline, max_streams, log, /*ignore_block=*/true); - right_query.source = right_pipeline.firstStream(); + right_query.source = build_pipeline.firstStream(); right_query.join = join_ptr; - right_query.join->setSampleBlock(right_query.source->getHeader()); + join_ptr->init(right_query.source->getHeader(), join_build_concurrency); + /// probe side streams + executeExpression(probe_pipeline, probe_side_prepare_actions); NamesAndTypes source_columns; - for (const auto & p : left_pipeline.streams[0]->getHeader().getNamesAndTypesList()) + for (const auto & p : probe_pipeline.firstStream()->getHeader()) source_columns.emplace_back(p.name, p.type); DAGExpressionAnalyzer dag_analyzer(std::move(source_columns), context); ExpressionActionsChain chain; dag_analyzer.appendJoin(chain, right_query, columns_added_by_join); - pipeline.streams = left_pipeline.streams; + pipeline.streams = probe_pipeline.streams; /// add join input stream if (is_tiflash_right_join) { auto & join_execute_info = dagContext().getJoinExecuteInfoMap()[query_block.source_name]; - for (size_t i = 0; i < join_build_concurrency; i++) + size_t not_joined_concurrency = join_ptr->getNotJoinedStreamConcurrency(); + for (size_t i = 0; i < not_joined_concurrency; ++i) { - auto non_joined_stream = chain.getLastActions()->createStreamWithNonJoinedDataIfFullOrRightJoin( + auto non_joined_stream = join_ptr->createStreamWithNonJoinedRows( pipeline.firstStream()->getHeader(), i, - join_build_concurrency, + not_joined_concurrency, settings.max_block_size); pipeline.streams_with_non_joined_data.push_back(non_joined_stream); join_execute_info.non_joined_streams.push_back(non_joined_stream); diff --git a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.h b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.h index 69bac9c3ba9..9b95a5c3e93 100644 --- a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.h +++ b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.h @@ -61,25 +61,11 @@ class DAGQueryBlockInterpreter void handleMockTableScan(const TiDBTableScan & table_scan, DAGPipeline & pipeline); void handleTableScan(const TiDBTableScan & table_scan, DAGPipeline & pipeline); void handleJoin(const tipb::Join & join, DAGPipeline & pipeline, SubqueryForSet & right_query); - void prepareJoin( - const google::protobuf::RepeatedPtrField & keys, - const DataTypes & key_types, - DAGPipeline & pipeline, - Names & key_names, - bool left, - bool is_right_out_join, - const google::protobuf::RepeatedPtrField & filters, - String & filter_column_name); void handleExchangeReceiver(DAGPipeline & pipeline); void handleMockExchangeReceiver(DAGPipeline & pipeline); void handleProjection(DAGPipeline & pipeline, const tipb::Projection & projection); void handleWindow(DAGPipeline & pipeline, const tipb::Window & window); void handleWindowOrder(DAGPipeline & pipeline, const tipb::Sort & window_sort); - ExpressionActionsPtr genJoinOtherConditionAction( - const tipb::Join & join, - NamesAndTypes & source_columns, - String & filter_column_for_other_condition, - String & filter_column_for_other_eq_condition); void executeWhere(DAGPipeline & pipeline, const ExpressionActionsPtr & expressionActionsPtr, String & filter_column); void executeExpression(DAGPipeline & pipeline, const ExpressionActionsPtr & expressionActionsPtr); void executeWindowOrder(DAGPipeline & pipeline, SortDescription sort_desc); diff --git a/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.cpp b/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.cpp new file mode 100644 index 00000000000..2582a84ac46 --- /dev/null +++ b/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.cpp @@ -0,0 +1,356 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace DB::JoinInterpreterHelper +{ +namespace +{ +std::pair getJoinKindAndBuildSideIndex(const tipb::Join & join) +{ + static const std::unordered_map equal_join_type_map{ + {tipb::JoinType::TypeInnerJoin, ASTTableJoin::Kind::Inner}, + {tipb::JoinType::TypeLeftOuterJoin, ASTTableJoin::Kind::Left}, + {tipb::JoinType::TypeRightOuterJoin, ASTTableJoin::Kind::Right}, + {tipb::JoinType::TypeSemiJoin, ASTTableJoin::Kind::Inner}, + {tipb::JoinType::TypeAntiSemiJoin, ASTTableJoin::Kind::Anti}, + {tipb::JoinType::TypeLeftOuterSemiJoin, ASTTableJoin::Kind::LeftSemi}, + {tipb::JoinType::TypeAntiLeftOuterSemiJoin, ASTTableJoin::Kind::LeftAnti}}; + static const std::unordered_map cartesian_join_type_map{ + {tipb::JoinType::TypeInnerJoin, ASTTableJoin::Kind::Cross}, + {tipb::JoinType::TypeLeftOuterJoin, ASTTableJoin::Kind::Cross_Left}, + {tipb::JoinType::TypeRightOuterJoin, ASTTableJoin::Kind::Cross_Right}, + {tipb::JoinType::TypeSemiJoin, ASTTableJoin::Kind::Cross}, + {tipb::JoinType::TypeAntiSemiJoin, ASTTableJoin::Kind::Cross_Anti}, + {tipb::JoinType::TypeLeftOuterSemiJoin, ASTTableJoin::Kind::Cross_LeftSemi}, + {tipb::JoinType::TypeAntiLeftOuterSemiJoin, ASTTableJoin::Kind::Cross_LeftAnti}}; + + const auto & join_type_map = join.left_join_keys_size() == 0 ? cartesian_join_type_map : equal_join_type_map; + auto join_type_it = join_type_map.find(join.join_type()); + if (unlikely(join_type_it == join_type_map.end())) + throw TiFlashException("Unknown join type in dag request", Errors::Coprocessor::BadRequest); + + ASTTableJoin::Kind kind = join_type_it->second; + + /// in DAG request, inner part is the build side, however for TiFlash implementation, + /// the build side must be the right side, so need to swap the join side if needed + /// 1. for (cross) inner join, there is no problem in this swap. + /// 2. for (cross) semi/anti-semi join, the build side is always right, needn't swap. + /// 3. for non-cross left/right join, there is no problem in this swap. + /// 4. for cross left join, the build side is always right, needn't and can't swap. + /// 5. for cross right join, the build side is always left, so it will always swap and change to cross left join. + /// note that whatever the build side is, we can't support cross-right join now. + + size_t build_side_index = 0; + switch (kind) + { + case ASTTableJoin::Kind::Cross_Right: + build_side_index = 0; + break; + case ASTTableJoin::Kind::Cross_Left: + build_side_index = 1; + break; + default: + build_side_index = join.inner_idx(); + } + assert(build_side_index == 0 || build_side_index == 1); + + // should swap join side. + if (build_side_index != 1) + { + switch (kind) + { + case ASTTableJoin::Kind::Left: + kind = ASTTableJoin::Kind::Right; + break; + case ASTTableJoin::Kind::Right: + kind = ASTTableJoin::Kind::Left; + break; + case ASTTableJoin::Kind::Cross_Right: + kind = ASTTableJoin::Kind::Cross_Left; + default:; // just `default`, for other kinds, don't need to change kind. + } + } + + return {kind, build_side_index}; +} + +DataTypes getJoinKeyTypes(const tipb::Join & join) +{ + if (unlikely(join.left_join_keys_size() != join.right_join_keys_size())) + throw TiFlashException("size of join.left_join_keys != size of join.right_join_keys", Errors::Coprocessor::BadRequest); + DataTypes key_types; + for (int i = 0; i < join.left_join_keys_size(); ++i) + { + if (unlikely(!exprHasValidFieldType(join.left_join_keys(i)) || !exprHasValidFieldType(join.right_join_keys(i)))) + throw TiFlashException("Join key without field type", Errors::Coprocessor::BadRequest); + DataTypes types; + types.emplace_back(getDataTypeByFieldTypeForComputingLayer(join.left_join_keys(i).field_type())); + types.emplace_back(getDataTypeByFieldTypeForComputingLayer(join.right_join_keys(i).field_type())); + DataTypePtr common_type = getLeastSupertype(types); + key_types.emplace_back(common_type); + } + return key_types; +} + +TiDB::TiDBCollators getJoinKeyCollators(const tipb::Join & join, const DataTypes & join_key_types) +{ + TiDB::TiDBCollators collators; + size_t join_key_size = join_key_types.size(); + if (join.probe_types_size() == static_cast(join_key_size) && join.build_types_size() == join.probe_types_size()) + for (size_t i = 0; i < join_key_size; ++i) + { + if (removeNullable(join_key_types[i])->isString()) + { + if (unlikely(join.probe_types(i).collate() != join.build_types(i).collate())) + throw TiFlashException("Join with different collators on the join key", Errors::Coprocessor::BadRequest); + collators.push_back(getCollatorFromFieldType(join.probe_types(i))); + } + else + collators.push_back(nullptr); + } + return collators; +} + +std::tuple doGenJoinOtherConditionAction( + const Context & context, + const tipb::Join & join, + const NamesAndTypes & source_columns) +{ + if (join.other_conditions_size() == 0 && join.other_eq_conditions_from_in_size() == 0) + return {nullptr, "", ""}; + + DAGExpressionAnalyzer dag_analyzer(source_columns, context); + ExpressionActionsChain chain; + + String filter_column_for_other_condition; + if (join.other_conditions_size() > 0) + { + std::vector condition_vector; + for (const auto & c : join.other_conditions()) + { + condition_vector.push_back(&c); + } + filter_column_for_other_condition = dag_analyzer.appendWhere(chain, condition_vector); + } + + String filter_column_for_other_eq_condition; + if (join.other_eq_conditions_from_in_size() > 0) + { + std::vector condition_vector; + for (const auto & c : join.other_eq_conditions_from_in()) + { + condition_vector.push_back(&c); + } + filter_column_for_other_eq_condition = dag_analyzer.appendWhere(chain, condition_vector); + } + + return {chain.getLastActions(), std::move(filter_column_for_other_condition), std::move(filter_column_for_other_eq_condition)}; +} +} // namespace + +TiFlashJoin::TiFlashJoin(const tipb::Join & join_) // NOLINT(cppcoreguidelines-pro-type-member-init) + : join(join_) + , join_key_types(getJoinKeyTypes(join_)) + , join_key_collators(getJoinKeyCollators(join_, join_key_types)) +{ + std::tie(kind, build_side_index) = getJoinKindAndBuildSideIndex(join); + strictness = isSemiJoin() ? ASTTableJoin::Strictness::Any : ASTTableJoin::Strictness::All; +} + +String TiFlashJoin::genMatchHelperName(const Block & header1, const Block & header2) const +{ + if (!isLeftSemiFamily()) + { + return ""; + } + + size_t i = 0; + String match_helper_name = fmt::format("{}{}", Join::match_helper_prefix, i); + while (header1.has(match_helper_name) || header2.has(match_helper_name)) + { + match_helper_name = fmt::format("{}{}", Join::match_helper_prefix, ++i); + } + return match_helper_name; +} + +NamesAndTypes TiFlashJoin::genColumnsForOtherJoinFilter( + const Block & left_input_header, + const Block & right_input_header, + const ExpressionActionsPtr & probe_prepare_join_actions) const +{ +#ifndef NDEBUG + auto is_prepare_actions_valid = [](const Block & origin_block, const ExpressionActionsPtr & prepare_actions) { + const Block & prepare_sample_block = prepare_actions->getSampleBlock(); + for (const auto & p : origin_block) + { + if (!prepare_sample_block.has(p.name)) + return false; + } + return true; + }; + if (unlikely(!is_prepare_actions_valid(build_side_index == 1 ? left_input_header : right_input_header, probe_prepare_join_actions))) + { + throw TiFlashException("probe_prepare_join_actions isn't valid", Errors::Coprocessor::Internal); + } +#endif + + /// columns_for_other_join_filter is a vector of columns used + /// as the input columns when compiling other join filter. + /// Note the order in the column vector is very important: + /// first the columns in left_input_header, then followed + /// by the columns in right_input_header, if there are other + /// columns generated before compile other join filter, then + /// append the extra columns afterwards. In order to figure out + /// whether a given column is already in the column vector or + /// not quickly, we use another set to store the column names. + + /// The order of columns must be {left_input, right_input, extra columns}, + /// because tidb requires the input schema of join to be {left_input, right_input}. + /// Extra columns are appended to prevent extra columns from being repeatedly generated. + + NamesAndTypes columns_for_other_join_filter; + std::unordered_set column_set_for_origin_columns; + + auto append_origin_columns = [&columns_for_other_join_filter, &column_set_for_origin_columns](const Block & header, bool make_nullable) { + for (const auto & p : header) + { + columns_for_other_join_filter.emplace_back(p.name, make_nullable ? makeNullable(p.type) : p.type); + column_set_for_origin_columns.emplace(p.name); + } + }; + append_origin_columns(left_input_header, join.join_type() == tipb::JoinType::TypeRightOuterJoin); + append_origin_columns(right_input_header, join.join_type() == tipb::JoinType::TypeLeftOuterJoin); + + /// append the columns generated by probe side prepare join actions. + /// the new columns are + /// - filter_column and related temporary columns + /// - join keys and related temporary columns + auto append_new_columns = [&columns_for_other_join_filter, &column_set_for_origin_columns](const Block & header, bool make_nullable) { + for (const auto & p : header) + { + if (column_set_for_origin_columns.find(p.name) == column_set_for_origin_columns.end()) + columns_for_other_join_filter.emplace_back(p.name, make_nullable ? makeNullable(p.type) : p.type); + } + }; + bool make_nullable = build_side_index == 1 + ? join.join_type() == tipb::JoinType::TypeRightOuterJoin + : join.join_type() == tipb::JoinType::TypeLeftOuterJoin; + append_new_columns(probe_prepare_join_actions->getSampleBlock(), make_nullable); + + return columns_for_other_join_filter; +} + +/// all the columns from build side streams should be added after join, even for the join key. +NamesAndTypesList TiFlashJoin::genColumnsAddedByJoin( + const Block & build_side_header, + const String & match_helper_name) const +{ + NamesAndTypesList columns_added_by_join; + bool make_nullable = isTiFlashLeftJoin(); + for (auto const & p : build_side_header) + { + columns_added_by_join.emplace_back(p.name, make_nullable ? makeNullable(p.type) : p.type); + } + if (!match_helper_name.empty()) + { + columns_added_by_join.emplace_back(match_helper_name, Join::match_helper_type); + } + return columns_added_by_join; +} + +NamesAndTypes TiFlashJoin::genJoinOutputColumns( + const Block & left_input_header, + const Block & right_input_header, + const String & match_helper_name) const +{ + NamesAndTypes join_output_columns; + auto append_output_columns = [&join_output_columns](const Block & header, bool make_nullable) { + for (auto const & p : header) + { + join_output_columns.emplace_back(p.name, make_nullable ? makeNullable(p.type) : p.type); + } + }; + + append_output_columns(left_input_header, join.join_type() == tipb::JoinType::TypeRightOuterJoin); + if (!isSemiJoin()) + { + /// for semi join, the columns from right table will be ignored + append_output_columns(right_input_header, join.join_type() == tipb::JoinType::TypeLeftOuterJoin); + } + + if (!match_helper_name.empty()) + { + join_output_columns.emplace_back(match_helper_name, Join::match_helper_type); + } + + return join_output_columns; +} + +std::tuple TiFlashJoin::genJoinOtherConditionAction( + const Context & context, + const Block & left_input_header, + const Block & right_input_header, + const ExpressionActionsPtr & probe_side_prepare_join) const +{ + auto columns_for_other_join_filter + = genColumnsForOtherJoinFilter( + left_input_header, + right_input_header, + probe_side_prepare_join); + + return doGenJoinOtherConditionAction(context, join, columns_for_other_join_filter); +} + +std::tuple prepareJoin( + const Context & context, + const Block & input_header, + const google::protobuf::RepeatedPtrField & keys, + const DataTypes & key_types, + bool left, + bool is_right_out_join, + const google::protobuf::RepeatedPtrField & filters) +{ + NamesAndTypes source_columns; + for (auto const & p : input_header) + source_columns.emplace_back(p.name, p.type); + DAGExpressionAnalyzer dag_analyzer(std::move(source_columns), context); + ExpressionActionsChain chain; + Names key_names; + String filter_column_name; + dag_analyzer.appendJoinKeyAndJoinFilters(chain, keys, key_types, key_names, left, is_right_out_join, filters, filter_column_name); + return {chain.getLastActions(), std::move(key_names), std::move(filter_column_name)}; +} + +std::function concurrencyBuildIndexGenerator(size_t join_build_concurrency) +{ + size_t init_value = 0; + return [init_value, join_build_concurrency]() mutable { + return (init_value++) % join_build_concurrency; + }; +} +} // namespace DB::JoinInterpreterHelper \ No newline at end of file diff --git a/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.h b/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.h new file mode 100644 index 00000000000..d84c03d572d --- /dev/null +++ b/dbms/src/Flash/Coprocessor/JoinInterpreterHelper.h @@ -0,0 +1,133 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace DB +{ +class Context; + +namespace JoinInterpreterHelper +{ +struct TiFlashJoin +{ + explicit TiFlashJoin(const tipb::Join & join_); + + const tipb::Join & join; + + ASTTableJoin::Kind kind; + size_t build_side_index = 0; + + DataTypes join_key_types; + TiDB::TiDBCollators join_key_collators; + + ASTTableJoin::Strictness strictness; + + /// (cartesian) (anti) left semi join. + bool isLeftSemiFamily() const { return join.join_type() == tipb::JoinType::TypeLeftOuterSemiJoin || join.join_type() == tipb::JoinType::TypeAntiLeftOuterSemiJoin; } + + bool isSemiJoin() const { return join.join_type() == tipb::JoinType::TypeSemiJoin || join.join_type() == tipb::JoinType::TypeAntiSemiJoin || isLeftSemiFamily(); } + + const google::protobuf::RepeatedPtrField & getBuildJoinKeys() const + { + return build_side_index == 1 ? join.right_join_keys() : join.left_join_keys(); + } + + const google::protobuf::RepeatedPtrField & getProbeJoinKeys() const + { + return build_side_index == 0 ? join.right_join_keys() : join.left_join_keys(); + } + + const google::protobuf::RepeatedPtrField & getBuildConditions() const + { + return build_side_index == 1 ? join.right_conditions() : join.left_conditions(); + } + + const google::protobuf::RepeatedPtrField & getProbeConditions() const + { + return build_side_index == 0 ? join.right_conditions() : join.left_conditions(); + } + + bool isTiFlashLeftJoin() const { return kind == ASTTableJoin::Kind::Left || kind == ASTTableJoin::Kind::Cross_Left; } + + /// Cross_Right join will be converted to Cross_Left join, so no need to check Cross_Right + bool isTiFlashRightJoin() const { return kind == ASTTableJoin::Kind::Right; } + + /// return a name that is unique in header1 and header2 for left semi family join, + /// return "" for everything else. + String genMatchHelperName(const Block & header1, const Block & header2) const; + + /// columns_added_by_join + /// = join_output_columns - probe_side_columns + /// = build_side_columns + match_helper_name + NamesAndTypesList genColumnsAddedByJoin( + const Block & build_side_header, + const String & match_helper_name) const; + + /// The columns output by join will be: + /// {columns of left_input, columns of right_input, match_helper_name} + NamesAndTypes genJoinOutputColumns( + const Block & left_input_header, + const Block & right_input_header, + const String & match_helper_name) const; + + /// @other_condition_expr: generates other_filter_column and other_eq_filter_from_in_column + /// @other_filter_column_name: column name of `and(other_cond1, other_cond2, ...)` + /// @other_eq_filter_from_in_column_name: column name of `and(other_eq_cond1_from_in, other_eq_cond2_from_in, ...)` + /// such as + /// `select * from t where col1 in (select col2 from t2 where t1.col2 = t2.col3)` + /// - other_filter is `t1.col2 = t2.col3` + /// - other_eq_filter_from_in_column is `t1.col1 = t2.col2` + /// + /// new columns from build side prepare join actions cannot be appended. + /// because the input that other filter accepts is + /// {left_input_columns, right_input_columns, new_columns_from_probe_side_prepare, match_helper_name}. + std::tuple genJoinOtherConditionAction( + const Context & context, + const Block & left_input_header, + const Block & right_input_header, + const ExpressionActionsPtr & probe_side_prepare_join) const; + + NamesAndTypes genColumnsForOtherJoinFilter( + const Block & left_input_header, + const Block & right_input_header, + const ExpressionActionsPtr & probe_prepare_join_actions) const; +}; + +/// @join_prepare_expr_actions: generates join key columns and join filter column +/// @key_names: column names of keys. +/// @filter_column_name: column name of `and(filters)` +std::tuple prepareJoin( + const Context & context, + const Block & input_header, + const google::protobuf::RepeatedPtrField & keys, + const DataTypes & key_types, + bool left, + bool is_right_out_join, + const google::protobuf::RepeatedPtrField & filters); + +std::function concurrencyBuildIndexGenerator(size_t join_build_concurrency); +} // namespace JoinInterpreterHelper +} // namespace DB diff --git a/dbms/src/Flash/tests/exchange_perftest.cpp b/dbms/src/Flash/tests/exchange_perftest.cpp index 45dbac4a7f6..c2e047bec62 100644 --- a/dbms/src/Flash/tests/exchange_perftest.cpp +++ b/dbms/src/Flash/tests/exchange_perftest.cpp @@ -462,7 +462,7 @@ struct ReceiverHelper SizeLimits(0, 0, OverflowMode::THROW), ASTTableJoin::Kind::Inner, ASTTableJoin::Strictness::All, - concurrency, + /*req_id=*/"", TiDB::TiDBCollators{nullptr}, "", "", @@ -471,7 +471,7 @@ struct ReceiverHelper nullptr, 65536); - join_ptr->setSampleBlock(receiver_header); + join_ptr->init(receiver_header, concurrency); for (int i = 0; i < concurrency; ++i) streams[i] = std::make_shared(streams[i], join_ptr, i, /*req_id=*/""); diff --git a/dbms/src/Interpreters/ExpressionAnalyzer.cpp b/dbms/src/Interpreters/ExpressionAnalyzer.cpp index cd947d08953..a532ed8a8e0 100644 --- a/dbms/src/Interpreters/ExpressionAnalyzer.cpp +++ b/dbms/src/Interpreters/ExpressionAnalyzer.cpp @@ -2435,7 +2435,7 @@ bool ExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain, bool only_ty /// TODO You do not need to set this up when JOIN is only needed on remote servers. subquery_for_set.join = join; - subquery_for_set.join->setSampleBlock(subquery_for_set.source->getHeader()); + subquery_for_set.join->init(subquery_for_set.source->getHeader()); } addJoinAction(step.actions, false); diff --git a/dbms/src/Interpreters/Join.cpp b/dbms/src/Interpreters/Join.cpp index ab37a1cb29b..820618a6e8b 100644 --- a/dbms/src/Interpreters/Join.cpp +++ b/dbms/src/Interpreters/Join.cpp @@ -38,40 +38,67 @@ extern const int TYPE_MISMATCH; extern const int ILLEGAL_COLUMN; } // namespace ErrorCodes +namespace +{ /// Do I need to use the hash table maps_*_full, in which we remember whether the row was joined. -static bool getFullness(ASTTableJoin::Kind kind) +bool getFullness(ASTTableJoin::Kind kind) { return kind == ASTTableJoin::Kind::Right || kind == ASTTableJoin::Kind::Cross_Right || kind == ASTTableJoin::Kind::Full; } -static bool isLeftJoin(ASTTableJoin::Kind kind) +bool isLeftJoin(ASTTableJoin::Kind kind) { return kind == ASTTableJoin::Kind::Left || kind == ASTTableJoin::Kind::Cross_Left; } -static bool isRightJoin(ASTTableJoin::Kind kind) +bool isRightJoin(ASTTableJoin::Kind kind) { return kind == ASTTableJoin::Kind::Right || kind == ASTTableJoin::Kind::Cross_Right; } -static bool isInnerJoin(ASTTableJoin::Kind kind) +bool isInnerJoin(ASTTableJoin::Kind kind) { return kind == ASTTableJoin::Kind::Inner || kind == ASTTableJoin::Kind::Cross; } -static bool isAntiJoin(ASTTableJoin::Kind kind) +bool isAntiJoin(ASTTableJoin::Kind kind) { return kind == ASTTableJoin::Kind::Anti || kind == ASTTableJoin::Kind::Cross_Anti; } -static bool isCrossJoin(ASTTableJoin::Kind kind) +bool isCrossJoin(ASTTableJoin::Kind kind) { return kind == ASTTableJoin::Kind::Cross || kind == ASTTableJoin::Kind::Cross_Left || kind == ASTTableJoin::Kind::Cross_Right || kind == ASTTableJoin::Kind::Cross_Anti || kind == ASTTableJoin::Kind::Cross_LeftSemi || kind == ASTTableJoin::Kind::Cross_LeftAnti; } /// (cartesian) (anti) left semi join. -static bool isLeftSemiFamily(ASTTableJoin::Kind kind) +bool isLeftSemiFamily(ASTTableJoin::Kind kind) { return kind == ASTTableJoin::Kind::LeftSemi || kind == ASTTableJoin::Kind::LeftAnti || kind == ASTTableJoin::Kind::Cross_LeftSemi || kind == ASTTableJoin::Kind::Cross_LeftAnti; } +void convertColumnToNullable(ColumnWithTypeAndName & column) +{ + column.type = makeNullable(column.type); + if (column.column) + column.column = makeNullable(column.column); +} + +ColumnRawPtrs getKeyColumns(const Names & key_names, const Block & block) +{ + size_t keys_size = key_names.size(); + ColumnRawPtrs key_columns(keys_size); + + for (size_t i = 0; i < keys_size; ++i) + { + key_columns[i] = block.getByName(key_names[i]).column.get(); + + /// We will join only keys, where all components are not NULL. + if (key_columns[i]->isColumnNullable()) + key_columns[i] = &static_cast(*key_columns[i]).getNestedColumn(); + } + + return key_columns; +} +} // namespace + const std::string Join::match_helper_prefix = "__left-semi-join-match-helper"; const DataTypePtr Join::match_helper_type = makeNullable(std::make_shared()); @@ -84,7 +111,6 @@ Join::Join( ASTTableJoin::Kind kind_, ASTTableJoin::Strictness strictness_, const String & req_id, - size_t build_concurrency_, const TiDB::TiDBCollators & collators_, const String & left_filter_column_, const String & right_filter_column_, @@ -99,7 +125,7 @@ Join::Join( , key_names_left(key_names_left_) , key_names_right(key_names_right_) , use_nulls(use_nulls_) - , build_concurrency(std::max(1, build_concurrency_)) + , build_concurrency(0) , build_set_exceeded(false) , collators(collators_) , left_filter_column(left_filter_column_) @@ -113,8 +139,6 @@ Join::Join( , log(Logger::get("Join", req_id)) , limits(limits) { - for (size_t i = 0; i < build_concurrency; i++) - pools.emplace_back(std::make_shared()); if (other_condition_ptr != nullptr) { /// if there is other_condition, then should keep all the valid rows during probe stage @@ -123,14 +147,9 @@ Join::Join( strictness = ASTTableJoin::Strictness::All; } } - if (getFullness(kind)) - { - for (size_t i = 0; i < build_concurrency; i++) - rows_not_inserted_to_map.push_back(std::make_unique()); - } - if (!left_filter_column.empty() && !isLeftJoin(kind)) + if (unlikely(!left_filter_column.empty() && !isLeftJoin(kind))) throw Exception("Not supported: non left join with left conditions"); - if (!right_filter_column.empty() && !isRightJoin(kind)) + if (unlikely(!right_filter_column.empty() && !isRightJoin(kind))) throw Exception("Not supported: non right join with right conditions"); } @@ -324,7 +343,7 @@ struct KeyGetterForType using Type = typename KeyGetterForTypeImpl::Type; }; -void Join::init(Type type_) +void Join::initMapImpl(Type type_) { type = type_; @@ -334,16 +353,16 @@ void Join::init(Type type_) if (!getFullness(kind)) { if (strictness == ASTTableJoin::Strictness::Any) - initImpl(maps_any, type, build_concurrency); + initImpl(maps_any, type, getBuildConcurrencyInternal()); else - initImpl(maps_all, type, build_concurrency); + initImpl(maps_all, type, getBuildConcurrencyInternal()); } else { if (strictness == ASTTableJoin::Strictness::Any) - initImpl(maps_any_full, type, build_concurrency); + initImpl(maps_any_full, type, getBuildConcurrencyInternal()); else - initImpl(maps_all_full, type, build_concurrency); + initImpl(maps_all_full, type, getBuildConcurrencyInternal()); } } @@ -392,37 +411,24 @@ size_t Join::getTotalByteCount() const return res; } - -static void convertColumnToNullable(ColumnWithTypeAndName & column) +void Join::setBuildConcurrencyAndInitPool(size_t build_concurrency_) { - column.type = makeNullable(column.type); - if (column.column) - column.column = makeNullable(column.column); -} - + if (unlikely(build_concurrency > 0)) + throw Exception("Logical error: `setBuildConcurrencyAndInitPool` shouldn't be called more than once", ErrorCodes::LOGICAL_ERROR); + build_concurrency = std::max(1, build_concurrency_); -void Join::setSampleBlock(const Block & block) -{ - std::unique_lock lock(rwlock); - - if (!empty()) - return; - - size_t keys_size = key_names_right.size(); - ColumnRawPtrs key_columns(keys_size); - - for (size_t i = 0; i < keys_size; ++i) + for (size_t i = 0; i < getBuildConcurrencyInternal(); ++i) + pools.emplace_back(std::make_shared()); + // init for non-joined-streams. + if (getFullness(kind)) { - key_columns[i] = block.getByName(key_names_right[i]).column.get(); - - /// We will join only keys, where all components are not NULL. - if (key_columns[i]->isColumnNullable()) - key_columns[i] = &static_cast(*key_columns[i]).getNestedColumn(); + for (size_t i = 0; i < getNotJoinedStreamConcurrencyInternal(); ++i) + rows_not_inserted_to_map.push_back(std::make_unique()); } +} - /// Choose data structure to use for JOIN. - init(chooseMethod(key_columns, key_sizes)); - +void Join::setSampleBlock(const Block & block) +{ sample_block_with_columns_to_add = materializeBlock(block); /// Move from `sample_block_with_columns_to_add` key columns to `sample_block_with_keys`, keeping the order. @@ -457,6 +463,18 @@ void Join::setSampleBlock(const Block & block) sample_block_with_columns_to_add.insert(ColumnWithTypeAndName(Join::match_helper_type, match_helper_name)); } +void Join::init(const Block & sample_block, size_t build_concurrency_) +{ + std::unique_lock lock(rwlock); + if (unlikely(initialized)) + throw Exception("Logical error: Join has been initialized", ErrorCodes::LOGICAL_ERROR); + initialized = true; + setBuildConcurrencyAndInitPool(build_concurrency_); + /// Choose data structure to use for JOIN. + initMapImpl(chooseMethod(getKeyColumns(key_names_right, sample_block), key_sizes)); + setSampleBlock(sample_block); +} + namespace { @@ -757,9 +775,9 @@ void recordFilteredRows(const Block & block, const String & filter_column, Colum bool Join::insertFromBlock(const Block & block) { - if (empty()) - throw Exception("Logical error: Join was not initialized", ErrorCodes::LOGICAL_ERROR); std::unique_lock lock(rwlock); + if (unlikely(!initialized)) + throw Exception("Logical error: Join was not initialized", ErrorCodes::LOGICAL_ERROR); blocks.push_back(block); Block * stored_block = &blocks.back(); return insertFromBlockInternal(stored_block, 0); @@ -768,11 +786,12 @@ bool Join::insertFromBlock(const Block & block) /// the block should be valid. void Join::insertFromBlock(const Block & block, size_t stream_index) { - assert(stream_index < build_concurrency); + std::shared_lock lock(rwlock); + assert(stream_index < getBuildConcurrencyInternal()); + assert(stream_index < getNotJoinedStreamConcurrencyInternal()); - if (empty()) + if (unlikely(!initialized)) throw Exception("Logical error: Join was not initialized", ErrorCodes::LOGICAL_ERROR); - std::shared_lock lock(rwlock); Block * stored_block = nullptr; { std::lock_guard lk(blocks_lock); @@ -868,16 +887,16 @@ bool Join::insertFromBlockInternal(Block * stored_block, size_t stream_index) if (!getFullness(kind)) { if (strictness == ASTTableJoin::Strictness::Any) - insertFromBlockImpl(type, maps_any, rows, key_columns, key_sizes, collators, stored_block, null_map, nullptr, stream_index, build_concurrency, *pools[stream_index]); + insertFromBlockImpl(type, maps_any, rows, key_columns, key_sizes, collators, stored_block, null_map, nullptr, stream_index, getBuildConcurrencyInternal(), *pools[stream_index]); else - insertFromBlockImpl(type, maps_all, rows, key_columns, key_sizes, collators, stored_block, null_map, nullptr, stream_index, build_concurrency, *pools[stream_index]); + insertFromBlockImpl(type, maps_all, rows, key_columns, key_sizes, collators, stored_block, null_map, nullptr, stream_index, getBuildConcurrencyInternal(), *pools[stream_index]); } else { if (strictness == ASTTableJoin::Strictness::Any) - insertFromBlockImpl(type, maps_any_full, rows, key_columns, key_sizes, collators, stored_block, null_map, rows_not_inserted_to_map[stream_index].get(), stream_index, build_concurrency, *pools[stream_index]); + insertFromBlockImpl(type, maps_any_full, rows, key_columns, key_sizes, collators, stored_block, null_map, rows_not_inserted_to_map[stream_index].get(), stream_index, getBuildConcurrencyInternal(), *pools[stream_index]); else - insertFromBlockImpl(type, maps_all_full, rows, key_columns, key_sizes, collators, stored_block, null_map, rows_not_inserted_to_map[stream_index].get(), stream_index, build_concurrency, *pools[stream_index]); + insertFromBlockImpl(type, maps_all_full, rows, key_columns, key_sizes, collators, stored_block, null_map, rows_not_inserted_to_map[stream_index].get(), stream_index, getBuildConcurrencyInternal(), *pools[stream_index]); } } @@ -1954,7 +1973,8 @@ class NonJoinedBlockInputStream : public IProfilingBlockInputStream , max_block_size(max_block_size_) , add_not_mapped_rows(true) { - if (step > parent.build_concurrency || index >= parent.build_concurrency) + size_t build_concurrency = parent.getBuildConcurrency(); + if (unlikely(step > build_concurrency || index >= build_concurrency)) throw Exception("The concurrency of NonJoinedBlockInputStream should not be larger than join build concurrency"); /** left_sample_block contains keys and "left" columns. diff --git a/dbms/src/Interpreters/Join.h b/dbms/src/Interpreters/Join.h index 89dad0d1ca6..01916aa1dcc 100644 --- a/dbms/src/Interpreters/Join.h +++ b/dbms/src/Interpreters/Join.h @@ -99,7 +99,6 @@ class Join ASTTableJoin::Kind kind_, ASTTableJoin::Strictness strictness_, const String & req_id, - size_t build_concurrency = 1, const TiDB::TiDBCollators & collators_ = TiDB::dummy_collators, const String & left_filter_column = "", const String & right_filter_column = "", @@ -109,17 +108,10 @@ class Join size_t max_block_size = 0, const String & match_helper_name = ""); - bool empty() { return type == Type::EMPTY; } - - /** Set information about structure of right hand of JOIN (joined data). + /** Call `setBuildConcurrencyAndInitPool`, `initMapImpl` and `setSampleBlock`. * You must call this method before subsequent calls to insertFromBlock. */ - void setSampleBlock(const Block & block); - - /** Add block of data from right hand of JOIN to the map. - * Returns false, if some limit was exceeded and you should not insert more data. - */ - bool insertFromBlockInternal(Block * stored_block, size_t stream_index); + void init(const Block & sample_block, size_t build_concurrency_ = 1); bool insertFromBlock(const Block & block); @@ -153,9 +145,19 @@ class Join bool useNulls() const { return use_nulls; } const Names & getLeftJoinKeys() const { return key_names_left; } - size_t getBuildConcurrency() const { return build_concurrency; } + + size_t getBuildConcurrency() const + { + std::shared_lock lock(rwlock); + return getBuildConcurrencyInternal(); + } + size_t getNotJoinedStreamConcurrency() const + { + std::shared_lock lock(rwlock); + return getNotJoinedStreamConcurrencyInternal(); + } + bool isBuildSetExceeded() const { return build_set_exceeded.load(); } - size_t getNotJoinedStreamConcurrency() const { return build_concurrency; }; enum BuildTableState { @@ -171,7 +173,7 @@ class Join const Block * block; size_t row_num; - RowRef() {} + RowRef() = default; RowRef(const Block * block_, size_t row_num_) : block(block_) , row_num(row_num_) @@ -183,7 +185,7 @@ class Join { RowRefList * next = nullptr; - RowRefList() {} + RowRefList() = default; RowRefList(const Block * block_, size_t row_num_) : RowRef(block_, row_num_) {} @@ -342,11 +344,40 @@ class Join */ mutable std::shared_mutex rwlock; - void init(Type type_); + bool initialized = false; + + size_t getBuildConcurrencyInternal() const + { + if (unlikely(build_concurrency == 0)) + throw Exception("Logical error: `setBuildConcurrencyAndInitPool` has not been called", ErrorCodes::LOGICAL_ERROR); + return build_concurrency; + } + size_t getNotJoinedStreamConcurrencyInternal() const + { + return getBuildConcurrencyInternal(); + } + + /// Initialize map implementations for various join types. + void initMapImpl(Type type_); + + /** Set information about structure of right hand of JOIN (joined data). + * You must call this method before subsequent calls to insertFromBlock. + */ + void setSampleBlock(const Block & block); + + /** Set Join build concurrency and init hash map. + * You must call this method before subsequent calls to insertFromBlock. + */ + void setBuildConcurrencyAndInitPool(size_t build_concurrency_); /// Throw an exception if blocks have different types of key columns. void checkTypesOfKeys(const Block & block_left, const Block & block_right) const; + /** Add block of data from right hand of JOIN to the map. + * Returns false, if some limit was exceeded and you should not insert more data. + */ + bool insertFromBlockInternal(Block * stored_block, size_t stream_index); + template void joinBlockImpl(Block & block, const Maps & maps) const; diff --git a/dbms/src/Storages/StorageJoin.cpp b/dbms/src/Storages/StorageJoin.cpp index 4e3c01c6574..4ca3e79a7ab 100644 --- a/dbms/src/Storages/StorageJoin.cpp +++ b/dbms/src/Storages/StorageJoin.cpp @@ -52,7 +52,7 @@ StorageJoin::StorageJoin( /// NOTE StorageJoin doesn't use join_use_nulls setting. join = std::make_shared(key_names, key_names, false /* use_nulls */, SizeLimits(), kind, strictness, /*req_id=*/""); - join->setSampleBlock(getSampleBlock().sortColumns()); + join->init(getSampleBlock().sortColumns()); restore(); } diff --git a/tests/fullstack-test/mpp/misc_join.test b/tests/fullstack-test/mpp/misc_join.test new file mode 100644 index 00000000000..61a1de49925 --- /dev/null +++ b/tests/fullstack-test/mpp/misc_join.test @@ -0,0 +1,41 @@ +# Copyright 2022 PingCAP, Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Preparation. +mysql> drop table if exists test.t1; +mysql> create table test.t1 (id decimal(5,2), value bigint(20)); +mysql> insert into test.t1 values(1, 1),(2, 2); +mysql> drop table if exists test.t2; +mysql> create table test.t2 (id decimal(5,2), value bigint(20)); +mysql> insert into test.t2 values(1, 1),(2, 2),(3, 3),(4, 4); + +mysql> alter table test.t1 set tiflash replica 1 +mysql> alter table test.t2 set tiflash replica 1 +mysql> analyze table test.t1 +mysql> analyze table test.t2 + +func> wait_table test t1 +func> wait_table test t2 + +mysql> use test; set tidb_allow_mpp=1; set tidb_enforce_mpp=1; set tidb_isolation_read_engines='tiflash'; select * from t1 left join t2 on cast(t1.id as decimal(7,2)) = cast(t2.id as decimal(7,2)) and t1.id + cast(t2.id as decimal(7,2)) + t1.id > 10; ++------+-------+------+-------+ +| id | value | id | value | ++------+-------+------+-------+ +| 1.00 | 1 | NULL | NULL | +| 2.00 | 2 | NULL | NULL | ++------+-------+------+-------+ + +# Clean up. +mysql> drop table if exists test.t1 +mysql> drop table if exists test.t2