diff --git a/paddle/cinn/hlir/framework/pir/CMakeLists.txt b/paddle/cinn/hlir/framework/pir/CMakeLists.txt index b2c3edfa06673..c764e57995f2d 100755 --- a/paddle/cinn/hlir/framework/pir/CMakeLists.txt +++ b/paddle/cinn/hlir/framework/pir/CMakeLists.txt @@ -8,6 +8,7 @@ if(NOT CINN_ONLY) op_lowering_impl.cc op_mapper.cc op_lowering_util.cc - trivial_op.cc + trivial_op_impl.cc + trivial_op_util.cc compilation_task.cc) endif() diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc index 5f9b2428ac5e1..847115bf8dbbf 100644 --- a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc +++ b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc @@ -22,7 +22,7 @@ #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/framework/compile_error.h" #include "paddle/cinn/hlir/framework/pir/op_lowering_util.h" -#include "paddle/cinn/hlir/framework/pir/trivial_op.h" +#include "paddle/cinn/hlir/framework/pir/trivial_op_impl.h" #include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/cinn/hlir/op/external_api_registry.h" #include "paddle/cinn/hlir/pe/map_expr_to_ir.h" diff --git a/paddle/cinn/hlir/framework/pir/trivial_op.cc b/paddle/cinn/hlir/framework/pir/trivial_op.cc deleted file mode 100644 index a0ad3ad799869..0000000000000 --- a/paddle/cinn/hlir/framework/pir/trivial_op.cc +++ /dev/null @@ -1,1243 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/cinn/hlir/framework/pir/trivial_op.h" - -#include - -#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" -#include "paddle/cinn/hlir/framework/compile_error.h" -#include "paddle/cinn/hlir/framework/pir/op_lowering_util.h" -#include "paddle/cinn/hlir/framework/pir/utils.h" -#include "paddle/cinn/hlir/op/external_api_registry.h" -#include "paddle/cinn/hlir/pe/map_expr_to_ir.h" -#include "paddle/cinn/ir/dim.h" -#include "paddle/cinn/ir/group_schedule/base_group_scheduler.h" -#include "paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h" -#include "paddle/cinn/ir/schedule/ir_schedule.h" -#include "paddle/cinn/ir/schedule/ir_schedule_util.h" -#include "paddle/cinn/lang/placeholder.h" -#include "paddle/cinn/optim/schedule_block_dce.h" -#include "paddle/cinn/optim/transform_gpu_forloop.h" -#include "paddle/common/ddim.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" -#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" - -// #include "paddle/cinn/frontend/group_pattern_util.h" - -namespace cinn { -namespace hlir { -namespace framework { -namespace pir { -namespace trivial_fusion_detail { - -namespace ComposeUtils { - -template -std::vector ConcatVector(const std::vector& first, - const std::vector& second) { - std::vector result = first; - result.insert(result.end(), second.begin(), second.end()); - return result; -} - -std::vector ExprVec2VarVec(const std::vector& in) { - std::vector out; - for (auto& expr : in) { - out.push_back(expr.as_var_ref()); - } - return out; -} - -std::vector VarVec2ExprVec(const std::vector& in) { - return std::vector(in.begin(), in.end()); -} - -std::vector GetEachTensorLoadExpr(const ir::Expr& body, - const ir::Tensor& tensor) { - VLOG(4) << "Start GetEachTensorLoadExpr: " << tensor; - std::set load_exprs = cinn::ir::ir_utils::CollectIRNodesWithoutTensor( - body, [&tensor](const Expr* expr) { - return expr->As() && expr->As()->is_addr_tensor() && - expr->As()->tensor.as_tensor_ref()->name == - tensor->name; - }); - for (auto& t : load_exprs) { - VLOG(4) << "GetEachTensorLoadExpr: " << t << " " << t.ptr(); - } - return std::vector(load_exprs.begin(), load_exprs.end()); -} - -struct MappingTargetExprToDestExprMutator : public ir::IRMutator<> { - explicit MappingTargetExprToDestExprMutator(const ir::Expr& source, - const ir::Expr& dest) - : source_(source), dest_(dest) {} - - void operator()(Expr* expr) { IRMutator::Visit(expr, expr); } - - private: - void Visit(const ir::Load* load, Expr* op) override { - VLOG(4) << "SubstitudeTargetExprWithDestExpr: " << load << " vs " - << source_.ptr(); - if (load == source_.ptr()) { - VLOG(4) << "substitude find!"; - *op = dest_; - } else { - IRMutator::Visit(load, op); - } - } - void Visit(const ir::Store* store, Expr* op) override { - VLOG(4) << "SubstitudeTargetExprWithDestExpr: " << store << " vs " - << source_.ptr(); - if (store == source_.ptr()) { - VLOG(4) << "substitude find!"; - *op = dest_; - } else { - IRMutator::Visit(store, op); - } - } - void Visit(const ir::Reduce* reduce, Expr* op) override { - VLOG(4) << "SubstitudeTargetExprWithDestExpr: " << reduce << " vs " - << source_.ptr(); - if (reduce == source_.ptr()) { - VLOG(4) << "substitude find!"; - *op = dest_; - } else { - IRMutator::Visit(reduce, op); - } - } - - private: - ir::Expr source_; - ir::Expr dest_; -}; - -bool CheckIterEq(const std::vector& up_iter, - const std::vector& down_iter) { - if (up_iter.size() != down_iter.size()) return false; - - for (int i = 0; i < up_iter.size(); ++i) { - const ir::Var& up_iter_var = up_iter[i]; - const ir::Var& down_iter_var = down_iter[i]; - - if (up_iter_var != down_iter_var) return false; - if (up_iter_var->lower_bound.as_int64() != - down_iter_var->lower_bound.as_int64()) - return false; - if (up_iter_var->upper_bound.as_int64() != - down_iter_var->upper_bound.as_int64()) - return false; - } - return true; -} - -static ir::Expr CopyedReplaceExpr(const Expr& source, - const std::vector& replaced, - const std::vector& candidates) { - VLOG(4) << "Copyed Replace Expr Start"; - CHECK_EQ(replaced.size(), candidates.size()) - << "In ReplaceExpr, the size of Vars to be replaced must be equal to " - "the " - "size of cadidate Exprs! Please check."; - auto copyed_source = ir::ir_utils::IRCopy(source); - if (replaced.empty()) return copyed_source; - std::map replacing_map; - for (int i = 0; i < replaced.size(); ++i) { - // If the Var to be replaced is equal to the candidate, we skip it. - if (candidates[i].is_var() && candidates[i].as_var_ref() == replaced[i]) - continue; - replacing_map[replaced[i]] = candidates[i]; - } - ir::MappingVarToExprMutator mapper(replacing_map); - mapper(©ed_source); - VLOG(4) << "Copyed Replace Expr End"; - return copyed_source; -} - -static void SubstitudeTargetExprWithDestExpr(const ir::Expr& source, - const ir::Expr& dest, - ir::Expr* body) { - VLOG(4) << "Start SubstitudeTargetExprWithDestExpr"; - MappingTargetExprToDestExprMutator mapper(source, dest); - mapper(body); - VLOG(4) << "End SubstitudeTargetExprWithDestExpr"; -} - -static ir::Expr SubstitudeIndexVector(const Expr& source, - const std::vector& load_vars, - const std::vector& indices) { - return CopyedReplaceExpr(source, load_vars, indices); -} - -template -static void ReplaceDownstreamLoadExprWithUpstreamComputeBody( - const FusionOp& upstream, - const ir::Expr& downstream_load_expr, - ir::Expr* downstream_body) { - ComposeUtils::SubstitudeTargetExprWithDestExpr( - downstream_load_expr, - ComposeUtils::SubstitudeIndexVector( - GetComputeBody(upstream), - GetOutputIters(upstream), - downstream_load_expr.As()->indices), - downstream_body); -} - -} // namespace ComposeUtils - -namespace SearchUtils { - -// 1. search by type. DONE -// 2. search by value. DONE -// 3. search by father. TODO - -using ExprSet = std::vector; -using Func = std::function; -struct Mapping { - Func f_; - std::string name; - explicit Mapping(Func f, std::string s = "") { - f_ = f; - name = s; - } - ExprSet operator()(const ir::Expr& x) const { return f_(x); } - ir::Expr GetSingle(const ir::Expr& x) const { - Mapping call = (*this) * Mapping::GetIdentity(); - const auto& o = call.operator()(x); - if (o.size() != 1) { - PADDLE_THROW("Try to get single result, but we get %d.", o.size()); - } - return *o.begin(); - } - Mapping operator*(Mapping x) const { - auto new_f = [self = *this, x = x](const ir::Expr& e) -> ExprSet { - const auto& rs = self.f_(e); - VLOG(6) << "Mapping Info : " << self.name; - VLOG(6) << " Inputs :" << e; - for (const auto& r : rs) { - VLOG(6) << " Outputs : \n" << r; - } - std::vector res; - for (const auto& r : rs) { - const auto& x_res = x.f_(r); - res.insert(res.begin(), x_res.begin(), x_res.end()); - } - return res; - }; - return Mapping(std::function(new_f), x.name + "*" + this->name); - } - static Mapping GetIdentity() { - return Mapping([](const ir::Expr& e) { return std::vector{e}; }, - "identity"); - } -}; - -Mapping Identity = Mapping::GetIdentity(); - -template -Mapping Collector(Teller t, std::string name = "") { - return Mapping( - [=](const ir::Expr& x) -> ExprSet { - const auto& rs = cinn::ir::ir_utils::CollectIRNodesWithoutTensor(x, t); - return std::vector(rs.begin(), rs.end()); - }, - name); -} - -template -Mapping FilterMaker(FilterFunc t, std::string name = "SomeFilter") { - return Mapping( - [=](const ir::Expr& x) -> ExprSet { - if (t(x)) { - return {x}; - } - return {}; - }, - name); -} - -Mapping Store2Value = Mapping( - [](const ir::Expr& e) -> ExprSet { - if (e.As()) { - return {e.As()->value}; - } - return {}; - }, - "Store2Value"); - -Mapping Realizer2ScheduleBlock = Mapping( - [](const ir::Expr& e) -> ExprSet { - if (e.As()) { - return {e.As()->schedule_block}; - } - return {}; - }, - "Realizer2ScheduleBlock"); - -Mapping ScheduleBlock2Body = Mapping( - [](const ir::Expr& e) -> ExprSet { - if (e.As()) { - return {e.As()->body}; - } - return {}; - }, - "ScheduleBlock2Body"); - -Mapping ScheduleBlockRealizeNotRoot = FilterMaker( - [](const ir::Expr& e) -> bool { - return (e.As() && - e.As() - ->schedule_block.As() - ->name.find("root") == std::string::npos); - }, - "ScheduleBlockRealizeNotRoot"); - -Mapping ScheduleBlockRealizeIsNotInit = FilterMaker( - [](const ir::Expr& e) -> bool { - return (e.As() && - e.As() - ->schedule_block.As() - ->name.find("__reduce_init") == std::string::npos); - }, - "ScheduleBlockRealizeIsNotInit"); - -Mapping ScheduleBlockRealizeIsInit = FilterMaker( - [](const ir::Expr& e) -> bool { - return (e.As() && - e.As() - ->schedule_block.As() - ->name.find("__reduce_init") != std::string::npos); - }, - "ScheduleBlockRealizeIsInit"); - -Mapping IsFor = FilterMaker( - [](const ir::Expr& e) -> bool { return e.As(); }, "IsFor"); - -Mapping ChildScheduleBlocks = - Collector([](const ir::Expr* e) { return e->As(); }, - "ChildScheduleBlocks"); - -Mapping ChildScheduleBlockRealizes = - Collector( - [](const ir::Expr* e) { return e->As(); }, - "ChildScheduleBlockRealizes") * - ScheduleBlockRealizeNotRoot; - -Mapping IsForIterVar(const ir::Var& var) { - return FilterMaker( - [var = var](const ir::Expr& e) -> bool { - return e.As() && e.As()->loop_var == var; - }, - "IsForIterVar"); -} - -Mapping For2Min = - Mapping([](const ir::Expr& e) -> ExprSet { return {e.As()->min}; }, - "For2Min"); - -Mapping For2Max = Mapping( - [](const ir::Expr& e) -> ExprSet { return {e.As()->extent}; }, - "For2Max"); - -Mapping ChildStores = Collector( - [](const ir::Expr* e) { return e->As(); }, "ChildStores"); - -Mapping ChildTensorLoads = Collector( - [](const ir::Expr* e) { - return e->As() && e->As()->is_addr_tensor(); - }, - "ChildLoads"); - -Mapping ChildTensorStores = Collector( - [](const ir::Expr* e) { - return e->As() && e->As()->is_addr_tensor(); - }, - "ChildTensorStores"); - -Mapping FilterLoadByTensor(const ir::Tensor& tensor) { - return FilterMaker( - [tensor = tensor](const ir::Expr& e) -> bool { - return e.As() && - e.As()->tensor.as_tensor_ref()->name == tensor->name; - }, - "FilterLoadByTensor(" + tensor->name + ")"); -} - -Mapping ChildFors = - Collector([](const ir::Expr* e) { return e->As(); }, "ChildFors"); - -Mapping FindFather(const ir::Expr& root) { - const auto& f = [&](const auto& child) -> ExprSet { - Mapping find_child = - Collector([child](const ir::Expr* e) { return *e == child; }); - const auto& father_collector = Collector( - [&](const ir::Expr* current) { return !find_child(*current).empty(); }); - return father_collector(root); - }; - return Mapping(f, "FindFather"); -} - -template -std::vector MapVector(const std::vector& as, M func) { - std::vector res; - for (const auto& a : as) { - res.push_back(func(a)); - } - return res; -} - -} // namespace SearchUtils - -namespace TransformerUtils { -using TransformFunc = std::function; -struct Transformer { - TransformFunc f_; - explicit Transformer(TransformFunc f) { f_ = f; } - ir::Expr operator()(const ir::Expr& x) const { return f_(x); } - Transformer operator*(const Transformer& x) const { - auto new_f = [self = *this, x = x](const ir::Expr& e) -> ir::Expr { - const auto& rs = self.f_(e); - return x.f_(rs); - }; - return Transformer(std::function(new_f)); - } -}; - -Transformer Identity = Transformer([](const ir::Expr& e) { return e; }); -Transformer WrapForTransformer(const ir::Var& v) { - const auto& f = [=](const ir::Expr& e) -> ir::Expr { - auto block = e; - if (!block.As()) { - block = ir::Block::Make({e}); - } - return ir::For::Make(v, - v->lower_bound, - v->upper_bound, - ir::ForType::Serial, - ir::DeviceAPI::Host, - block); - }; - return Transformer(f); -} - -Transformer WrapForsTransformer(const std::vector& vs) { - const auto& f = [&](const ir::Expr& e) -> ir::Expr { - Transformer t = Identity; - for (const auto& v : vs) { - t = WrapForTransformer(v) * t; - } - return t(e); - }; - return Transformer(f); -} - -Transformer ChangeTensorLoadTransformer(const ir::Tensor& tensor, - const ir::Expr dst_load) { - const auto& f = [&](const ir::Expr& e) -> ir::Expr { - auto copied_e = ir::ir_utils::IRCopy(e); - const auto& load = (SearchUtils::ChildTensorLoads * - SearchUtils::FilterLoadByTensor(tensor)) - .GetSingle(copied_e); - ComposeUtils::MappingTargetExprToDestExprMutator(load, dst_load)(&copied_e); - return copied_e; - }; - return Transformer(f); -} - -void ReplaceTarget(ir::Expr* e, const ir::Expr& t, const ir::Expr dst) { - ComposeUtils::MappingTargetExprToDestExprMutator(t, dst)(e); -} - -Transformer WrapStoreTransformer(const ir::Tensor& tensor, - const std::vector& indices) { - const auto& f = [=](const ir::Expr& e) -> ir::Expr { - return ir::Store::Make(tensor, e, indices); - }; - return Transformer(f); -} - -std::vector CreateInnerBlockVars( - const std::vector& block_vars) { - int i = 0; - std::vector vars; - for (const auto& v : block_vars) { - vars.emplace_back("inner_block_" + std::to_string(i++)); - } - return vars; -} - -Transformer ChangeVarTransformer(const std::vector& target_vars, - const std::vector& dest_vars) { - const auto& f = [=](const ir::Expr& e) -> ir::Expr { - return ComposeUtils::CopyedReplaceExpr( - e, - target_vars, - std::vector(dest_vars.begin(), dest_vars.end())); - }; - return Transformer(f); -} - -Transformer SubstitudeByScheduleBlockRealize(const ir::Expr& realize) { - const auto& f = [=](const ir::Expr& e) -> ir::Expr { - const auto& iter_values = - realize.As()->iter_values; - const auto& iter_vars = realize.As() - ->schedule_block.As() - ->iter_vars; - return TransformerUtils::ChangeVarTransformer( - iter_vars, ComposeUtils::ExprVec2VarVec(iter_values))(e); - }; - return Transformer(f); -} - -Transformer WrapScheduleRealizer(const std::vector& block_vars, - const std::string& tensor_name) { - const auto& f = [=](const ir::Expr& e) -> ir::Expr { - if (e.As()) { - PADDLE_THROW("please input a non-schedule block expr."); - } - const auto& inner_block_var = CreateInnerBlockVars(block_vars); - const auto& replaced_e = - ChangeVarTransformer(block_vars, inner_block_var)(e); - const auto& schedule_block = ir::ScheduleBlock::Make( - inner_block_var, {}, {}, tensor_name, replaced_e); - const auto& schedule_realizer = ir::ScheduleBlockRealize::Make( - std::vector(block_vars.begin(), block_vars.end()), - schedule_block); - return schedule_realizer; - }; - return Transformer(f); -} - -} // namespace TransformerUtils - -std::vector GetOpPatternKindVector( - const std::vector<::pir::Operation*>& ops) { - const auto& op_pattern_map = - Operator::GetAttrs("OpPattern"); - std::vector op_patterns; - const auto ConvertToPattern = [&op_pattern_map](const ::pir::Operation* op) { - const std::string cinn_op_name = CompatibleInfo::OpName(*op); - const hlir::framework::Operator* cinn_op = Operator::Get(cinn_op_name); - return op_pattern_map[cinn_op]; - }; - std::transform(ops.begin(), - ops.end(), - std::back_inserter(op_patterns), - ConvertToPattern); - return op_patterns; -} - -template -void SequenceMutator(const std::vector& as, C* acc, const Func& mutator) { - VLOG(4) << "SequenceTransform Init: " << acc; - for (int i = 0; i < as.size(); ++i) { - mutator(as[i], acc); - VLOG(4) << "SequenceTransform Iter: " << acc; - } -} - -inline bool IsTrivialKind(OpPatternKind kind) { - return kind == OpPatternKind::kElementWise || - kind == OpPatternKind::kBroadcast || kind == OpPatternKind::kInjective; -} - -void CheckFusionInputValid(const std::vector& op_compute_bodies, - const std::vector& op_patterns) { - if (VLOG_IS_ON(4)) { - for (const auto& func : op_compute_bodies) { - VLOG(4) << "TrivialOpFusion: {FuncBody is} :" << func; - } - for (const auto& op_ptn : op_patterns) { - VLOG(4) << "OpPattern is :" << op_ptn; - } - } - VLOG(4) << " op_patterns.size() = " << op_compute_bodies.size(); - VLOG(4) << "op_compute_bodies.size() = " << op_patterns.size(); - PADDLE_ENFORCE_EQ( - op_patterns.size(), op_compute_bodies.size(), "ops and size not equal"); -} - -struct TrivialOp { - public: - explicit TrivialOp(const ir::Expr& origin_func_body) { - func_body = ir::ir_utils::IRCopy(origin_func_body); - } - - TrivialOp(const TrivialOp& trivial_op) { - func_body = trivial_op.GetFuncBody(); - } - - void _SetFuncBody(ir::Expr new_body) { func_body = new_body; } - - ir::Expr* _GetFuncBodyPointer() { return &func_body; } - - ir::Expr GetFuncBody() const { return func_body; } - - private: - ir::Expr func_body; -}; - -struct ReduceOp { - public: - explicit ReduceOp(const ir::Expr& origin_func_body) { - func_body = ir::ir_utils::IRCopy(origin_func_body); - } - - ReduceOp(const ReduceOp& reduce_op) { func_body = reduce_op.GetFuncBody(); } - - void _SetFuncBody(ir::Expr new_body) { func_body = new_body; } - - ir::Expr GetFuncBody() const { return func_body; } - - ir::Expr* _GetFuncBodyPointer() { return &func_body; } - - private: - ir::Expr func_body; -}; - -using FusibleOp = std::variant; - -ir::Expr _GetRootExpr(const FusibleOp& op) { - return std::visit([](auto&& arg) { return arg.GetFuncBody(); }, op); -} - -void _SetFuncBody(FusibleOp& op, ir::Expr new_body) { - std::visit([&](auto&& arg) { arg._SetFuncBody(new_body); }, op); -} - -ir::Expr GetComputeBody(const FusibleOp& op) { - struct Visitor { - ir::Expr operator()(const ReduceOp& op) { - const auto& compute_realize = (SearchUtils::ChildScheduleBlockRealizes * - SearchUtils::ScheduleBlockRealizeIsNotInit) - .GetSingle(_GetRootExpr(op)); - const auto& compute_body = - (SearchUtils::ChildStores * SearchUtils::Store2Value) - .GetSingle(compute_realize); - return TransformerUtils::SubstitudeByScheduleBlockRealize( - compute_realize)(compute_body); - } - ir::Expr operator()(const TrivialOp& op) { - const auto& compute_realize = - (SearchUtils::ChildScheduleBlockRealizes).GetSingle(_GetRootExpr(op)); - const auto& compute_body = - (SearchUtils::ChildStores * SearchUtils::Store2Value) - .GetSingle(compute_realize); - return TransformerUtils::SubstitudeByScheduleBlockRealize( - compute_realize)(compute_body); - } - }; - VLOG(4) << "GetComputeBody"; - return std::visit(Visitor(), op); -} - -ir::Tensor GetOutputTensor(const FusibleOp& op) { - struct Visitor { - ir::Tensor operator()(const ReduceOp& op) { - const auto& compute_body = (SearchUtils::ChildScheduleBlockRealizes * - SearchUtils::ScheduleBlockRealizeIsNotInit * - SearchUtils::ChildStores) - .GetSingle(_GetRootExpr(op)); - return compute_body.As()->tensor.as_tensor_ref(); - } - ir::Tensor operator()(const TrivialOp& op) { - VLOG(4) << "Root is :" << _GetRootExpr(op); - VLOG(4) << "Searched is:" - << SearchUtils::ChildScheduleBlockRealizes.GetSingle( - _GetRootExpr(op)); - const auto& compute_body = - (SearchUtils::ChildScheduleBlockRealizes * SearchUtils::ChildStores) - .GetSingle(_GetRootExpr(op)); - return compute_body.As()->tensor.as_tensor_ref(); - } - }; - VLOG(4) << "GetOutputTensor"; - return std::visit(Visitor(), op); -} - -ir::Expr _GetOriginalStoreValuePointer(const FusibleOp& op) { - struct Visitor { - ir::Expr operator()(const ReduceOp& op) { - return (SearchUtils::ChildScheduleBlockRealizes * - SearchUtils::ScheduleBlockRealizeIsNotInit * - SearchUtils::ChildStores * SearchUtils::Store2Value) - .GetSingle(_GetRootExpr(op)); - } - ir::Expr operator()(const TrivialOp& op) { - return (SearchUtils::ChildScheduleBlockRealizes * - SearchUtils::ChildStores * SearchUtils::Store2Value) - .GetSingle(_GetRootExpr(op)); - } - }; - return std::visit(Visitor(), op); -} - -std::vector AppendBound(const std::vector vars, - const ir::Expr& root) { - using namespace SearchUtils; - return MapVector(vars, [&](const auto& v) -> ir::Var { - VLOG(4) << "AppendBound for " << v; - VLOG(4) << "lower: " - << (ChildFors * IsForIterVar(v) * For2Min).GetSingle(root); - VLOG(4) << "upper: " - << (ChildFors * IsForIterVar(v) * For2Max).GetSingle(root); - return ir::Var((ChildFors * IsForIterVar(v) * For2Min).GetSingle(root), - (ChildFors * IsForIterVar(v) * For2Max).GetSingle(root), - v->name); - }); -} - -std::vector GetOutputIters(const FusibleOp& op) { - struct Visitor { - std::vector operator()(const ReduceOp& op) { - ir::Expr init_block_realize = (SearchUtils::ChildScheduleBlockRealizes * - SearchUtils::ScheduleBlockRealizeIsInit) - .GetSingle(_GetRootExpr(op)); - const std::vector& outer_iter_expr = - init_block_realize.As()->iter_values; - return trivial_fusion_detail::ComposeUtils::ExprVec2VarVec( - outer_iter_expr); - } - std::vector operator()(const TrivialOp& op) { - const auto& compute_realize = - (SearchUtils::ChildScheduleBlockRealizes).GetSingle(_GetRootExpr(op)); - const std::vector& outer_iter_expr = - compute_realize.As()->iter_values; - return trivial_fusion_detail::ComposeUtils::ExprVec2VarVec( - outer_iter_expr); - } - }; - return AppendBound(std::visit(Visitor(), op), _GetRootExpr(op)); -} - -std::vector GetAllIterVars(const ReduceOp& op) { - ir::Expr compute_schedule_block_realize = - (SearchUtils::ChildScheduleBlockRealizes * - SearchUtils::ScheduleBlockRealizeIsNotInit) - .GetSingle(_GetRootExpr(op)); - - const std::vector& all_iter_expr = - compute_schedule_block_realize.As() - ->iter_values; - return ComposeUtils::ExprVec2VarVec(all_iter_expr); -} - -std::vector GetReduceIters(const ReduceOp& op) { - // Iter Vars not appearing in outer_iter_vars are pushed into - // reduce_iter_vars - std::vector all_iter_vars = GetAllIterVars(op); - std::vector outer_iter_vars = GetOutputIters(op); - std::vector reduce_iter_vars; - - for (auto& iter_var : all_iter_vars) { - if (!(std::find(outer_iter_vars.begin(), outer_iter_vars.end(), iter_var) != - outer_iter_vars.end())) { - reduce_iter_vars.push_back(iter_var); - } - } - return AppendBound(reduce_iter_vars, _GetRootExpr(op)); -} - -ir::Expr GetInitExpr(const ReduceOp& op) { - return (SearchUtils::ChildScheduleBlockRealizes * - SearchUtils::ScheduleBlockRealizeIsInit * SearchUtils::ChildStores * - SearchUtils::Store2Value) - .GetSingle(op.GetFuncBody()); -} - -ir::Expr* _GetFuncBodyPointer(FusibleOp op) { - return std::visit([&](auto&& arg) { return arg._GetFuncBodyPointer(); }, op); -} - -ir::Expr CopyReduceBody(const FusibleOp& downstream, const ReduceOp& upstream) { - struct Visitor { - ir::Expr operator()(const ReduceOp& op) { - return ir::ir_utils::IRCopy(op.GetFuncBody()); - } - ir::Expr operator()(const TrivialOp& op) { - PADDLE_THROW("TrivialOp cannot be copied."); - } - }; - return std::visit(Visitor(), downstream); -} - -ir::Expr CreateReduceExpr( - const std::vector& output_iters, - const std::vector& reduce_iters, - const ir::Expr& init_body, // relay on output_iters - const ir::Expr& reduce_body, // relay on output_iters + reduce_iters - const ir::Tensor& new_write_tensor, - const ir::Tensor& origin_write_tensor) { - VLOG(4) << "CreateReduceExpr Start."; - const std::vector indice_expr = - std::vector(output_iters.begin(), output_iters.end()); - const auto& new_init_tensor = ir::Tensor(new_write_tensor->name + "__init", - new_write_tensor->type(), - new_write_tensor->shape, - new_write_tensor->domain, - new_write_tensor->operation); - - const auto& init_schedule_block = - (TransformerUtils::WrapStoreTransformer(new_init_tensor, indice_expr) * - TransformerUtils::WrapScheduleRealizer( - output_iters, new_init_tensor->name))(init_body); - - const auto& reduce_schedule_block = - (TransformerUtils::ChangeTensorLoadTransformer( - origin_write_tensor, new_write_tensor(indice_expr)) * - TransformerUtils::WrapStoreTransformer(new_write_tensor, indice_expr) * - TransformerUtils::WrapScheduleRealizer( - ComposeUtils::ConcatVector(output_iters, reduce_iters), - new_write_tensor->name) * - TransformerUtils::WrapForsTransformer(reduce_iters))(reduce_body); - - const auto& gather_body = ir::Block::Make( - std::vector({init_schedule_block, reduce_schedule_block})); - return ir::Block::Make( - {(TransformerUtils::WrapForsTransformer(output_iters) * - TransformerUtils::WrapScheduleRealizer({}, "root"))(gather_body)}); -} - -ir::Expr CreateTrivialExpr(const std::vector& output_iters, - const ir::Expr& function_body, - const ir::Tensor& new_write_tensor) { - VLOG(4) << "CreateTrivialExpr Start."; - const std::vector indice_expr = - std::vector(output_iters.begin(), output_iters.end()); - const auto& compute_body_schedule_block = - (TransformerUtils::WrapStoreTransformer(new_write_tensor, indice_expr) * - TransformerUtils::WrapScheduleRealizer( - output_iters, new_write_tensor->name))(function_body); - return ir::Block::Make({(TransformerUtils::WrapForsTransformer(output_iters) * - TransformerUtils::WrapScheduleRealizer({}, "root"))( - ir::Block::Make({compute_body_schedule_block}))}); -} - -ir::Expr CreateExprWithNewComputeBody(FusibleOp fusible_op, - ir::Expr new_compute_body) { - struct Visitor { - ir::Expr operator()(const ReduceOp& op) { - return CreateReduceExpr(GetOutputIters(op), - GetReduceIters(op), - GetInitExpr(op), - compute_body_, - GetOutputTensor(op), - GetOutputTensor(op)); - } - ir::Expr operator()(const TrivialOp& op) { - return CreateTrivialExpr( - GetOutputIters(op), compute_body_, GetOutputTensor(op)); - } - - ir::Expr compute_body_; - explicit Visitor(ir::Expr compute_body) { compute_body_ = compute_body; } - }; - VLOG(4) << "CreateExprWithNewComputeBody"; - return std::visit(Visitor(new_compute_body), fusible_op); -} - -struct FusionNode { - FusibleOp fusible_op; - ::pir::Operation* expr_related_op; - - std::unordered_map upstream; - std::unordered_map downstream; - - explicit FusionNode(FusibleOp fusible_op) : fusible_op(fusible_op) {} - - static std::string GetTensorCounter() { - static int i = 0; - return std::to_string(i++); - } - - void replace_topo_structure_of_fused_nodes(FusionNode* fused_up_node, - FusionNode* fused_down_node) { - upstream.insert(fused_up_node->upstream.begin(), - fused_up_node->upstream.end()); - upstream.insert(fused_down_node->upstream.begin(), - fused_down_node->upstream.end()); - upstream.erase(fused_up_node); - - downstream.insert(fused_up_node->downstream.begin(), - fused_up_node->downstream.end()); - downstream.insert(fused_down_node->downstream.begin(), - fused_down_node->downstream.end()); - downstream.erase(fused_down_node); - - expr_related_op = fused_down_node->expr_related_op; - - for (const auto& pair_data : upstream) { - FusionNode* upstream_node = pair_data.first; - ::pir::Value related_value = pair_data.second; - if (upstream_node->downstream.find(fused_up_node) != - upstream_node->downstream.end()) { - upstream_node->downstream.erase(fused_up_node); - } - if (upstream_node->downstream.find(fused_down_node) != - upstream_node->downstream.end()) { - upstream_node->downstream.erase(fused_down_node); - } - upstream_node->downstream[this] = related_value; - } - - for (const auto& pair_data : downstream) { - FusionNode* downstream_node = pair_data.first; - ::pir::Value related_value = pair_data.second; - if (downstream_node->upstream.find(fused_up_node) != - downstream_node->upstream.end()) { - downstream_node->upstream.erase(fused_up_node); - } - if (downstream_node->upstream.find(fused_down_node) != - downstream_node->upstream.end()) { - downstream_node->upstream.erase(fused_down_node); - } - downstream_node->upstream[this] = related_value; - } - } - - bool IsTrivial() const { - return std::holds_alternative(fusible_op); - } -}; - -template -DownStreamOp TrivalxOther_Fusion(TrivialOp upstream, DownStreamOp downstream) { - VLOG(4) << "Trivial x OtherFusion begin."; - - const auto& replaced_tensor = GetOutputTensor(upstream); - VLOG(4) << "upstream is " << upstream.GetFuncBody(); - VLOG(4) << "downstream is " << downstream.GetFuncBody(); - - DownStreamOp fused(ir::ir_utils::IRCopy(downstream.GetFuncBody())); - ir::Expr origin_compute_body = _GetOriginalStoreValuePointer(fused); - SequenceMutator( - ComposeUtils::GetEachTensorLoadExpr(origin_compute_body, replaced_tensor), - &origin_compute_body, - [&](const ir::Expr& downstream_load_expr, ir::Expr* downstream_body) { - ComposeUtils::ReplaceDownstreamLoadExprWithUpstreamComputeBody( - upstream, downstream_load_expr, downstream_body); - }); - - VLOG(4) << "After mutate, compute body: " << origin_compute_body; - VLOG(4) << "TTFusion end:\n" << fused.GetFuncBody(); - return fused; -} - -bool CheckAllLoopRangeEq(ReduceOp reduce_upper, TrivialOp trivial_down) {} - -std::vector TransformReduceLoopRange(const ReduceOp& upstream, - FusibleOp* downstream) { - // downstream will be mutated by this transform. - VLOG(4) << "RRTransform begin"; - VLOG(4) << "Upstream is " << upstream.GetFuncBody(); - ir::Expr modified_downstream_compute_body = GetComputeBody(*downstream); - const auto& load_upstream_expr = ComposeUtils::GetEachTensorLoadExpr( - modified_downstream_compute_body, GetOutputTensor(upstream)); - std::vector results; - ir::Tensor downstream_output_tensor = GetOutputTensor(*downstream); - const auto create_new_tensor = [&](const ir::Tensor& downstream_load_tensor) { - VLOG(4) << "downstream output tensor: " << downstream_output_tensor; - VLOG(4) << "downstream_load_tensor : " << downstream_load_tensor; - return ir::Tensor( - downstream_load_tensor->name + "_" + FusionNode::GetTensorCounter(), - downstream_load_tensor->type(), - downstream_output_tensor->shape, - downstream_output_tensor->domain, - downstream_load_tensor->operation); - }; - - for (const auto& load_tensor : load_upstream_expr) { - const auto& new_tensor = - create_new_tensor(load_tensor.As()->tensor.as_tensor_ref()); - VLOG(4) << "GetInit: " << GetInitExpr(upstream); - VLOG(4) << "GetNewTensor: " << new_tensor; - VLOG(4) << "GetOutputIter: " - << utils::Join(GetOutputIters(*downstream), " "); - VLOG(4) << "GetReduceIter: " << utils::Join(GetReduceIters(upstream), " "); - VLOG(4) << "GetCompute: " - << ComposeUtils::CopyedReplaceExpr( - GetComputeBody(upstream), - GetOutputIters(upstream), - load_tensor.As()->indices); - ir::Expr new_reduce = CreateReduceExpr( - GetOutputIters(*downstream), - GetReduceIters(upstream), - GetInitExpr(upstream), - ComposeUtils::CopyedReplaceExpr(GetComputeBody(upstream), - GetOutputIters(upstream), - load_tensor.As()->indices), - new_tensor, - GetOutputTensor(upstream)); - results.emplace_back(ReduceOp(new_reduce)); - TransformerUtils::ReplaceTarget( - &modified_downstream_compute_body, - load_tensor, - new_tensor(ComposeUtils::VarVec2ExprVec(GetOutputIters(*downstream)))); - } - _SetFuncBody(*downstream, - CreateExprWithNewComputeBody(*downstream, - modified_downstream_compute_body)); - VLOG(4) << "After Replace Downstream Load: \n" << _GetRootExpr(*downstream); - return results; -} - -FusibleOp TrivialFusion(FusionNode* upstream, FusionNode* downstream) { - CHECK(upstream->IsTrivial()); - if (downstream->IsTrivial()) { - return TrivalxOther_Fusion(std::get(upstream->fusible_op), - std::get(downstream->fusible_op)); - } else { - return TrivalxOther_Fusion(std::get(upstream->fusible_op), - std::get(downstream->fusible_op)); - } -} - -FusibleOp SinkTrivialLoopAlign(TrivialOp trivial_op, ReduceOp reduce_op) { - ir::Expr new_trivial_body = ir::ir_utils::IRCopy(trivial_op.GetFuncBody()); - ir::Var last_iter = GetOutputIters(trivial_op).back(); - ir::Expr trivial_last_for = - (SearchUtils::ChildFors * SearchUtils::IsForIterVar(last_iter)) - .GetSingle(new_trivial_body); - ir::Expr new_for_body = trivial_last_for.As()->body; - new_for_body = TransformerUtils::WrapForsTransformer( - GetReduceIters(reduce_op))(new_for_body); - trivial_last_for.As()->body = new_for_body; - return TrivialOp(new_trivial_body); -} - -std::vector ReduceTransformRecursive(FusibleOp root_op, - FusionNode* fusion_tree) { - VLOG(4) << "ReduceTransformRecursive: " << *_GetFuncBodyPointer(root_op); - std::vector result; - for (auto& pair : fusion_tree->upstream) { - auto transformed_nodes = TransformReduceLoopRange( - std::get(pair.first->fusible_op), &root_op); - for (auto& node : transformed_nodes) { - auto child_flatten = ReduceTransformRecursive(node, pair.first); - result.insert(result.end(), child_flatten.begin(), child_flatten.end()); - } - } - VLOG(4) << "Before push_back, is trivial_op: " - << std::holds_alternative(root_op); - result.push_back( - std::holds_alternative(root_op) - ? SinkTrivialLoopAlign( - std::get(root_op), - std::get( - fusion_tree->upstream.begin()->first->fusible_op)) - : root_op); - VLOG(4) << "After push_back."; - return result; -} - -std::vector ReduceTransform(FusionNode* downstream) { - if (downstream->IsTrivial() && downstream->upstream.empty()) { - return {downstream->fusible_op}; - } - auto reduces = ReduceTransformRecursive(downstream->fusible_op, downstream); - return reduces; -} - -FusibleOp CreateFusibleOp(ir::Expr compute_body, OpPatternKind op_pattern) { - if (IsTrivialKind(op_pattern)) { - return TrivialOp(compute_body); - } else { - return ReduceOp(compute_body); - } -} - -struct FusionGraph { - explicit FusionGraph(const std::vector<::pir::Operation*>& ops, - const std::vector& op_compute_bodies) { - // shardable_axes_ = InferShardableAxes(ops); - VLOG(4) << "CreateFusionGraph"; - - const auto& op_patterns = GetOpPatternKindVector(ops); - CheckFusionInputValid(op_compute_bodies, op_patterns); - - std::unordered_map<::pir::Operation*, FusionNode*> op_to_node_map; - - for (int i = 0; i < ops.size(); ++i) { - FusionNode* node = - new FusionNode(CreateFusibleOp(op_compute_bodies[i], op_patterns[i])); - op_to_node_map[ops[i]] = node; - all_fusion_nodes_.emplace(node); - node->expr_related_op = ops[i]; - } - - for (::pir::Operation* op : ops) { - FusionNode* cur_node = op_to_node_map[op]; - - // add upstream nodes - for (int i = 0; i < op->num_operands(); ++i) { - ::pir::Value related_value = op->operand_source(i); - ::pir::Operation* input_op = related_value.defining_op(); - if (op_to_node_map.find(input_op) != op_to_node_map.end()) { - FusionNode* upstream_node = op_to_node_map[input_op]; - cur_node->upstream[upstream_node] = related_value; - upstream_node->downstream[cur_node] = related_value; - } - } - - // add downstream nodes - for (int i = 0; i < op->num_results(); ++i) { - ::pir::Value related_value = op->result(i); - for (auto consumer_it = related_value.use_begin(); - consumer_it != related_value.use_end(); - ++consumer_it) { - ::pir::Operation* output_op = consumer_it->owner(); - if (op_to_node_map.find(output_op) != op_to_node_map.end()) { - FusionNode* downstream_node = op_to_node_map[output_op]; - cur_node->downstream[downstream_node] = related_value; - downstream_node->upstream[cur_node] = related_value; - } - } - } - - if (cur_node->upstream.empty()) { - entrance_nodes_.emplace(cur_node); - } - - if (cur_node->downstream.empty()) { - exit_nodes_.emplace(cur_node); - } - } - - VLOG(4) << "FusionGraph Created, fusion node size: " - << all_fusion_nodes_.size(); - } - - ~FusionGraph() { - for (FusionNode* node : all_fusion_nodes_) { - delete node; - } - } - - std::vector DoFusion() { - VLOG(4) << "Start Trivial Fusion"; - DoTrivialFusion(); - VLOG(4) << "Start R + T and R + R Fusion"; - ReduceLoopTranform(); - return GetExprResults(); - } - - private: - FusionNode* FindTrivialFusibleNode() { - for (FusionNode* node : all_fusion_nodes_) { - if (node->IsTrivial() && !node->downstream.empty()) { - return node; - } - } - return nullptr; - } - - void DoTrivialFusion() { - FusionNode* upstream = nullptr; - // use funcion to get upstream and downstream is save here - // cause we might delete Nodes in this process - while ((upstream = FindTrivialFusibleNode()) != nullptr) { - std::unordered_map fusion_candidate = - upstream->downstream; - upstream->downstream.clear(); - for (const auto& pair_data : fusion_candidate) { - FusionNode* downstream = pair_data.first; - FusionNode* new_node = - new FusionNode(TrivialFusion(upstream, downstream)); - new_node->replace_topo_structure_of_fused_nodes(upstream, downstream); - AppendNode(new_node); - RemoveNode(downstream); - } - RemoveNode(upstream); - } - } - - void ReduceLoopTranform() { - for (FusionNode* node : exit_nodes_) { - auto fusion_nodes = ReduceTransform(node); - fusion_results_.insert( - fusion_results_.end(), fusion_nodes.begin(), fusion_nodes.end()); - } - } - - std::vector GetExprResults() { - std::vector output_exprs; - for (const auto& node : fusion_results_) { - output_exprs.emplace_back(_GetRootExpr(node)); - } - return output_exprs; - } - - void RemoveNode(FusionNode* node) { - if (all_fusion_nodes_.find(node) != all_fusion_nodes_.end()) { - all_fusion_nodes_.erase(node); - } - if (entrance_nodes_.find(node) != entrance_nodes_.end()) { - entrance_nodes_.erase(node); - } - if (exit_nodes_.find(node) != exit_nodes_.end()) { - exit_nodes_.erase(node); - } - delete node; - } - - void AppendNode(FusionNode* node) { - all_fusion_nodes_.emplace(node); - if (node->upstream.empty()) { - entrance_nodes_.emplace(node); - } - - if (node->downstream.empty()) { - exit_nodes_.emplace(node); - } - } - - FusionNode* FindReduceUpstream(FusionNode* node) { - for (const auto& pair_data : node->upstream) { - FusionNode* upstream = pair_data.first; - if (!upstream->IsTrivial()) { - return upstream; - } - } - return nullptr; - } - - private: - std::unordered_set all_fusion_nodes_; - std::vector fusion_results_; - std::unordered_set entrance_nodes_; - std::unordered_set exit_nodes_; - - // std::unordered_map<::pir::Value, ShardableAxes> shardable_axes_; -}; - -} // namespace trivial_fusion_detail - -std::vector OperationFusion( - const std::vector<::pir::Operation*>& ops, - const std::vector& op_compute_bodies) { - trivial_fusion_detail::FusionGraph graph = - trivial_fusion_detail::FusionGraph(ops, op_compute_bodies); - auto output = graph.DoFusion(); - VLOG(4) << "Fusion Result: output size is " << output.size(); - for (const auto& expr : output) { - VLOG(4) << expr; - } - return output; -} - -} // namespace pir -} // namespace framework -} // namespace hlir -} // namespace cinn diff --git a/paddle/cinn/hlir/framework/pir/trivial_op.h b/paddle/cinn/hlir/framework/pir/trivial_op.h deleted file mode 100644 index 14d38cdda088f..0000000000000 --- a/paddle/cinn/hlir/framework/pir/trivial_op.h +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once - -#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" -#include "paddle/cinn/hlir/framework/compile_error.h" -#include "paddle/cinn/hlir/framework/pir/op_lowering_util.h" -#include "paddle/cinn/hlir/framework/pir/utils.h" -#include "paddle/cinn/hlir/op/external_api_registry.h" -#include "paddle/cinn/hlir/pe/map_expr_to_ir.h" -#include "paddle/cinn/ir/dim.h" -#include "paddle/cinn/ir/group_schedule/base_group_scheduler.h" -#include "paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h" -#include "paddle/cinn/ir/schedule/ir_schedule.h" -#include "paddle/cinn/lang/placeholder.h" -#include "paddle/cinn/optim/schedule_block_dce.h" -#include "paddle/cinn/optim/transform_gpu_forloop.h" -#include "paddle/common/ddim.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" -#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" - -namespace cinn { -namespace hlir { -namespace framework { -namespace pir { -std::vector OperationFusion( - const std::vector<::pir::Operation*>& ops, - const std::vector& op_compute_bodies); -} -} // namespace framework -} // namespace hlir -} // namespace cinn diff --git a/paddle/cinn/hlir/framework/pir/trivial_op_impl.cc b/paddle/cinn/hlir/framework/pir/trivial_op_impl.cc new file mode 100644 index 0000000000000..aebda5bf8c1c4 --- /dev/null +++ b/paddle/cinn/hlir/framework/pir/trivial_op_impl.cc @@ -0,0 +1,671 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/framework/pir/trivial_op_impl.h" + +#include + +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/cinn/hlir/framework/compile_error.h" +#include "paddle/cinn/hlir/framework/pir/op_lowering_util.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" +#include "paddle/cinn/hlir/op/external_api_registry.h" +#include "paddle/cinn/hlir/pe/map_expr_to_ir.h" +#include "paddle/cinn/ir/dim.h" +#include "paddle/cinn/ir/group_schedule/base_group_scheduler.h" +#include "paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h" +#include "paddle/cinn/ir/schedule/ir_schedule.h" +#include "paddle/cinn/ir/schedule/ir_schedule_util.h" +#include "paddle/cinn/lang/placeholder.h" +#include "paddle/cinn/optim/schedule_block_dce.h" +#include "paddle/cinn/optim/transform_gpu_forloop.h" +#include "paddle/common/ddim.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" + +// #include "paddle/cinn/frontend/group_pattern_util.h" + +namespace cinn { +namespace hlir { +namespace framework { +namespace pir { +namespace trivial_fusion_detail { + +TrivialOp::TrivialOp(const ir::Expr& origin_func_body) { + func_body = ir::ir_utils::IRCopy(origin_func_body); +} + +TrivialOp::TrivialOp(const TrivialOp& trivial_op) { + func_body = trivial_op.GetFuncBody(); +} + +void TrivialOp::_SetFuncBody(ir::Expr new_body) { func_body = new_body; } + +ir::Expr* TrivialOp::_GetFuncBodyPointer() { return &func_body; } + +ir::Expr TrivialOp::GetFuncBody() const { return func_body; } + +ReduceOp::ReduceOp(const ir::Expr& origin_func_body) { + func_body = ir::ir_utils::IRCopy(origin_func_body); +} + +ReduceOp::ReduceOp(const ReduceOp& reduce_op) { + func_body = reduce_op.GetFuncBody(); +} + +void ReduceOp::_SetFuncBody(ir::Expr new_body) { func_body = new_body; } + +ir::Expr ReduceOp::GetFuncBody() const { return func_body; } + +ir::Expr* ReduceOp::_GetFuncBodyPointer() { return &func_body; } + +using FusibleOp = std::variant; + +ir::Expr _GetRootExpr(const FusibleOp& op) { + return std::visit([](auto&& arg) { return arg.GetFuncBody(); }, op); +} + +void _SetFuncBody(FusibleOp& op, ir::Expr new_body) { + std::visit([&](auto&& arg) { arg._SetFuncBody(new_body); }, op); +} + +ir::Expr GetComputeBody(const FusibleOp& op) { + struct Visitor { + ir::Expr operator()(const ReduceOp& op) { + const auto& compute_realize = (SearchUtils::ChildScheduleBlockRealizes * + SearchUtils::ScheduleBlockRealizeIsNotInit) + .GetSingle(_GetRootExpr(op)); + const auto& compute_body = + (SearchUtils::ChildStores * SearchUtils::Store2Value) + .GetSingle(compute_realize); + return TransformerUtils::SubstitudeByScheduleBlockRealize( + compute_realize)(compute_body); + } + ir::Expr operator()(const TrivialOp& op) { + const auto& compute_realize = + (SearchUtils::ChildScheduleBlockRealizes).GetSingle(_GetRootExpr(op)); + const auto& compute_body = + (SearchUtils::ChildStores * SearchUtils::Store2Value) + .GetSingle(compute_realize); + return TransformerUtils::SubstitudeByScheduleBlockRealize( + compute_realize)(compute_body); + } + }; + VLOG(4) << "GetComputeBody"; + return std::visit(Visitor(), op); +} + +ir::Tensor GetOutputTensor(const FusibleOp& op) { + struct Visitor { + ir::Tensor operator()(const ReduceOp& op) { + const auto& compute_body = (SearchUtils::ChildScheduleBlockRealizes * + SearchUtils::ScheduleBlockRealizeIsNotInit * + SearchUtils::ChildStores) + .GetSingle(_GetRootExpr(op)); + return compute_body.As()->tensor.as_tensor_ref(); + } + ir::Tensor operator()(const TrivialOp& op) { + VLOG(4) << "Root is :" << _GetRootExpr(op); + VLOG(4) << "Searched is:" + << SearchUtils::ChildScheduleBlockRealizes.GetSingle( + _GetRootExpr(op)); + const auto& compute_body = + (SearchUtils::ChildScheduleBlockRealizes * SearchUtils::ChildStores) + .GetSingle(_GetRootExpr(op)); + return compute_body.As()->tensor.as_tensor_ref(); + } + }; + VLOG(4) << "GetOutputTensor"; + return std::visit(Visitor(), op); +} + +ir::Expr _GetOriginalStoreValuePointer(const FusibleOp& op) { + struct Visitor { + ir::Expr operator()(const ReduceOp& op) { + return (SearchUtils::ChildScheduleBlockRealizes * + SearchUtils::ScheduleBlockRealizeIsNotInit * + SearchUtils::ChildStores * SearchUtils::Store2Value) + .GetSingle(_GetRootExpr(op)); + } + ir::Expr operator()(const TrivialOp& op) { + return (SearchUtils::ChildScheduleBlockRealizes * + SearchUtils::ChildStores * SearchUtils::Store2Value) + .GetSingle(_GetRootExpr(op)); + } + }; + return std::visit(Visitor(), op); +} + +std::vector AppendBound(const std::vector vars, + const ir::Expr& root) { + return SearchUtils::MapVector(vars, [&](const auto& v) -> ir::Var { + VLOG(4) << "AppendBound for " << v; + VLOG(4) << "lower: " + << (SearchUtils::ChildFors * SearchUtils::IsForIterVar(v) * + SearchUtils::For2Min) + .GetSingle(root); + VLOG(4) << "upper: " + << (SearchUtils::ChildFors * SearchUtils::IsForIterVar(v) * + SearchUtils::For2Max) + .GetSingle(root); + return ir::Var((SearchUtils::ChildFors * SearchUtils::IsForIterVar(v) * + SearchUtils::For2Min) + .GetSingle(root), + (SearchUtils::ChildFors * SearchUtils::IsForIterVar(v) * + SearchUtils::For2Max) + .GetSingle(root), + v->name); + }); +} + +std::vector GetOutputIters(const FusibleOp& op) { + struct Visitor { + std::vector operator()(const ReduceOp& op) { + ir::Expr init_block_realize = (SearchUtils::ChildScheduleBlockRealizes * + SearchUtils::ScheduleBlockRealizeIsInit) + .GetSingle(_GetRootExpr(op)); + const std::vector& outer_iter_expr = + init_block_realize.As()->iter_values; + return trivial_fusion_detail::ComposeUtils::ExprVec2VarVec( + outer_iter_expr); + } + std::vector operator()(const TrivialOp& op) { + const auto& compute_realize = + (SearchUtils::ChildScheduleBlockRealizes).GetSingle(_GetRootExpr(op)); + const std::vector& outer_iter_expr = + compute_realize.As()->iter_values; + return trivial_fusion_detail::ComposeUtils::ExprVec2VarVec( + outer_iter_expr); + } + }; + return AppendBound(std::visit(Visitor(), op), _GetRootExpr(op)); +} + +std::vector GetAllIterVars(const ReduceOp& op) { + ir::Expr compute_schedule_block_realize = + (SearchUtils::ChildScheduleBlockRealizes * + SearchUtils::ScheduleBlockRealizeIsNotInit) + .GetSingle(_GetRootExpr(op)); + + const std::vector& all_iter_expr = + compute_schedule_block_realize.As() + ->iter_values; + return ComposeUtils::ExprVec2VarVec(all_iter_expr); +} + +std::vector GetReduceIters(const ReduceOp& op) { + // Iter Vars not appearing in outer_iter_vars are pushed into + // reduce_iter_vars + std::vector all_iter_vars = GetAllIterVars(op); + std::vector outer_iter_vars = GetOutputIters(op); + std::vector reduce_iter_vars; + + for (auto& iter_var : all_iter_vars) { + if (!(std::find(outer_iter_vars.begin(), outer_iter_vars.end(), iter_var) != + outer_iter_vars.end())) { + reduce_iter_vars.push_back(iter_var); + } + } + return AppendBound(reduce_iter_vars, _GetRootExpr(op)); +} + +ir::Expr GetInitExpr(const ReduceOp& op) { + return (SearchUtils::ChildScheduleBlockRealizes * + SearchUtils::ScheduleBlockRealizeIsInit * SearchUtils::ChildStores * + SearchUtils::Store2Value) + .GetSingle(op.GetFuncBody()); +} + +ir::Expr* _GetFuncBodyPointer(FusibleOp op) { + return std::visit([&](auto&& arg) { return arg._GetFuncBodyPointer(); }, op); +} + +ir::Expr CopyReduceBody(const FusibleOp& downstream, const ReduceOp& upstream) { + struct Visitor { + ir::Expr operator()(const ReduceOp& op) { + return ir::ir_utils::IRCopy(op.GetFuncBody()); + } + ir::Expr operator()(const TrivialOp& op) { + PADDLE_THROW("TrivialOp cannot be copied."); + } + }; + return std::visit(Visitor(), downstream); +} + +ir::Expr CreateReduceExpr( + const std::vector& output_iters, + const std::vector& reduce_iters, + const ir::Expr& init_body, // relay on output_iters + const ir::Expr& reduce_body, // relay on output_iters + reduce_iters + const ir::Tensor& new_write_tensor, + const ir::Tensor& origin_write_tensor) { + VLOG(4) << "CreateReduceExpr Start."; + const std::vector indice_expr = + std::vector(output_iters.begin(), output_iters.end()); + const auto& new_init_tensor = ir::Tensor(new_write_tensor->name + "__init", + new_write_tensor->type(), + new_write_tensor->shape, + new_write_tensor->domain, + new_write_tensor->operation); + + const auto& init_schedule_block = + (TransformerUtils::WrapStoreTransformer(new_init_tensor, indice_expr) * + TransformerUtils::WrapScheduleRealizer( + output_iters, new_init_tensor->name))(init_body); + + const auto& reduce_schedule_block = + (TransformerUtils::ChangeTensorLoadTransformer( + origin_write_tensor, new_write_tensor(indice_expr)) * + TransformerUtils::WrapStoreTransformer(new_write_tensor, indice_expr) * + TransformerUtils::WrapScheduleRealizer( + ComposeUtils::ConcatVector(output_iters, reduce_iters), + new_write_tensor->name) * + TransformerUtils::WrapForsTransformer(reduce_iters))(reduce_body); + + const auto& gather_body = ir::Block::Make( + std::vector({init_schedule_block, reduce_schedule_block})); + return ir::Block::Make( + {(TransformerUtils::WrapForsTransformer(output_iters) * + TransformerUtils::WrapScheduleRealizer({}, "root"))(gather_body)}); +} + +ir::Expr CreateTrivialExpr(const std::vector& output_iters, + const ir::Expr& function_body, + const ir::Tensor& new_write_tensor) { + VLOG(4) << "CreateTrivialExpr Start."; + const std::vector indice_expr = + std::vector(output_iters.begin(), output_iters.end()); + const auto& compute_body_schedule_block = + (TransformerUtils::WrapStoreTransformer(new_write_tensor, indice_expr) * + TransformerUtils::WrapScheduleRealizer( + output_iters, new_write_tensor->name))(function_body); + return ir::Block::Make({(TransformerUtils::WrapForsTransformer(output_iters) * + TransformerUtils::WrapScheduleRealizer({}, "root"))( + ir::Block::Make({compute_body_schedule_block}))}); +} + +ir::Expr CreateExprWithNewComputeBody(FusibleOp fusible_op, + ir::Expr new_compute_body) { + struct Visitor { + ir::Expr operator()(const ReduceOp& op) { + return CreateReduceExpr(GetOutputIters(op), + GetReduceIters(op), + GetInitExpr(op), + compute_body_, + GetOutputTensor(op), + GetOutputTensor(op)); + } + ir::Expr operator()(const TrivialOp& op) { + return CreateTrivialExpr( + GetOutputIters(op), compute_body_, GetOutputTensor(op)); + } + + ir::Expr compute_body_; + explicit Visitor(ir::Expr compute_body) { compute_body_ = compute_body; } + }; + VLOG(4) << "CreateExprWithNewComputeBody"; + return std::visit(Visitor(new_compute_body), fusible_op); +} + +FusionNode::FusionNode(FusibleOp fusible_op) : fusible_op(fusible_op) {} + +std::string FusionNode::GetTensorCounter() { + static int i = 0; + return std::to_string(i++); +} + +void FusionNode::replace_topo_structure_of_fused_nodes( + FusionNode* fused_up_node, FusionNode* fused_down_node) { + upstream.insert(fused_up_node->upstream.begin(), + fused_up_node->upstream.end()); + upstream.insert(fused_down_node->upstream.begin(), + fused_down_node->upstream.end()); + upstream.erase(fused_up_node); + + downstream.insert(fused_up_node->downstream.begin(), + fused_up_node->downstream.end()); + downstream.insert(fused_down_node->downstream.begin(), + fused_down_node->downstream.end()); + downstream.erase(fused_down_node); + + expr_related_op = fused_down_node->expr_related_op; + + for (const auto& pair_data : upstream) { + FusionNode* upstream_node = pair_data.first; + ::pir::Value related_value = pair_data.second; + if (upstream_node->downstream.find(fused_up_node) != + upstream_node->downstream.end()) { + upstream_node->downstream.erase(fused_up_node); + } + if (upstream_node->downstream.find(fused_down_node) != + upstream_node->downstream.end()) { + upstream_node->downstream.erase(fused_down_node); + } + upstream_node->downstream[this] = related_value; + } + + for (const auto& pair_data : downstream) { + FusionNode* downstream_node = pair_data.first; + ::pir::Value related_value = pair_data.second; + if (downstream_node->upstream.find(fused_up_node) != + downstream_node->upstream.end()) { + downstream_node->upstream.erase(fused_up_node); + } + if (downstream_node->upstream.find(fused_down_node) != + downstream_node->upstream.end()) { + downstream_node->upstream.erase(fused_down_node); + } + downstream_node->upstream[this] = related_value; + } +} + +bool FusionNode::IsTrivial() const { + return std::holds_alternative(fusible_op); +} + +bool CheckAllLoopRangeEq(ReduceOp reduce_upper, TrivialOp trivial_down) {} + +std::vector TransformReduceLoopRange(const ReduceOp& upstream, + FusibleOp* downstream) { + // downstream will be mutated by this transform. + VLOG(4) << "RRTransform begin"; + VLOG(4) << "Upstream is " << upstream.GetFuncBody(); + ir::Expr modified_downstream_compute_body = GetComputeBody(*downstream); + const auto& load_upstream_expr = ComposeUtils::GetEachTensorLoadExpr( + modified_downstream_compute_body, GetOutputTensor(upstream)); + std::vector results; + ir::Tensor downstream_output_tensor = GetOutputTensor(*downstream); + const auto create_new_tensor = [&](const ir::Tensor& downstream_load_tensor) { + VLOG(4) << "downstream output tensor: " << downstream_output_tensor; + VLOG(4) << "downstream_load_tensor : " << downstream_load_tensor; + return ir::Tensor( + downstream_load_tensor->name + "_" + FusionNode::GetTensorCounter(), + downstream_load_tensor->type(), + downstream_output_tensor->shape, + downstream_output_tensor->domain, + downstream_load_tensor->operation); + }; + + for (const auto& load_tensor : load_upstream_expr) { + const auto& new_tensor = + create_new_tensor(load_tensor.As()->tensor.as_tensor_ref()); + VLOG(4) << "GetInit: " << GetInitExpr(upstream); + VLOG(4) << "GetNewTensor: " << new_tensor; + VLOG(4) << "GetOutputIter: " + << utils::Join(GetOutputIters(*downstream), " "); + VLOG(4) << "GetReduceIter: " << utils::Join(GetReduceIters(upstream), " "); + VLOG(4) << "GetCompute: " + << ComposeUtils::CopyedReplaceExpr( + GetComputeBody(upstream), + GetOutputIters(upstream), + load_tensor.As()->indices); + ir::Expr new_reduce = CreateReduceExpr( + GetOutputIters(*downstream), + GetReduceIters(upstream), + GetInitExpr(upstream), + ComposeUtils::CopyedReplaceExpr(GetComputeBody(upstream), + GetOutputIters(upstream), + load_tensor.As()->indices), + new_tensor, + GetOutputTensor(upstream)); + results.emplace_back(ReduceOp(new_reduce)); + TransformerUtils::ReplaceTarget( + &modified_downstream_compute_body, + load_tensor, + new_tensor(ComposeUtils::VarVec2ExprVec(GetOutputIters(*downstream)))); + } + _SetFuncBody(*downstream, + CreateExprWithNewComputeBody(*downstream, + modified_downstream_compute_body)); + VLOG(4) << "After Replace Downstream Load: \n" << _GetRootExpr(*downstream); + return results; +} + +FusibleOp TrivialFusion(FusionNode* upstream, FusionNode* downstream) { + CHECK(upstream->IsTrivial()); + if (downstream->IsTrivial()) { + return TrivalxOther_Fusion(std::get(upstream->fusible_op), + std::get(downstream->fusible_op)); + } else { + return TrivalxOther_Fusion(std::get(upstream->fusible_op), + std::get(downstream->fusible_op)); + } +} + +FusibleOp SinkTrivialLoopAlign(TrivialOp trivial_op, ReduceOp reduce_op) { + ir::Expr new_trivial_body = ir::ir_utils::IRCopy(trivial_op.GetFuncBody()); + ir::Var last_iter = GetOutputIters(trivial_op).back(); + ir::Expr trivial_last_for = + (SearchUtils::ChildFors * SearchUtils::IsForIterVar(last_iter)) + .GetSingle(new_trivial_body); + ir::Expr new_for_body = trivial_last_for.As()->body; + new_for_body = TransformerUtils::WrapForsTransformer( + GetReduceIters(reduce_op))(new_for_body); + trivial_last_for.As()->body = new_for_body; + return TrivialOp(new_trivial_body); +} + +std::vector ReduceTransformRecursive(FusibleOp root_op, + FusionNode* fusion_tree) { + VLOG(4) << "ReduceTransformRecursive: " << *_GetFuncBodyPointer(root_op); + std::vector result; + for (auto& pair : fusion_tree->upstream) { + auto transformed_nodes = TransformReduceLoopRange( + std::get(pair.first->fusible_op), &root_op); + for (auto& node : transformed_nodes) { + auto child_flatten = ReduceTransformRecursive(node, pair.first); + result.insert(result.end(), child_flatten.begin(), child_flatten.end()); + } + } + VLOG(4) << "Before push_back, is trivial_op: " + << std::holds_alternative(root_op); + result.push_back( + std::holds_alternative(root_op) + ? SinkTrivialLoopAlign( + std::get(root_op), + std::get( + fusion_tree->upstream.begin()->first->fusible_op)) + : root_op); + VLOG(4) << "After push_back."; + return result; +} + +std::vector ReduceTransform(FusionNode* downstream) { + if (downstream->IsTrivial() && downstream->upstream.empty()) { + return {downstream->fusible_op}; + } + auto reduces = ReduceTransformRecursive(downstream->fusible_op, downstream); + return reduces; +} + +FusibleOp CreateFusibleOp(ir::Expr compute_body, OpPatternKind op_pattern) { + if (IsTrivialKind(op_pattern)) { + return TrivialOp(compute_body); + } else { + return ReduceOp(compute_body); + } +} + +FusionGraph::FusionGraph(const std::vector<::pir::Operation*>& ops, + const std::vector& op_compute_bodies) { + // shardable_axes_ = InferShardableAxes(ops); + VLOG(4) << "CreateFusionGraph"; + + const auto& op_patterns = GetOpPatternKindVector(ops); + CheckFusionInputValid(op_compute_bodies, op_patterns); + + std::unordered_map<::pir::Operation*, FusionNode*> op_to_node_map; + + for (int i = 0; i < ops.size(); ++i) { + FusionNode* node = + new FusionNode(CreateFusibleOp(op_compute_bodies[i], op_patterns[i])); + op_to_node_map[ops[i]] = node; + all_fusion_nodes_.emplace(node); + node->expr_related_op = ops[i]; + } + + for (::pir::Operation* op : ops) { + FusionNode* cur_node = op_to_node_map[op]; + + // add upstream nodes + for (int i = 0; i < op->num_operands(); ++i) { + ::pir::Value related_value = op->operand_source(i); + ::pir::Operation* input_op = related_value.defining_op(); + if (op_to_node_map.find(input_op) != op_to_node_map.end()) { + FusionNode* upstream_node = op_to_node_map[input_op]; + cur_node->upstream[upstream_node] = related_value; + upstream_node->downstream[cur_node] = related_value; + } + } + + // add downstream nodes + for (int i = 0; i < op->num_results(); ++i) { + ::pir::Value related_value = op->result(i); + for (auto consumer_it = related_value.use_begin(); + consumer_it != related_value.use_end(); + ++consumer_it) { + ::pir::Operation* output_op = consumer_it->owner(); + if (op_to_node_map.find(output_op) != op_to_node_map.end()) { + FusionNode* downstream_node = op_to_node_map[output_op]; + cur_node->downstream[downstream_node] = related_value; + downstream_node->upstream[cur_node] = related_value; + } + } + } + + if (cur_node->upstream.empty()) { + entrance_nodes_.emplace(cur_node); + } + + if (cur_node->downstream.empty()) { + exit_nodes_.emplace(cur_node); + } + } + + VLOG(4) << "FusionGraph Created, fusion node size: " + << all_fusion_nodes_.size(); +} + +FusionGraph::~FusionGraph() { + for (FusionNode* node : all_fusion_nodes_) { + delete node; + } +} + +std::vector FusionGraph::DoFusion() { + VLOG(4) << "Start Trivial Fusion"; + DoTrivialFusion(); + VLOG(4) << "Start R + T and R + R Fusion"; + ReduceLoopTranform(); + return GetExprResults(); +} + +FusionNode* FusionGraph::FindTrivialFusibleNode() { + for (FusionNode* node : all_fusion_nodes_) { + if (node->IsTrivial() && !node->downstream.empty()) { + return node; + } + } + return nullptr; +} + +void FusionGraph::DoTrivialFusion() { + FusionNode* upstream = nullptr; + // use funcion to get upstream and downstream is save here + // cause we might delete Nodes in this process + while ((upstream = FindTrivialFusibleNode()) != nullptr) { + std::unordered_map fusion_candidate = + upstream->downstream; + upstream->downstream.clear(); + for (const auto& pair_data : fusion_candidate) { + FusionNode* downstream = pair_data.first; + FusionNode* new_node = + new FusionNode(TrivialFusion(upstream, downstream)); + new_node->replace_topo_structure_of_fused_nodes(upstream, downstream); + AppendNode(new_node); + RemoveNode(downstream); + } + RemoveNode(upstream); + } +} + +void FusionGraph::ReduceLoopTranform() { + for (FusionNode* node : exit_nodes_) { + auto fusion_nodes = ReduceTransform(node); + fusion_results_.insert( + fusion_results_.end(), fusion_nodes.begin(), fusion_nodes.end()); + } +} + +std::vector FusionGraph::GetExprResults() { + std::vector output_exprs; + for (const auto& node : fusion_results_) { + output_exprs.emplace_back(_GetRootExpr(node)); + } + return output_exprs; +} + +void FusionGraph::RemoveNode(FusionNode* node) { + if (all_fusion_nodes_.find(node) != all_fusion_nodes_.end()) { + all_fusion_nodes_.erase(node); + } + if (entrance_nodes_.find(node) != entrance_nodes_.end()) { + entrance_nodes_.erase(node); + } + if (exit_nodes_.find(node) != exit_nodes_.end()) { + exit_nodes_.erase(node); + } + delete node; +} + +void FusionGraph::AppendNode(FusionNode* node) { + all_fusion_nodes_.emplace(node); + if (node->upstream.empty()) { + entrance_nodes_.emplace(node); + } + + if (node->downstream.empty()) { + exit_nodes_.emplace(node); + } +} + +FusionNode* FusionGraph::FindReduceUpstream(FusionNode* node) { + for (const auto& pair_data : node->upstream) { + FusionNode* upstream = pair_data.first; + if (!upstream->IsTrivial()) { + return upstream; + } + } + return nullptr; +} + +} // namespace trivial_fusion_detail + +std::vector OperationFusion( + const std::vector<::pir::Operation*>& ops, + const std::vector& op_compute_bodies) { + trivial_fusion_detail::FusionGraph graph = + trivial_fusion_detail::FusionGraph(ops, op_compute_bodies); + auto output = graph.DoFusion(); + VLOG(4) << "Fusion Result: output size is " << output.size(); + for (const auto& expr : output) { + VLOG(4) << expr; + } + return output; +} + +} // namespace pir +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/paddle/cinn/hlir/framework/pir/trivial_op_impl.h b/paddle/cinn/hlir/framework/pir/trivial_op_impl.h new file mode 100644 index 0000000000000..de146230b83c7 --- /dev/null +++ b/paddle/cinn/hlir/framework/pir/trivial_op_impl.h @@ -0,0 +1,209 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include + +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/cinn/hlir/framework/compile_error.h" +#include "paddle/cinn/hlir/framework/pir/op_lowering_util.h" +#include "paddle/cinn/hlir/framework/pir/trivial_op_util.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" +#include "paddle/cinn/hlir/op/external_api_registry.h" +#include "paddle/cinn/hlir/pe/map_expr_to_ir.h" +#include "paddle/cinn/ir/dim.h" +#include "paddle/cinn/ir/group_schedule/base_group_scheduler.h" +#include "paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h" +#include "paddle/cinn/ir/schedule/ir_schedule.h" +#include "paddle/cinn/ir/schedule/ir_schedule_util.h" +#include "paddle/cinn/lang/placeholder.h" +#include "paddle/cinn/optim/schedule_block_dce.h" +#include "paddle/cinn/optim/transform_gpu_forloop.h" +#include "paddle/common/ddim.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" + +// #include "paddle/cinn/frontend/group_pattern_util.h" + +namespace cinn { +namespace hlir { +namespace framework { +namespace pir { +namespace trivial_fusion_detail { + +struct TrivialOp { + public: + explicit TrivialOp(const ir::Expr& origin_func_body); + + TrivialOp(const TrivialOp& trivial_op); + + void _SetFuncBody(ir::Expr new_body); + ir::Expr* _GetFuncBodyPointer(); + + ir::Expr GetFuncBody() const; + + private: + ir::Expr func_body; +}; + +struct ReduceOp { + public: + explicit ReduceOp(const ir::Expr& origin_func_body); + ReduceOp(const ReduceOp& reduce_op); + + void _SetFuncBody(ir::Expr new_body); + + ir::Expr GetFuncBody() const; + + ir::Expr* _GetFuncBodyPointer(); + + private: + ir::Expr func_body; +}; + +using FusibleOp = std::variant; + +ir::Expr _GetRootExpr(const FusibleOp& op); + +void _SetFuncBody(FusibleOp& op, ir::Expr new_body); +ir::Expr GetComputeBody(const FusibleOp& op); + +ir::Tensor GetOutputTensor(const FusibleOp& op); + +ir::Expr _GetOriginalStoreValuePointer(const FusibleOp& op); + +std::vector AppendBound(const std::vector vars, + const ir::Expr& root); + +std::vector GetOutputIters(const FusibleOp& op); + +std::vector GetAllIterVars(const ReduceOp& op); + +std::vector GetReduceIters(const ReduceOp& op); + +ir::Expr GetInitExpr(const ReduceOp& op); + +ir::Expr* _GetFuncBodyPointer(FusibleOp op); + +ir::Expr CopyReduceBody(const FusibleOp& downstream, const ReduceOp& upstream); + +ir::Expr CreateReduceExpr( + const std::vector& output_iters, + const std::vector& reduce_iters, + const ir::Expr& init_body, // relay on output_iters + const ir::Expr& reduce_body, // relay on output_iters + reduce_iters + const ir::Tensor& new_write_tensor, + const ir::Tensor& origin_write_tensor); + +ir::Expr CreateTrivialExpr(const std::vector& output_iters, + const ir::Expr& function_body, + const ir::Tensor& new_write_tensor); +ir::Expr CreateExprWithNewComputeBody(FusibleOp fusible_op, + ir::Expr new_compute_body); +struct FusionNode { + FusibleOp fusible_op; + ::pir::Operation* expr_related_op; + + std::unordered_map upstream; + std::unordered_map downstream; + + explicit FusionNode(FusibleOp fusible_op); + + static std::string GetTensorCounter(); + void replace_topo_structure_of_fused_nodes(FusionNode* fused_up_node, + FusionNode* fused_down_node); + + bool IsTrivial() const; +}; + +template +DownStreamOp TrivalxOther_Fusion(TrivialOp upstream, DownStreamOp downstream) { + VLOG(4) << "Trivial x OtherFusion begin."; + + const auto& replaced_tensor = GetOutputTensor(upstream); + VLOG(4) << "upstream is " << upstream.GetFuncBody(); + VLOG(4) << "downstream is " << downstream.GetFuncBody(); + + DownStreamOp fused(ir::ir_utils::IRCopy(downstream.GetFuncBody())); + ir::Expr origin_compute_body = _GetOriginalStoreValuePointer(fused); + SequenceMutator( + ComposeUtils::GetEachTensorLoadExpr(origin_compute_body, replaced_tensor), + &origin_compute_body, + [&](const ir::Expr& downstream_load_expr, ir::Expr* downstream_body) { + ComposeUtils::ReplaceDownstreamLoadExprWithUpstreamComputeBody( + upstream, downstream_load_expr, downstream_body); + }); + + VLOG(4) << "After mutate, compute body: " << origin_compute_body; + VLOG(4) << "TTFusion end:\n" << fused.GetFuncBody(); + return fused; +} + +bool CheckAllLoopRangeEq(ReduceOp reduce_upper, TrivialOp trivial_down); + +std::vector TransformReduceLoopRange(const ReduceOp& upstream, + FusibleOp* downstream); + +FusibleOp TrivialFusion(FusionNode* upstream, FusionNode* downstream); + +FusibleOp SinkTrivialLoopAlign(TrivialOp trivial_op, ReduceOp reduce_op); + +std::vector ReduceTransformRecursive(FusibleOp root_op, + FusionNode* fusion_tree); +std::vector ReduceTransform(FusionNode* downstream); + +FusibleOp CreateFusibleOp(ir::Expr compute_body, OpPatternKind op_pattern); + +struct FusionGraph { + explicit FusionGraph(const std::vector<::pir::Operation*>& ops, + const std::vector& op_compute_bodies); + + ~FusionGraph(); + + std::vector DoFusion(); + + private: + FusionNode* FindTrivialFusibleNode(); + + void DoTrivialFusion(); + + void ReduceLoopTranform(); + + std::vector GetExprResults(); + + void RemoveNode(FusionNode* node); + + void AppendNode(FusionNode* node); + + FusionNode* FindReduceUpstream(FusionNode* node); + + private: + std::unordered_set all_fusion_nodes_; + std::vector fusion_results_; + std::unordered_set entrance_nodes_; + std::unordered_set exit_nodes_; + + // std::unordered_map<::pir::Value, ShardableAxes> shardable_axes_; +}; + +} // namespace trivial_fusion_detail + +std::vector OperationFusion( + const std::vector<::pir::Operation*>& ops, + const std::vector& op_compute_bodies); + +} // namespace pir +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/paddle/cinn/hlir/framework/pir/trivial_op_util.cc b/paddle/cinn/hlir/framework/pir/trivial_op_util.cc new file mode 100644 index 0000000000000..cf92dc3c0f6fa --- /dev/null +++ b/paddle/cinn/hlir/framework/pir/trivial_op_util.cc @@ -0,0 +1,494 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/framework/pir/trivial_op_util.h" + +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/cinn/hlir/framework/compile_error.h" +#include "paddle/cinn/hlir/framework/pir/op_lowering_util.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" +#include "paddle/cinn/hlir/op/external_api_registry.h" +#include "paddle/cinn/hlir/pe/map_expr_to_ir.h" +#include "paddle/cinn/ir/dim.h" +#include "paddle/cinn/ir/group_schedule/base_group_scheduler.h" +#include "paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h" +#include "paddle/cinn/ir/schedule/ir_schedule.h" +#include "paddle/cinn/ir/schedule/ir_schedule_util.h" +#include "paddle/cinn/lang/placeholder.h" +#include "paddle/cinn/optim/schedule_block_dce.h" +#include "paddle/cinn/optim/transform_gpu_forloop.h" +#include "paddle/common/ddim.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" + +namespace cinn { +namespace hlir { +namespace framework { +namespace pir { +namespace trivial_fusion_detail { + +namespace ComposeUtils { + +std::vector ExprVec2VarVec(const std::vector& in) { + std::vector out; + for (auto& expr : in) { + out.push_back(expr.as_var_ref()); + } + return out; +} + +std::vector VarVec2ExprVec(const std::vector& in) { + return std::vector(in.begin(), in.end()); +} + +std::vector GetEachTensorLoadExpr(const ir::Expr& body, + const ir::Tensor& tensor) { + VLOG(4) << "Start GetEachTensorLoadExpr: " << tensor; + std::set load_exprs = cinn::ir::ir_utils::CollectIRNodesWithoutTensor( + body, [&tensor](const Expr* expr) { + return expr->As() && expr->As()->is_addr_tensor() && + expr->As()->tensor.as_tensor_ref()->name == + tensor->name; + }); + for (auto& t : load_exprs) { + VLOG(4) << "GetEachTensorLoadExpr: " << t << " " << t.ptr(); + } + return std::vector(load_exprs.begin(), load_exprs.end()); +} + +MappingTargetExprToDestExprMutator::MappingTargetExprToDestExprMutator( + const ir::Expr& source, const ir::Expr& dest) + : source_(source), dest_(dest) {} + +void MappingTargetExprToDestExprMutator::operator()(Expr* expr) { + IRMutator::Visit(expr, expr); +} + +void MappingTargetExprToDestExprMutator::Visit(const ir::Load* load, Expr* op) { + VLOG(4) << "SubstitudeTargetExprWithDestExpr: " << load << " vs " + << source_.ptr(); + if (load == source_.ptr()) { + VLOG(4) << "substitude find!"; + *op = dest_; + } else { + IRMutator::Visit(load, op); + } +} +void MappingTargetExprToDestExprMutator::Visit(const ir::Store* store, + Expr* op) { + VLOG(4) << "SubstitudeTargetExprWithDestExpr: " << store << " vs " + << source_.ptr(); + if (store == source_.ptr()) { + VLOG(4) << "substitude find!"; + *op = dest_; + } else { + IRMutator::Visit(store, op); + } +} +void MappingTargetExprToDestExprMutator::Visit(const ir::Reduce* reduce, + Expr* op) { + VLOG(4) << "SubstitudeTargetExprWithDestExpr: " << reduce << " vs " + << source_.ptr(); + if (reduce == source_.ptr()) { + VLOG(4) << "substitude find!"; + *op = dest_; + } else { + IRMutator::Visit(reduce, op); + } +} + +bool CheckIterEq(const std::vector& up_iter, + const std::vector& down_iter) { + if (up_iter.size() != down_iter.size()) return false; + + for (int i = 0; i < up_iter.size(); ++i) { + const ir::Var& up_iter_var = up_iter[i]; + const ir::Var& down_iter_var = down_iter[i]; + + if (up_iter_var != down_iter_var) return false; + if (up_iter_var->lower_bound.as_int64() != + down_iter_var->lower_bound.as_int64()) + return false; + if (up_iter_var->upper_bound.as_int64() != + down_iter_var->upper_bound.as_int64()) + return false; + } + return true; +} + +ir::Expr CopyedReplaceExpr(const Expr& source, + const std::vector& replaced, + const std::vector& candidates) { + VLOG(4) << "Copyed Replace Expr Start"; + CHECK_EQ(replaced.size(), candidates.size()) + << "In ReplaceExpr, the size of Vars to be replaced must be equal to " + "the " + "size of cadidate Exprs! Please check."; + auto copyed_source = ir::ir_utils::IRCopy(source); + if (replaced.empty()) return copyed_source; + std::map replacing_map; + for (int i = 0; i < replaced.size(); ++i) { + // If the Var to be replaced is equal to the candidate, we skip it. + if (candidates[i].is_var() && candidates[i].as_var_ref() == replaced[i]) + continue; + replacing_map[replaced[i]] = candidates[i]; + } + ir::MappingVarToExprMutator mapper(replacing_map); + mapper(©ed_source); + VLOG(4) << "Copyed Replace Expr End"; + return copyed_source; +} + +void SubstitudeTargetExprWithDestExpr(const ir::Expr& source, + const ir::Expr& dest, + ir::Expr* body) { + VLOG(4) << "Start SubstitudeTargetExprWithDestExpr"; + MappingTargetExprToDestExprMutator mapper(source, dest); + mapper(body); + VLOG(4) << "End SubstitudeTargetExprWithDestExpr"; +} + +ir::Expr SubstitudeIndexVector(const Expr& source, + const std::vector& load_vars, + const std::vector& indices) { + return CopyedReplaceExpr(source, load_vars, indices); +} +} // namespace ComposeUtils + +namespace SearchUtils { + +using ExprSet = std::vector; +using Func = std::function; +Mapping::Mapping(Func f, std::string s) { + f_ = f; + name = s; +} +ExprSet Mapping::operator()(const ir::Expr& x) const { return f_(x); } +ir::Expr Mapping::GetSingle(const ir::Expr& x) const { + Mapping call = (*this) * Mapping::GetIdentity(); + const auto& o = call.operator()(x); + if (o.size() != 1) { + PADDLE_THROW("Try to get single result, but we get %d.", o.size()); + } + return *o.begin(); +} +Mapping Mapping::operator*(Mapping x) const { + auto new_f = [self = *this, x = x](const ir::Expr& e) -> ExprSet { + const auto& rs = self.f_(e); + VLOG(6) << "Mapping Info : " << self.name; + VLOG(6) << " Inputs :" << e; + for (const auto& r : rs) { + VLOG(6) << " Outputs : \n" << r; + } + std::vector res; + for (const auto& r : rs) { + const auto& x_res = x.f_(r); + res.insert(res.begin(), x_res.begin(), x_res.end()); + } + return res; + }; + return Mapping(std::function(new_f), x.name + "*" + this->name); +} +Mapping Mapping::GetIdentity() { + return Mapping([](const ir::Expr& e) { return std::vector{e}; }, + "identity"); +} + +Mapping Identity = Mapping::GetIdentity(); + +Mapping Store2Value = Mapping( + [](const ir::Expr& e) -> ExprSet { + if (e.As()) { + return {e.As()->value}; + } + return {}; + }, + "Store2Value"); + +Mapping Realizer2ScheduleBlock = Mapping( + [](const ir::Expr& e) -> ExprSet { + if (e.As()) { + return {e.As()->schedule_block}; + } + return {}; + }, + "Realizer2ScheduleBlock"); + +Mapping ScheduleBlock2Body = Mapping( + [](const ir::Expr& e) -> ExprSet { + if (e.As()) { + return {e.As()->body}; + } + return {}; + }, + "ScheduleBlock2Body"); + +Mapping ScheduleBlockRealizeNotRoot = FilterMaker( + [](const ir::Expr& e) -> bool { + return (e.As() && + e.As() + ->schedule_block.As() + ->name.find("root") == std::string::npos); + }, + "ScheduleBlockRealizeNotRoot"); + +Mapping ScheduleBlockRealizeIsNotInit = FilterMaker( + [](const ir::Expr& e) -> bool { + return (e.As() && + e.As() + ->schedule_block.As() + ->name.find("__reduce_init") == std::string::npos); + }, + "ScheduleBlockRealizeIsNotInit"); + +Mapping ScheduleBlockRealizeIsInit = FilterMaker( + [](const ir::Expr& e) -> bool { + return (e.As() && + e.As() + ->schedule_block.As() + ->name.find("__reduce_init") != std::string::npos); + }, + "ScheduleBlockRealizeIsInit"); + +Mapping IsFor = FilterMaker( + [](const ir::Expr& e) -> bool { return e.As(); }, "IsFor"); + +Mapping ChildScheduleBlocks = + Collector([](const ir::Expr* e) { return e->As(); }, + "ChildScheduleBlocks"); + +Mapping ChildScheduleBlockRealizes = + Collector( + [](const ir::Expr* e) { return e->As(); }, + "ChildScheduleBlockRealizes") * + ScheduleBlockRealizeNotRoot; + +Mapping IsForIterVar(const ir::Var& var) { + return FilterMaker( + [var = var](const ir::Expr& e) -> bool { + return e.As() && e.As()->loop_var == var; + }, + "IsForIterVar"); +} + +Mapping For2Min = + Mapping([](const ir::Expr& e) -> ExprSet { return {e.As()->min}; }, + "For2Min"); + +Mapping For2Max = Mapping( + [](const ir::Expr& e) -> ExprSet { return {e.As()->extent}; }, + "For2Max"); + +Mapping ChildStores = Collector( + [](const ir::Expr* e) { return e->As(); }, "ChildStores"); + +Mapping ChildTensorLoads = Collector( + [](const ir::Expr* e) { + return e->As() && e->As()->is_addr_tensor(); + }, + "ChildLoads"); + +Mapping ChildTensorStores = Collector( + [](const ir::Expr* e) { + return e->As() && e->As()->is_addr_tensor(); + }, + "ChildTensorStores"); + +Mapping FilterLoadByTensor(const ir::Tensor& tensor) { + return FilterMaker( + [tensor = tensor](const ir::Expr& e) -> bool { + return e.As() && + e.As()->tensor.as_tensor_ref()->name == tensor->name; + }, + "FilterLoadByTensor(" + tensor->name + ")"); +} + +Mapping ChildFors = + Collector([](const ir::Expr* e) { return e->As(); }, "ChildFors"); + +Mapping FindFather(const ir::Expr& root) { + const auto& f = [&](const auto& child) -> ExprSet { + Mapping find_child = + Collector([child](const ir::Expr* e) { return *e == child; }); + const auto& father_collector = Collector( + [&](const ir::Expr* current) { return !find_child(*current).empty(); }); + return father_collector(root); + }; + return Mapping(f, "FindFather"); +} +} // namespace SearchUtils + +namespace TransformerUtils { +using TransformFunc = std::function; + +Transformer::Transformer(TransformFunc f) { f_ = f; } +ir::Expr Transformer::operator()(const ir::Expr& x) const { return f_(x); } +Transformer Transformer::operator*(const Transformer& x) const { + auto new_f = [self = *this, x = x](const ir::Expr& e) -> ir::Expr { + const auto& rs = self.f_(e); + return x.f_(rs); + }; + return Transformer(std::function(new_f)); +} + +Transformer Identity = Transformer([](const ir::Expr& e) { return e; }); +Transformer WrapForTransformer(const ir::Var& v) { + const auto& f = [=](const ir::Expr& e) -> ir::Expr { + auto block = e; + if (!block.As()) { + block = ir::Block::Make({e}); + } + return ir::For::Make(v, + v->lower_bound, + v->upper_bound, + ir::ForType::Serial, + ir::DeviceAPI::Host, + block); + }; + return Transformer(f); +} + +Transformer WrapForsTransformer(const std::vector& vs) { + const auto& f = [&](const ir::Expr& e) -> ir::Expr { + Transformer t = Identity; + for (const auto& v : vs) { + t = WrapForTransformer(v) * t; + } + return t(e); + }; + return Transformer(f); +} + +Transformer ChangeTensorLoadTransformer(const ir::Tensor& tensor, + const ir::Expr dst_load) { + const auto& f = [&](const ir::Expr& e) -> ir::Expr { + auto copied_e = ir::ir_utils::IRCopy(e); + const auto& load = (SearchUtils::ChildTensorLoads * + SearchUtils::FilterLoadByTensor(tensor)) + .GetSingle(copied_e); + ComposeUtils::MappingTargetExprToDestExprMutator(load, dst_load)(&copied_e); + return copied_e; + }; + return Transformer(f); +} + +void ReplaceTarget(ir::Expr* e, const ir::Expr& t, const ir::Expr dst) { + ComposeUtils::MappingTargetExprToDestExprMutator(t, dst)(e); +} + +Transformer WrapStoreTransformer(const ir::Tensor& tensor, + const std::vector& indices) { + const auto& f = [=](const ir::Expr& e) -> ir::Expr { + return ir::Store::Make(tensor, e, indices); + }; + return Transformer(f); +} + +std::vector CreateInnerBlockVars( + const std::vector& block_vars) { + int i = 0; + std::vector vars; + for (const auto& v : block_vars) { + vars.emplace_back("inner_block_" + std::to_string(i++)); + } + return vars; +} + +Transformer ChangeVarTransformer(const std::vector& target_vars, + const std::vector& dest_vars) { + const auto& f = [=](const ir::Expr& e) -> ir::Expr { + return ComposeUtils::CopyedReplaceExpr( + e, + target_vars, + std::vector(dest_vars.begin(), dest_vars.end())); + }; + return Transformer(f); +} + +Transformer SubstitudeByScheduleBlockRealize(const ir::Expr& realize) { + const auto& f = [=](const ir::Expr& e) -> ir::Expr { + const auto& iter_values = + realize.As()->iter_values; + const auto& iter_vars = realize.As() + ->schedule_block.As() + ->iter_vars; + return TransformerUtils::ChangeVarTransformer( + iter_vars, ComposeUtils::ExprVec2VarVec(iter_values))(e); + }; + return Transformer(f); +} + +Transformer WrapScheduleRealizer(const std::vector& block_vars, + const std::string& tensor_name) { + const auto& f = [=](const ir::Expr& e) -> ir::Expr { + if (e.As()) { + PADDLE_THROW("please input a non-schedule block expr."); + } + const auto& inner_block_var = CreateInnerBlockVars(block_vars); + const auto& replaced_e = + ChangeVarTransformer(block_vars, inner_block_var)(e); + const auto& schedule_block = ir::ScheduleBlock::Make( + inner_block_var, {}, {}, tensor_name, replaced_e); + const auto& schedule_realizer = ir::ScheduleBlockRealize::Make( + std::vector(block_vars.begin(), block_vars.end()), + schedule_block); + return schedule_realizer; + }; + return Transformer(f); +} +} // namespace TransformerUtils + +std::vector GetOpPatternKindVector( + const std::vector<::pir::Operation*>& ops) { + const auto& op_pattern_map = + Operator::GetAttrs("OpPattern"); + std::vector op_patterns; + const auto ConvertToPattern = [&op_pattern_map](const ::pir::Operation* op) { + const std::string cinn_op_name = CompatibleInfo::OpName(*op); + const hlir::framework::Operator* cinn_op = Operator::Get(cinn_op_name); + return op_pattern_map[cinn_op]; + }; + std::transform(ops.begin(), + ops.end(), + std::back_inserter(op_patterns), + ConvertToPattern); + return op_patterns; +} + +bool IsTrivialKind(OpPatternKind kind) { + return kind == OpPatternKind::kElementWise || + kind == OpPatternKind::kBroadcast || kind == OpPatternKind::kInjective; +} + +void CheckFusionInputValid(const std::vector& op_compute_bodies, + const std::vector& op_patterns) { + if (VLOG_IS_ON(4)) { + for (const auto& func : op_compute_bodies) { + VLOG(4) << "TrivialOpFusion: {FuncBody is} :" << func; + } + for (const auto& op_ptn : op_patterns) { + VLOG(4) << "OpPattern is :" << op_ptn; + } + } + VLOG(4) << " op_patterns.size() = " << op_compute_bodies.size(); + VLOG(4) << "op_compute_bodies.size() = " << op_patterns.size(); + PADDLE_ENFORCE_EQ( + op_patterns.size(), op_compute_bodies.size(), "ops and size not equal"); +} + +} // namespace trivial_fusion_detail +} // namespace pir +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/paddle/cinn/hlir/framework/pir/trivial_op_util.h b/paddle/cinn/hlir/framework/pir/trivial_op_util.h new file mode 100644 index 0000000000000..e87b33ba2fcef --- /dev/null +++ b/paddle/cinn/hlir/framework/pir/trivial_op_util.h @@ -0,0 +1,240 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/cinn/hlir/framework/compile_error.h" +#include "paddle/cinn/hlir/framework/pir/op_lowering_util.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" +#include "paddle/cinn/hlir/op/external_api_registry.h" +#include "paddle/cinn/hlir/pe/map_expr_to_ir.h" +#include "paddle/cinn/ir/dim.h" +#include "paddle/cinn/ir/group_schedule/base_group_scheduler.h" +#include "paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h" +#include "paddle/cinn/ir/schedule/ir_schedule.h" +#include "paddle/cinn/lang/placeholder.h" +#include "paddle/cinn/optim/schedule_block_dce.h" +#include "paddle/cinn/optim/transform_gpu_forloop.h" +#include "paddle/common/ddim.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" + +namespace cinn { +namespace hlir { +namespace framework { +namespace pir { +namespace trivial_fusion_detail { + +namespace ComposeUtils { + +template +std::vector ConcatVector(const std::vector& first, + const std::vector& second) { + std::vector result = first; + result.insert(result.end(), second.begin(), second.end()); + return result; +} + +std::vector ExprVec2VarVec(const std::vector& in); +std::vector VarVec2ExprVec(const std::vector& in); + +std::vector GetEachTensorLoadExpr(const ir::Expr& body, + const ir::Tensor& tensor); + +struct MappingTargetExprToDestExprMutator : public ir::IRMutator<> { + explicit MappingTargetExprToDestExprMutator(const ir::Expr& source, + const ir::Expr& dest); + + void operator()(Expr* expr); + + private: + void Visit(const ir::Load* load, Expr* op) override; + void Visit(const ir::Store* store, Expr* op) override; + void Visit(const ir::Reduce* reduce, Expr* op) override; + + private: + ir::Expr source_; + ir::Expr dest_; +}; + +bool CheckIterEq(const std::vector& up_iter, + const std::vector& down_iter); + +ir::Expr CopyedReplaceExpr(const Expr& source, + const std::vector& replaced, + const std::vector& candidates); +void SubstitudeTargetExprWithDestExpr(const ir::Expr& source, + const ir::Expr& dest, + ir::Expr* body); + +ir::Expr SubstitudeIndexVector(const Expr& source, + const std::vector& load_vars, + const std::vector& indices); + +template +void ReplaceDownstreamLoadExprWithUpstreamComputeBody( + const FusionOp& upstream, + const ir::Expr& downstream_load_expr, + ir::Expr* downstream_body) { + ComposeUtils::SubstitudeTargetExprWithDestExpr( + downstream_load_expr, + ComposeUtils::SubstitudeIndexVector( + GetComputeBody(upstream), + GetOutputIters(upstream), + downstream_load_expr.As()->indices), + downstream_body); +} +} // namespace ComposeUtils + +namespace SearchUtils { + +using ExprSet = std::vector; +using Func = std::function; +struct Mapping { + Func f_; + std::string name; + explicit Mapping(Func f, std::string s = ""); + + ExprSet operator()(const ir::Expr& x) const; + ir::Expr GetSingle(const ir::Expr& x) const; + Mapping operator*(Mapping x) const; + static Mapping GetIdentity(); +}; + +template +Mapping Collector(Teller t, std::string name = "") { + return Mapping( + [=](const ir::Expr& x) -> ExprSet { + const auto& rs = cinn::ir::ir_utils::CollectIRNodesWithoutTensor(x, t); + return std::vector(rs.begin(), rs.end()); + }, + name); +} + +template +Mapping FilterMaker(FilterFunc t, std::string name) { + return Mapping( + [=](const ir::Expr& x) -> ExprSet { + if (t(x)) { + return {x}; + } + return {}; + }, + name); +} + +extern Mapping Identity; + +extern Mapping Store2Value; + +extern Mapping Realizer2ScheduleBlock; + +extern Mapping ScheduleBlock2Body; + +extern Mapping ScheduleBlockRealizeNotRoot; + +extern Mapping ScheduleBlockRealizeIsNotInit; + +extern Mapping ScheduleBlockRealizeIsInit; + +extern Mapping IsFor; + +extern Mapping ChildScheduleBlocks; + +extern Mapping ChildScheduleBlockRealizes; + +extern Mapping For2Min; + +extern Mapping For2Max; + +extern Mapping ChildStores; + +extern Mapping ChildTensorLoads; + +extern Mapping ChildTensorStores; + +extern Mapping ChildFors; + +Mapping IsForIterVar(const ir::Var& var); + +Mapping FilterLoadByTensor(const ir::Tensor& tensor); + +Mapping FindFather(const ir::Expr& root); + +template +std::vector MapVector(const std::vector& as, M func) { + std::vector res; + for (const auto& a : as) { + res.push_back(func(a)); + } + return res; +} +} // namespace SearchUtils + +namespace TransformerUtils { +using TransformFunc = std::function; +struct Transformer { + TransformFunc f_; + explicit Transformer(TransformFunc f); + ir::Expr operator()(const ir::Expr& x) const; + Transformer operator*(const Transformer& x) const; +}; + +extern Transformer Identity; + +Transformer WrapForTransformer(const ir::Var& v); + +Transformer WrapForsTransformer(const std::vector& vs); +Transformer ChangeTensorLoadTransformer(const ir::Tensor& tensor, + const ir::Expr dst_load); + +void ReplaceTarget(ir::Expr* e, const ir::Expr& t, const ir::Expr dst); + +Transformer WrapStoreTransformer(const ir::Tensor& tensor, + const std::vector& indices); + +std::vector CreateInnerBlockVars( + const std::vector& block_vars); + +Transformer ChangeVarTransformer(const std::vector& target_vars, + const std::vector& dest_vars); + +Transformer SubstitudeByScheduleBlockRealize(const ir::Expr& realize); + +Transformer WrapScheduleRealizer(const std::vector& block_vars, + const std::string& tensor_name); +} // namespace TransformerUtils + +std::vector GetOpPatternKindVector( + const std::vector<::pir::Operation*>& ops); + +template +void SequenceMutator(const std::vector& as, C* acc, const Func& mutator) { + VLOG(4) << "SequenceTransform Init: " << acc; + for (int i = 0; i < as.size(); ++i) { + mutator(as[i], acc); + VLOG(4) << "SequenceTransform Iter: " << acc; + } +} + +bool IsTrivialKind(OpPatternKind kind); + +void CheckFusionInputValid(const std::vector& op_compute_bodies, + const std::vector& op_patterns); + +} // namespace trivial_fusion_detail +} // namespace pir +} // namespace framework +} // namespace hlir +} // namespace cinn