diff --git a/ydb/core/kqp/opt/logical/kqp_opt_log.cpp b/ydb/core/kqp/opt/logical/kqp_opt_log.cpp index 88ab33b50de0..8817d8709e82 100644 --- a/ydb/core/kqp/opt/logical/kqp_opt_log.cpp +++ b/ydb/core/kqp/opt/logical/kqp_opt_log.cpp @@ -200,14 +200,14 @@ class TKqpLogicalOptTransformer : public TOptimizeTransformerBase { TMaybeNode RewriteEquiJoin(TExprBase node, TExprContext& ctx) { bool useCBO = Config->CostBasedOptimizationLevel.Get().GetOrElse(Config->DefaultCostBasedOptimizationLevel) >= 2; - TExprBase output = DqRewriteEquiJoin(node, KqpCtx.Config->GetHashJoinMode(), useCBO, ctx, TypesCtx, KqpCtx.JoinsCount); + TExprBase output = DqRewriteEquiJoin(node, KqpCtx.Config->GetHashJoinMode(), useCBO, ctx, TypesCtx, KqpCtx.JoinsCount, KqpCtx.GetOptimizerHints()); DumpAppliedRule("RewriteEquiJoin", node.Ptr(), output.Ptr(), ctx); return output; } TMaybeNode JoinToIndexLookup(TExprBase node, TExprContext& ctx) { bool useCBO = Config->CostBasedOptimizationLevel.Get().GetOrElse(Config->DefaultCostBasedOptimizationLevel) >= 2; - TExprBase output = KqpJoinToIndexLookup(node, ctx, KqpCtx, useCBO); + TExprBase output = KqpJoinToIndexLookup(node, ctx, KqpCtx, useCBO, KqpCtx.GetOptimizerHints()); DumpAppliedRule("JoinToIndexLookup", node.Ptr(), output.Ptr(), ctx); return output; } diff --git a/ydb/core/kqp/opt/logical/kqp_opt_log_join.cpp b/ydb/core/kqp/opt/logical/kqp_opt_log_join.cpp index 6c32c107ad83..da15a663f954 100644 --- a/ydb/core/kqp/opt/logical/kqp_opt_log_join.cpp +++ b/ydb/core/kqp/opt/logical/kqp_opt_log_join.cpp @@ -331,7 +331,7 @@ bool IsParameterToListOfStructsRepack(const TExprBase& expr) { return true; } -//#define DBG(...) YQL_CLOG(DEBUG, ProviderKqp) << __VA_ARGS__ +// #define DBG(...) YQL_CLOG(DEBUG, ProviderKqp) << __VA_ARGS__ #define DBG(...) TMaybeNode BuildKqpStreamIndexLookupJoin( @@ -935,7 +935,38 @@ TMaybeNode KqpJoinToIndexLookupImpl(const TDqJoin& join, TExprContext } // anonymous namespace -TExprBase KqpJoinToIndexLookup(const TExprBase& node, TExprContext& ctx, const TKqpOptimizeContext& kqpCtx, bool useCBO) +TVector CollectLabels(const TExprBase& node) { + TVector rels; + + if (node.Maybe()) { + auto precompute = node.Cast(); + return CollectLabels(precompute.Input()); + } + + if (node.Maybe()) { + auto join = node.Cast(); + + if (join.LeftLabel().Maybe()) { + rels.push_back(join.LeftLabel().Cast().StringValue()); + } else { + auto lhs = CollectLabels(join.LeftInput()); + rels.insert(rels.end(), std::make_move_iterator(lhs.begin()), std::make_move_iterator(lhs.end())); + } + + if (join.RightLabel().Maybe()) { + rels.push_back(join.RightLabel().Cast().StringValue()); + } else { + auto rhs = CollectLabels(join.RightInput()); + rels.insert(rels.end(), std::make_move_iterator(rhs.begin()), std::make_move_iterator(rhs.end())); + } + + return rels; + } + + return {}; +} + +TExprBase KqpJoinToIndexLookup(const TExprBase& node, TExprContext& ctx, const TKqpOptimizeContext& kqpCtx, bool useCBO, const TOptimizerHints& hints) { if (!node.Maybe()) { return node; @@ -952,11 +983,26 @@ TExprBase KqpJoinToIndexLookup(const TExprBase& node, TExprContext& ctx, const T return node; } - if (useCBO){ - - if (algo != EJoinAlgoType::LookupJoin && algo != EJoinAlgoType::LookupJoinReverse) { + if (useCBO && algo != EJoinAlgoType::LookupJoin && algo != EJoinAlgoType::LookupJoinReverse){ + return node; + } + + /* + * this cycle looks for applied hints for these join labels. if we've found one then we will leave the function. + * But if it is a LookupJoin we will rewrite it with KqpJoinToIndexLookupImpl because lookup join needs to be rewritten + */ + auto joinLabels = CollectLabels(node); + for (const auto& hint: hints.JoinAlgoHints->Hints) { + if ( + std::unordered_set(hint.JoinLabels.begin(), hint.JoinLabels.end()) == + std::unordered_set(joinLabels.begin(), joinLabels.end()) && hint.Applied + ) { + if (hint.Algo == EJoinAlgoType::LookupJoin || hint.Algo == EJoinAlgoType::LookupJoinReverse) { + break; + } + return node; - } + } } DBG("-- Join: " << KqpExprToPrettyString(join, ctx)); @@ -964,8 +1010,6 @@ TExprBase KqpJoinToIndexLookup(const TExprBase& node, TExprContext& ctx, const T // SqlIn support (preferred lookup direction) if (join.JoinType().Value() == "LeftSemi") { auto flipJoin = FlipLeftSemiJoin(join, ctx); - DBG("-- Flip join"); - if (auto indexLookupJoin = KqpJoinToIndexLookupImpl(flipJoin, ctx, kqpCtx)) { return indexLookupJoin.Cast(); } diff --git a/ydb/core/kqp/opt/logical/kqp_opt_log_rules.h b/ydb/core/kqp/opt/logical/kqp_opt_log_rules.h index 40a6ed4f0b01..9ae28f4cbe30 100644 --- a/ydb/core/kqp/opt/logical/kqp_opt_log_rules.h +++ b/ydb/core/kqp/opt/logical/kqp_opt_log_rules.h @@ -25,7 +25,7 @@ NYql::NNodes::TExprBase KqpPushExtractedPredicateToReadTable(NYql::NNodes::TExpr const TKqpOptimizeContext& kqpCtx, NYql::TTypeAnnotationContext& typesCtx, const NYql::TParentsMap& parentsMap); NYql::NNodes::TExprBase KqpJoinToIndexLookup(const NYql::NNodes::TExprBase& node, NYql::TExprContext& ctx, - const TKqpOptimizeContext& kqpCtx, bool useCBO); + const TKqpOptimizeContext& kqpCtx, bool useCBO, const NYql::TOptimizerHints& hints); NYql::NNodes::TExprBase KqpRewriteSqlInToEquiJoin(const NYql::NNodes::TExprBase& node, NYql::TExprContext& ctx, const TKqpOptimizeContext& kqpCtx, const NYql::TKikimrConfiguration::TPtr& config); diff --git a/ydb/core/kqp/ut/join/data/queries/oltp_join_type_hint_cbo_turnoff.sql b/ydb/core/kqp/ut/join/data/queries/oltp_join_type_hint_cbo_turnoff.sql new file mode 100644 index 000000000000..bb6c1898545a --- /dev/null +++ b/ydb/core/kqp/ut/join/data/queries/oltp_join_type_hint_cbo_turnoff.sql @@ -0,0 +1,14 @@ +PRAGMA TablePathPrefix='/Root'; +PRAGMA ydb.OptimizerHints = +' + JoinType(R S Shuffle) + JoinType(R S T Broadcast) + JoinType(R S T U Shuffle) + JoinType(R S T U V Broadcast) +'; + +SELECT * FROM + R INNER JOIN S on R.id = S.id + INNER JOIN T on R.id = T.id + INNER JOIN U on T.id = U.id + INNER JOIN V on U.id = V.id; diff --git a/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp b/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp index 29a95d24f186..5b2bcbe65ddc 100644 --- a/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp +++ b/ydb/core/kqp/ut/join/kqp_join_order_ut.cpp @@ -81,14 +81,11 @@ static void CreateSampleTable(TSession session, bool useColumnStore) { CreateTables(session, "schema/lookupbug.sql", useColumnStore); } -static TKikimrRunner GetKikimrWithJoinSettings(bool useStreamLookupJoin = false, TString stats = ""){ +static TKikimrRunner GetKikimrWithJoinSettings(bool useStreamLookupJoin = false, TString stats = "", bool useCBO = true){ TVector settings; NKikimrKqp::TKqpSetting setting; - setting.SetName("CostBasedOptimizationLevel"); - setting.SetValue("4"); - settings.push_back(setting); if (stats != "") { setting.SetName("OptOverrideStatistics"); @@ -100,6 +97,9 @@ static TKikimrRunner GetKikimrWithJoinSettings(bool useStreamLookupJoin = false, appConfig.MutableTableServiceConfig()->SetEnableKqpDataQueryStreamIdxLookupJoin(useStreamLookupJoin); appConfig.MutableTableServiceConfig()->SetEnableConstantFolding(true); appConfig.MutableTableServiceConfig()->SetCompileTimeoutMs(TDuration::Minutes(10).MilliSeconds()); + if (!useCBO) { + appConfig.MutableTableServiceConfig()->SetDefaultCostBasedOptimizationLevel(0); + } auto serverSettings = TKikimrSettings().SetAppConfig(appConfig); serverSettings.SetKqpSettings(settings); @@ -197,8 +197,8 @@ class TChainTester { size_t ChainSize; }; -void ExplainJoinOrderTestDataQueryWithStats(const TString& queryPath, const TString& statsPath, bool useStreamLookupJoin, bool useColumnStore) { - auto kikimr = GetKikimrWithJoinSettings(useStreamLookupJoin, GetStatic(statsPath)); +void ExplainJoinOrderTestDataQueryWithStats(const TString& queryPath, const TString& statsPath, bool useStreamLookupJoin, bool useColumnStore, bool useCBO = true) { + auto kikimr = GetKikimrWithJoinSettings(useStreamLookupJoin, GetStatic(statsPath), useCBO); auto db = kikimr.GetTableClient(); auto session = db.CreateSession().GetValueSync().GetSession(); @@ -333,8 +333,8 @@ Y_UNIT_TEST_SUITE(KqpJoinOrder) { // TChainTester(65).Test(); //} - TString ExecuteJoinOrderTestDataQueryWithStats(const TString& queryPath, const TString& statsPath, bool useStreamLookupJoin, bool useColumnStore) { - auto kikimr = GetKikimrWithJoinSettings(useStreamLookupJoin, GetStatic(statsPath)); + TString ExecuteJoinOrderTestDataQueryWithStats(const TString& queryPath, const TString& statsPath, bool useStreamLookupJoin, bool useColumnStore, bool useCBO = true) { + auto kikimr = GetKikimrWithJoinSettings(useStreamLookupJoin, GetStatic(statsPath), useCBO); auto db = kikimr.GetTableClient(); auto session = db.CreateSession().GetValueSync().GetSession(); @@ -518,6 +518,69 @@ Y_UNIT_TEST_SUITE(KqpJoinOrder) { CheckJoinCardinality("queries/test_join_hint2.sql", "stats/basic.json", "InnerJoin (MapJoin)", 1, StreamLookupJoin, ColumnStore); } + + class TFindJoinWithLabels { + public: + TFindJoinWithLabels( + const NJson::TJsonValue& plan + ) + : Plan(plan) + {} + + TString Find(const TVector& labels) { + Labels = labels; + std::sort(Labels.begin(), Labels.end()); + TVector dummy; + auto res = FindImpl(Plan, dummy); + return res; + } + + private: + TString FindImpl(const NJson::TJsonValue& plan, TVector& subtreeLabels) { + auto planMap = plan.GetMapSafe(); + if (!planMap.contains("table")) { + TString opName = planMap.at("op_name").GetStringSafe(); + + auto inputs = planMap.at("args").GetArraySafe(); + for (size_t i = 0; i < inputs.size(); ++i) { + TVector childLabels; + if (auto maybeOpName = FindImpl(inputs[i], childLabels) ) { + return maybeOpName; + } + subtreeLabels.insert(subtreeLabels.end(), childLabels.begin(), childLabels.end()); + } + + if (AreRequestedLabels(subtreeLabels)) { + return opName; + } + + return ""; + } + + subtreeLabels = {planMap.at("table").GetStringSafe()}; + return ""; + } + + bool AreRequestedLabels(TVector labels) { + std::sort(labels.begin(), labels.end()); + return Labels == labels; + } + + NJson::TJsonValue Plan; + TVector Labels; + }; + + Y_UNIT_TEST(OltpJoinTypeHintCBOTurnOFF) { + auto plan = ExecuteJoinOrderTestDataQueryWithStats("queries/oltp_join_type_hint_cbo_turnoff.sql", "stats/basic.json", false, false, false); + auto detailedPlan = GetDetailedJoinOrder(plan); + + auto joinFinder = TFindJoinWithLabels(detailedPlan); + UNIT_ASSERT(joinFinder.Find({"R", "S"}) == "InnerJoin (Grace)"); + UNIT_ASSERT(joinFinder.Find({"R", "S", "T"}) == "InnerJoin (MapJoin)"); + UNIT_ASSERT(joinFinder.Find({"R", "S", "T", "U"}) == "InnerJoin (Grace)"); + UNIT_ASSERT(joinFinder.Find({"R", "S", "T", "U", "V"}) == "InnerJoin (MapJoin)"); + } + Y_UNIT_TEST_XOR_OR_BOTH_FALSE(TestJoinOrderHintsSimple, StreamLookupJoin, ColumnStore) { auto plan = ExecuteJoinOrderTestDataQueryWithStats("queries/join_order_hints_simple.sql", "stats/basic.json", StreamLookupJoin, ColumnStore); UNIT_ASSERT_VALUES_EQUAL(GetJoinOrder(plan).GetStringRobust(), R"(["T",["R","S"]])") ; diff --git a/ydb/library/yql/dq/opt/dq_opt_join.cpp b/ydb/library/yql/dq/opt/dq_opt_join.cpp index b1dea51ea3b9..42308ac1a7e2 100644 --- a/ydb/library/yql/dq/opt/dq_opt_join.cpp +++ b/ydb/library/yql/dq/opt/dq_opt_join.cpp @@ -116,44 +116,67 @@ TExprBase BuildDqJoinInput(TExprContext& ctx, TPositionHandle pos, const TExprBa return partition; } -TMaybe BuildDqJoin(const TCoEquiJoinTuple& joinTuple, - const THashMap& inputs, EHashJoinMode mode, bool useCBO, TExprContext& ctx, const TTypeAnnotationContext& typeCtx) +TMaybe BuildDqJoin( + const TCoEquiJoinTuple& joinTuple, + const THashMap& inputs, + EHashJoinMode mode, + TExprContext& ctx, + const TTypeAnnotationContext& typeCtx, + TVector& subtreeLabels, + const NYql::TOptimizerHints& hints +) { - auto options = joinTuple.Options(); - auto linkSettings = GetEquiJoinLinkSettings(options.Ref()); - YQL_ENSURE(linkSettings.JoinAlgo != EJoinAlgoType::StreamLookupJoin || typeCtx.StreamLookupJoin, "Unsupported join strategy: streamlookup"); - - if (linkSettings.JoinAlgo == EJoinAlgoType::MapJoin) { - mode = EHashJoinMode::Map; - } else if (linkSettings.JoinAlgo == EJoinAlgoType::GraceJoin) { - mode = EHashJoinMode::GraceAndSelf; - } - - bool leftAny = linkSettings.LeftHints.contains("any"); - bool rightAny = linkSettings.RightHints.contains("any"); - TMaybe left; + TVector lhsLabels; if (joinTuple.LeftScope().Maybe()) { + lhsLabels.push_back(joinTuple.LeftScope().Cast().StringValue()); left = inputs.at(joinTuple.LeftScope().Cast().Value()); YQL_ENSURE(left, "unknown scope " << joinTuple.LeftScope().Cast().Value()); } else { - left = BuildDqJoin(joinTuple.LeftScope().Cast(), inputs, mode, useCBO, ctx, typeCtx); + left = BuildDqJoin(joinTuple.LeftScope().Cast(), inputs, mode, ctx, typeCtx, lhsLabels, hints); if (!left) { return {}; } } TMaybe right; + TVector rhsLabels; if (joinTuple.RightScope().Maybe()) { + rhsLabels.push_back(joinTuple.RightScope().Cast().StringValue()); right = inputs.at(joinTuple.RightScope().Cast().Value()); YQL_ENSURE(right, "unknown scope " << joinTuple.RightScope().Cast().Value()); } else { - right = BuildDqJoin(joinTuple.RightScope().Cast(), inputs, mode, useCBO, ctx, typeCtx); + right = BuildDqJoin(joinTuple.RightScope().Cast(), inputs, mode, ctx, typeCtx, rhsLabels, hints); if (!right) { return {}; } } + subtreeLabels.insert(subtreeLabels.end(), std::make_move_iterator(lhsLabels.begin()), std::make_move_iterator(lhsLabels.end())); + subtreeLabels.insert(subtreeLabels.end(), std::make_move_iterator(rhsLabels.begin()), std::make_move_iterator(rhsLabels.end())); + + auto options = joinTuple.Options(); + auto linkSettings = GetEquiJoinLinkSettings(options.Ref()); + for (auto& hint: hints.JoinAlgoHints->Hints) { + if ( + std::unordered_set(hint.JoinLabels.begin(), hint.JoinLabels.end()) == + std::unordered_set(subtreeLabels.begin(), subtreeLabels.end()) + ) { + linkSettings.JoinAlgo = hint.Algo; + hint.Applied = true; + } + } + YQL_ENSURE(linkSettings.JoinAlgo != EJoinAlgoType::StreamLookupJoin || typeCtx.StreamLookupJoin, "Unsupported join strategy: streamlookup"); + + if (linkSettings.JoinAlgo == EJoinAlgoType::MapJoin) { + mode = EHashJoinMode::Map; + } else if (linkSettings.JoinAlgo == EJoinAlgoType::GraceJoin) { + mode = EHashJoinMode::GraceAndSelf; + } + + bool leftAny = linkSettings.LeftHints.contains("any"); + bool rightAny = linkSettings.RightHints.contains("any"); + TStringBuf joinType = joinTuple.Type().Value(); TSet> resultKeys; if (joinType != TStringBuf("RightOnly") && joinType != TStringBuf("RightSemi")) { @@ -379,9 +402,16 @@ bool CheckJoinColumns(const TExprBase& node) { } } -TExprBase DqRewriteEquiJoin(const TExprBase& node, EHashJoinMode mode, bool useCBO, TExprContext& ctx, const TTypeAnnotationContext& typeCtx) { - int dummyJoinCounter; - return DqRewriteEquiJoin(node, mode, useCBO, ctx, typeCtx, dummyJoinCounter); +TExprBase DqRewriteEquiJoin( + const TExprBase& node, + EHashJoinMode mode, + bool useCBO, + TExprContext& ctx, + const TTypeAnnotationContext& typeCtx, + const TOptimizerHints& hints +) { + int dummyJoinCounter = 0; + return DqRewriteEquiJoin(node, mode, useCBO, ctx, typeCtx, dummyJoinCounter, hints); } /** @@ -389,7 +419,15 @@ TExprBase DqRewriteEquiJoin(const TExprBase& node, EHashJoinMode mode, bool useC * physical stages with join operators. * Potentially this optimizer can also perform joins reorder given cardinality information. */ -TExprBase DqRewriteEquiJoin(const TExprBase& node, EHashJoinMode mode, bool useCBO, TExprContext& ctx, const TTypeAnnotationContext& typeCtx, int& joinCounter) { +TExprBase DqRewriteEquiJoin( + const TExprBase& node, + EHashJoinMode mode, + bool /* useCBO */, + TExprContext& ctx, + const TTypeAnnotationContext& typeCtx, + int& joinCounter, + const TOptimizerHints& hints +) { if (!node.Maybe()) { return node; } @@ -406,7 +444,8 @@ TExprBase DqRewriteEquiJoin(const TExprBase& node, EHashJoinMode mode, bool useC } auto joinTuple = equiJoin.Arg(equiJoin.ArgCount() - 2).Cast(); - auto result = BuildDqJoin(joinTuple, inputs, mode, useCBO, ctx, typeCtx); + TVector dummy; + auto result = BuildDqJoin(joinTuple, inputs, mode, ctx, typeCtx, dummy, hints); if (!result) { return node; } diff --git a/ydb/library/yql/dq/opt/dq_opt_join.h b/ydb/library/yql/dq/opt/dq_opt_join.h index 50e7fe21ac5e..ecafe563cd84 100644 --- a/ydb/library/yql/dq/opt/dq_opt_join.h +++ b/ydb/library/yql/dq/opt/dq_opt_join.h @@ -4,6 +4,7 @@ #include #include +#include namespace NYql { @@ -12,9 +13,9 @@ struct TRelOptimizerNode; namespace NDq { -NNodes::TExprBase DqRewriteEquiJoin(const NNodes::TExprBase& node, EHashJoinMode mode, bool useCBO, TExprContext& ctx, const TTypeAnnotationContext& typeCtx); +NNodes::TExprBase DqRewriteEquiJoin(const NNodes::TExprBase& node, EHashJoinMode mode, bool useCBO, TExprContext& ctx, const TTypeAnnotationContext& typeCtx, const TOptimizerHints& hints = {}); -NNodes::TExprBase DqRewriteEquiJoin(const NNodes::TExprBase& node, EHashJoinMode mode, bool useCBO, TExprContext& ctx, const TTypeAnnotationContext& typeCtx, int& joinCounter); +NNodes::TExprBase DqRewriteEquiJoin(const NNodes::TExprBase& node, EHashJoinMode mode, bool useCBO, TExprContext& ctx, const TTypeAnnotationContext& typeCtx, int& joinCounter, const TOptimizerHints& hints = {}); NNodes::TExprBase DqBuildPhyJoin(const NNodes::TDqJoin& join, bool pushLeftStage, TExprContext& ctx, IOptimizationContext& optCtx);