Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cinn trivalop fuse #59

Merged
87 changes: 38 additions & 49 deletions paddle/cinn/hlir/framework/pir/trivial_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,19 @@ std::set<Expr> GetStoreFromBody(const ir::Expr& body) {
return store_tensor_exprs;
}

bool CheckIterEq(std::vector<ir::Var> up_iter,
std::vector<ir::Var> down_iter){TODO}
std::vector<ir::Var> GetOutputIters(const std::vector<ir::Expr>& indices) {
std::vector<ir::Var> vars;
std::transform(indices.begin(),
indices.end(),
std::back_inserter(vars),
[](const ir::Expr& expr) { return expr.as_var_ref(); });
return vars;
}


ir::Expr TransformComputeExpr(ir::Expr up_compute_expr, ir::Expr downstream) {
bool CheckIterEq(std::vector<ir::Var> up_iter, std::vector<ir::Var> down_iter){
TODO
}
}

} // namespace ComposeUtils

Expand All @@ -243,23 +250,16 @@ struct TrivialOp {
return GetSingleStoreExpr(func_body).As<ir::Store>()->value;
}

std::vector<ir::Var> GetOutputIters() const {
return ComposeUtils::GetOutputIters(GetSingleStoreExpr(func_body).As<ir::Store>()->indices);
}

std::vector<ir::Var> GetAllIterVar() const { return GetOutputIters(); }

ir::Expr* GetStoreValuePointer() const {
return &GetSingleStoreExpr(func_body).As<ir::Store>()->value;
}

std::vector<ir::Var> GetOutputIters() const {
std::vector<ir::Var> vars;
const auto& indices =
GetSingleStoreExpr(func_body).As<ir::Store>()->indices;
std::transform(indices.begin(),
indices.end(),
std::back_inserter(vars),
[](const ir::Expr& expr) { return expr.as_var_ref(); });
return vars;
}

ir::Expr GetFuncBody() const { return func_body; }

ir::Tensor GetOutputTensor() const {
Expand Down Expand Up @@ -313,14 +313,7 @@ struct ReduceOp {
}

std::vector<ir::Var> GetOutputIters() const {
std::vector<ir::Var> vars;
const auto& indices =
GetSingleStoreExpr(func_body).As<ir::Store>()->indices;
std::transform(indices.begin(),
indices.end(),
std::back_inserter(vars),
[](const ir::Expr& expr) { return expr.as_var_ref(); });
return vars;
return ComposeUtils::GetOutputIters(GetSingleStoreExpr(func_body).As<ir::Store>()->indices);
}

ir::Expr GetFuncBody() const { return func_body; }
Expand Down Expand Up @@ -520,31 +513,24 @@ TrivialOp TransformT2R(ReduceOp reduce_upper, TrivialOp trivial_down) {}

bool CheckAllLoopRangeEq(ReduceOp reduce_upper, TrivialOp trivial_down) {}

ir::Expr ReplaceReduceComputeBody(const ir::Expr& body,
const ir::Expr& new_body) {
TODO;
}

ReduceOp TransformReduceLoopRange(ReduceOp upstream, ReduceOp downstream) {
std::vector<ReduceOp> TransformReduceLoopRange(ReduceOp upstream, ReduceOp downstream) {
VLOG(4) << "RRTransform begin";

const auto& down_out_iter = downstream.GetOutputIters();
const auto& up_reduce_iter = upstream.GetReduceIters();
const auto& down_reduce_iter = downstream.GetReduceIters();

// we just support fuse reduce when reduce iter eq
CHECK(ComposeUtils::CheckIterEq(up_reduce_iter, down_reduce_iter));

// TODO modify up_expr, replace out iter of up_expr i => f(i)
ir::Expr new_reduce_body = ir::ir_utils::IRCopy(downstream.GetFuncBody());
ir::Expr reduce_op_expr = ComposeUtils::TransformComputeExpr(
new_reduce_body.GetComputeExpr(), down);
ir::Expr const auto& replaced_tensor = upstream.GetOutputTensor();

ir::Expr result = ComposeUtils::CreateReduceExpr(downstream, reduce_op_expr);
CHECK(ComposeUtils::CheckIterEq(upstream.GetReduceIters(), downstream.GetReduceIters()));
const auto& load_upstream_expr = downstream.GetEachTensorLoadExpr(upstream.GetOutputTensor());
std::vector<ReduceOp> results;
for (const auto& load_tensor : load_upstream_expr){
ir::Expr new_reduce = CreateReduceExpr(
downstream,
ComposeUtils::CopyedReplaceExpr(upstream.GetFuncBody(), upstream.GetOutputIters(), load_tensor.As<ir::Load>()->indices),
upstream.GetInitExpr(),
new_tensor);
ComposeUtils::MappingTargetExprToDestExprMutator(load_tensor.As<ir::Load>()->tensor, new_tensor)(downstream.GetFuncBody());
results.emplace_back(new_reduce);
}

VLOG(4) << "RRTransform end" << result;
return ReduceOp(result);
return results;
}

FusibleOp TrivialFusion(FusionNode* upstream, FusionNode* downstream) {
Expand All @@ -558,14 +544,14 @@ FusibleOp TrivialFusion(FusionNode* upstream, FusionNode* downstream) {
}
}

FusibleOp ReduceTransform(FusionNode* upstream, FusionNode* downstream) {
std::vector<FusibleOp> ReduceTransform(FusionNode* upstream, FusionNode* downstream) {
if (downstream->IsTrivial()) {
CHECK(CheckAllLoopRangeEq(std::get<ReduceOp>(upstream->fusible_op),
std::get<TrivialOp>(upstream->fusible_op)));
return upstream->fusible_op;
std::get<TrivialOp>(downstream->fusible_op)));
return {upstream->fusible_op};
} else {
return TransformReduceLoopRange(std::get<ReduceOp>(upstream->fusible_op),
std::get<ReduceOp>(upstream->fusible_op));
std::get<ReduceOp>(downstream->fusible_op));
}
}

Expand Down Expand Up @@ -702,7 +688,10 @@ struct FusionGraph {
bfs_candidate.pop();
for (const auto& pair_data : downstream->upstream) {
FusionNode* upstream = pair_data.first;
upstream->fusible_op = ReduceTransform(upstream, downstream);
const auto& new_fusible_ops = ReduceTransform(upstream, downstream);

{TODO: update topo structure with multi upstream nodes}

bfs_candidate.push(upstream);
}
}
Expand Down