Skip to content

Commit

Permalink
move ir_copy from namespace optim to ir_utils (PaddlePaddle#57582)
Browse files Browse the repository at this point in the history
  • Loading branch information
Courtesy-Xs authored and jiahy0825 committed Oct 16, 2023
1 parent f11d7e8 commit 2966339
Show file tree
Hide file tree
Showing 24 changed files with 101 additions and 95 deletions.
2 changes: 1 addition & 1 deletion paddle/cinn/auto_schedule/analysis/analyze_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ std::vector<ir::Var> IndicesToVars(const std::vector<ir::Expr>& indices) {
for (const ir::Expr& e : indices) {
// Whether we have to convert other types, like const numbers to Var?
if (e.As<ir::_Var_>() != nullptr) {
ir::Expr copy_e = optim::IRCopy(e);
ir::Expr copy_e = ir::ir_utils::IRCopy(e);
ir::_Var_* var_ref = copy_e.As<ir::_Var_>();
result.emplace_back(ir::Var(var_ref));
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/auto_schedule/cost_model/feature_extractor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ void FeatureExtractor::Visit(const For *x) {
}

void FeatureExtractor::Visit(const PolyFor *x) {
Expr copy = optim::IRCopy(Expr(x));
Expr copy = ir::ir_utils::IRCopy(Expr(x));
feature_.IntoLoopBlock();
optim::TransformPolyForToFor(&copy);
ir::For *loop = copy.As<For>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ ir::IRSchedule MakeIRSchedule(const std::vector<ir::LoweredFunc>& lowered_funcs,
const std::string& task_key) {
std::vector<Expr> exprs;
for (auto&& func : lowered_funcs) {
exprs.emplace_back(optim::IRCopy(func->body));
exprs.emplace_back(ir::ir_utils::IRCopy(func->body));
}
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
task_registry->Regist(task_key, ir::ModuleExpr(exprs));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ std::vector<SearchState> EvolutionarySearch::GetTopKCandidatesFromDatabase(
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
for (auto&& record : records) {
ir::IRSchedule ir_sch(
optim::IRCopy(task_registry->Get(task_key)->module_expr),
ir::ir_utils::IRCopy(task_registry->Get(task_key)->module_expr),
utils::ForkRandomState(&rand_seed_));
ir::ScheduleDesc::ReplayWithProto(record.trace, &ir_sch);
results.emplace_back(SearchState(std::move(ir_sch), record.predicted_cost));
Expand Down Expand Up @@ -181,9 +181,9 @@ SearchState EvolutionarySearch::CrossOver(const SearchState& state1,

for (size_t i = 0; i < father_exprs.size(); ++i) {
if (utils::SampleUniformInt(0, 2, &rand_seed_) == 0) {
cross_over_exprs.push_back(optim::IRCopy(father_exprs[i]));
cross_over_exprs.push_back(ir::ir_utils::IRCopy(father_exprs[i]));
} else {
cross_over_exprs.push_back(optim::IRCopy(mother_exprs[i]));
cross_over_exprs.push_back(ir::ir_utils::IRCopy(mother_exprs[i]));
}
}
auto res = SearchState(ir::IRSchedule(ir::ModuleExpr(cross_over_exprs),
Expand Down Expand Up @@ -217,7 +217,7 @@ SearchState EvolutionarySearch::Mutate(
const auto& task_key = tune_task_.serialized_key;
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
ir::IRSchedule new_ir_sch(
optim::IRCopy(task_registry->Get(task_key)->module_expr),
ir::ir_utils::IRCopy(task_registry->Get(task_key)->module_expr),
utils::ForkRandomState(rand_seed));
new_trace.Replay(&new_ir_sch, true);
ApplyPostScheduleRules(&new_ir_sch, post_schedule_rules_);
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/auto_schedule/task/task_optimizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ TaskOptimizer::Result TaskOptimizer::OptimizeByEvolution(
auto& optimized_funcs = result.functions;
auto& best_cost = result.cost;
// use initial lowered function as default result
optimized_funcs = optim::IRCopy(task_->lowered_funcs);
optimized_funcs = ir::ir_utils::IRCopy(task_->lowered_funcs);
if (options.num_measure_trials ==
0) { // no need to measure and simply return the best searched
std::vector<MeasureInput> measure_candidates;
Expand Down Expand Up @@ -347,7 +347,7 @@ std::vector<SearchState> TaskOptimizer::SearchOneRound(
CHECK_EQ(best_exprs.size(), task_->lowered_funcs.size())
<< "RuntimeError: Expr size is not equal to LoweredFunc size in "
"TaskOptimizer";
auto init_funcs = optim::IRCopy(task_->lowered_funcs);
auto init_funcs = ir::ir_utils::IRCopy(task_->lowered_funcs);
std::vector<ir::LoweredFunc> valid_funcs;
for (size_t j = 0; j < best_exprs.size(); ++j) {
auto updated_f =
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/auto_schedule/task/task_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class InitialTaskRegistry : public Registry<InitialTaskInfo> {
std::lock_guard<std::mutex> guard(registering_mutex);
if (fmap_.count(task_key) == 0) {
InitialTaskInfo* task_info =
new InitialTaskInfo(task_key, optim::IRCopy(module_expr));
new InitialTaskInfo(task_key, ir::ir_utils::IRCopy(module_expr));
__REGISTER__(task_key, task_info);
}
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/backends/codegen_cuda_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
}

Expr CreateDeviceFunctionGivenDeviceKernel(Expr expr) {
auto copied = optim::IRCopy(expr);
auto copied = ir::ir_utils::IRCopy(expr);
auto* lowered_func = copied.as_lowered_func();
lowered_func->name = GenDeviceKernelName(lowered_func->name);
return copied;
Expand Down
16 changes: 8 additions & 8 deletions paddle/cinn/common/cas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1584,7 +1584,7 @@ bool CASasSymbol(Expr expr) {

Expr ConvertCinnToCAS(Expr expr) {
VLOG(7) << "Begin ConvertCinnToCAS " << expr;
Expr copied = optim::IRCopy(expr);
Expr copied = ir::ir_utils::IRCopy(expr);
struct Mutator : public ir::IRMutator<ir::Expr*> {
void operator()(Expr* expr) { Visit(expr); }
void Visit(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
Expand Down Expand Up @@ -1710,7 +1710,7 @@ Expr ConvertCinnToCAS(Expr expr) {
* simplify the condition ensures correctness, though not sufficient.
*/
Expr ReplaceMinToConstant(Expr expr) {
Expr copied = optim::IRCopy(expr);
Expr copied = ir::ir_utils::IRCopy(expr);
struct Mutator : public ir::IRMutator<ir::Expr*> {
void operator()(Expr* expr) { Visit(expr); }
void Visit(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
Expand All @@ -1727,10 +1727,10 @@ Expr ReplaceMinToConstant(Expr expr) {
auto min_b = op->b();
if (min_a.is_constant() && !min_b.is_constant()) {
CHECK(min_a->type().is_integer());
*expr = optim::IRCopy(min_a);
*expr = ir::ir_utils::IRCopy(min_a);
} else if (min_b.is_constant() && !min_a.is_constant()) {
CHECK(min_b->type().is_integer());
*expr = optim::IRCopy(min_b);
*expr = ir::ir_utils::IRCopy(min_b);
}
}
};
Expand All @@ -1743,7 +1743,7 @@ Expr ReplaceMinToConstant(Expr expr) {
* constant value and 1 inconstant value, return the constant max value.
*/
Expr ReplaceMaxToConstant(Expr expr) {
Expr copied = optim::IRCopy(expr);
Expr copied = ir::ir_utils::IRCopy(expr);
struct Mutator : public ir::IRMutator<ir::Expr*> {
void operator()(Expr* expr) { Visit(expr); }
void Visit(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
Expand All @@ -1760,10 +1760,10 @@ Expr ReplaceMaxToConstant(Expr expr) {
auto max_b = op->b();
if (max_a.is_constant() && !max_b.is_constant()) {
CHECK(max_a->type().is_integer());
*expr = optim::IRCopy(max_a);
*expr = ir::ir_utils::IRCopy(max_a);
} else if (max_b.is_constant() && !max_a.is_constant()) {
CHECK(max_b->type().is_integer());
*expr = optim::IRCopy(max_b);
*expr = ir::ir_utils::IRCopy(max_b);
}
}
};
Expand All @@ -1773,7 +1773,7 @@ Expr ReplaceMaxToConstant(Expr expr) {

Expr ConvertCasToCinn(Expr expr) {
VLOG(7) << "Begin ConvertCasToCinn : " << expr;
Expr copied = optim::IRCopy(expr);
Expr copied = ir::ir_utils::IRCopy(expr);

struct Mutator : ir::IRMutator<Expr*> {
void operator()(Expr* expr) { Visit(expr); }
Expand Down
43 changes: 22 additions & 21 deletions paddle/cinn/ir/schedule/ir_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ std::vector<Expr> ScheduleImpl::Split(const Expr& loop,
new_loop_vars.push_back(temp_var);
}
substitute_value = common::AutoSimplify(substitute_value);
Expr new_node = optim::IRCopy(for_node->body);
Expr new_node = ir::ir_utils::IRCopy(for_node->body);
ReplaceExpr(&new_node, {for_node->loop_var}, {substitute_value});
std::vector<Expr> splited_loops;
splited_loops.resize(processed_factors.size());
Expand Down Expand Up @@ -252,7 +252,7 @@ Expr ScheduleImpl::Fuse(const std::vector<Expr>& loops) {
}
substitute_value[0] = fused_expr;

Expr fused_body = optim::IRCopy(for_nodes.back()->body);
Expr fused_body = ir::ir_utils::IRCopy(for_nodes.back()->body);
ReplaceExpr(&fused_body, loop_vars, substitute_value);
optim::Simplify(&fused_body);
Expr fused_extent(1);
Expand Down Expand Up @@ -321,7 +321,7 @@ void ScheduleImpl::MutateForType(const Expr& loop,
<< "loop is not serial, current forloop type is "
<< static_cast<int>(for_node->for_type()) << ", and it cannot become "
<< static_cast<int>(for_type);
auto loop_copy = optim::IRCopy(loop);
auto loop_copy = ir::ir_utils::IRCopy(loop);
auto* new_for_node = loop_copy.As<ir::For>();
CHECK(new_for_node);
new_for_node->set_for_type(for_type);
Expand Down Expand Up @@ -674,7 +674,7 @@ struct RfCreater : public ir::IRMutator<> {
CHECK(root_realize);
auto root_block = root_realize->schedule_block.As<ScheduleBlock>();
CHECK(root_block);
Expr root_loop = optim::IRCopy(root_block->body);
Expr root_loop = ir::ir_utils::IRCopy(root_block->body);
if (auto block = root_loop.As<Block>()) {
CHECK_EQ(block->stmts.size(), 1U)
<< "rfactor root should only have one block stmt";
Expand All @@ -685,13 +685,13 @@ struct RfCreater : public ir::IRMutator<> {
auto rf_for = rf_loop_.As<For>();
CHECK(rf_for);
// create new rfactor forloops
Expr new_rf_forloop = optim::IRCopy(root_loop);
Expr new_rf_forloop = ir::ir_utils::IRCopy(root_loop);
RfMutator rf_mutator(rf_loop_, rf_axis_);
rf_mutator(&new_rf_forloop);
VLOG(3) << "After RfMutator, new rf_forloop is\n" << new_rf_forloop;
auto new_rf_tensor = rf_mutator.GetNewRfTensor();
// create final write-back forloops
Expr final_forloop = optim::IRCopy(root_loop);
Expr final_forloop = ir::ir_utils::IRCopy(root_loop);
FinalMutator final_mutator(rf_loop_, rf_axis_, new_rf_tensor);
final_mutator(&final_forloop);
VLOG(3) << "After FinalMuator, final write-back forloop is\n"
Expand Down Expand Up @@ -721,7 +721,7 @@ struct CacheReadRewriter : public ir::IRMutator<> {
public:
static Expr Rewrite(const Expr& root, CacheBlockInfo* info) {
CacheReadRewriter rewriter(root, info);
Expr new_root = optim::IRCopy(root);
Expr new_root = ir::ir_utils::IRCopy(root);
rewriter(&new_root);
return new_root;
}
Expand Down Expand Up @@ -762,7 +762,7 @@ struct CacheWriteRewriter : public ir::IRMutator<> {
public:
static Expr Rewrite(const Expr& root, CacheBlockInfo* info) {
CacheWriteRewriter rewriter(root, info);
Expr new_root = optim::IRCopy(root);
Expr new_root = ir::ir_utils::IRCopy(root);
rewriter.mutate_cache_block = true;
rewriter(&info->cache_block);
rewriter.mutate_cache_block = false;
Expand Down Expand Up @@ -1194,7 +1194,7 @@ struct LoopReconstructor : public ir::IRMutator<> {
loop_.As<ir::For>()->device_api,
std::move(loop_body));
}
new_loop_ = optim::IRCopy(loop_);
new_loop_ = ir::ir_utils::IRCopy(loop_);

// Replace the copied Tensor object with the original Tensor object,
// to ensure that the same Tensor in a AST is the same object.
Expand Down Expand Up @@ -1431,9 +1431,9 @@ void ScheduleImpl::SimpleComputeAt(const Expr& block, const Expr& loop) {
}

Expr result = loops.size() < block_loops.size()
? optim::IRCopy(block_loops[loops.size()])
: optim::IRCopy(this_block);
Expr new_loop = optim::IRCopy(this_loop);
? ir::ir_utils::IRCopy(block_loops[loops.size()])
: ir::ir_utils::IRCopy(this_block);
Expr new_loop = ir::ir_utils::IRCopy(this_loop);

// Get the body of block_loop under the same loops
auto body = block_loops.at(loops.size() - 1).As<ir::For>()->body;
Expand Down Expand Up @@ -1608,7 +1608,7 @@ void ComputeInliner::Visit(const ir::Load* expr, Expr* op) {
Expr ComputeInliner::ReplaceInlinedTensor(Expr* load) {
CHECK(load->As<ir::Load>());
SetIndexSubstitution(load->As<ir::Load>()->indices);
Expr value_copy = optim::IRCopy(inlined_store_.As<Store>()->value);
Expr value_copy = ir::ir_utils::IRCopy(inlined_store_.As<Store>()->value);
ReplaceExpr(&value_copy, idx_sub_var_, idx_sub_expr_);
return value_copy;
}
Expand Down Expand Up @@ -1684,7 +1684,7 @@ void ReverseComputeInliner::Visit(const ir::Store* expr, Expr* op) {
Expr ReverseComputeInliner::ReplaceInlinedTensor(Expr* load) {
CHECK(load->As<ir::Load>());
SetIndexSubstitution(load->As<ir::Load>()->indices);
Expr value_copy = optim::IRCopy(inlined_store_.As<Store>()->value);
Expr value_copy = ir::ir_utils::IRCopy(inlined_store_.As<Store>()->value);
return value_copy;
}

Expand All @@ -1699,7 +1699,7 @@ Expr ReverseComputeInliner::ReplaceTargetTensor(Expr* store) {
idx_sub_expr_.emplace_back(idx_vars_[i]);
}

Expr value_copy = optim::IRCopy(target_store_);
Expr value_copy = ir::ir_utils::IRCopy(target_store_);
ReplaceExpr(&value_copy, idx_sub_var_, idx_sub_expr_);
return value_copy;
}
Expand Down Expand Up @@ -1936,7 +1936,7 @@ void ScheduleImpl::Annotate(const Expr& block,
CHECK(block.As<ir::ScheduleBlockRealize>());
CHECK(block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>());
auto copied_block = optim::IRCopy(block);
auto copied_block = ir::ir_utils::IRCopy(block);
auto* schedule_block = copied_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>();
schedule_block->attrs.emplace(key, value);
Expand Down Expand Up @@ -2195,7 +2195,7 @@ void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block,
}
CHECK(!used_target_loop_vars.empty());
std::vector<Expr> used_target_loops;
auto expr_copy = optim::IRCopy(expr);
auto expr_copy = ir::ir_utils::IRCopy(expr);
for (auto& var : used_target_loop_vars) {
auto find_loop_var = ir::ir_utils::CollectIRNodesWithoutTensor(
expr_copy,
Expand All @@ -2220,7 +2220,7 @@ void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block,
VLOG(3) << "changed_loop_num is : " << changed_loop_num;
VLOG(3) << "old_iter_values.size() is : " << old_iter_values.size();
if (changed_loop_num >= static_cast<int>(old_iter_values.size())) {
new_loop = optim::IRCopy(block);
new_loop = ir::ir_utils::IRCopy(block);
new_loop.As<ir::ScheduleBlockRealize>()->iter_values = new_iter_values;
} else {
CHECK(old_iter_values[changed_loop_num].as_var());
Expand All @@ -2234,7 +2234,7 @@ void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block,
},
true);
CHECK_EQ(find_partial_loop.size(), 1U);
new_loop = optim::IRCopy(*find_partial_loop.begin());
new_loop = ir::ir_utils::IRCopy(*find_partial_loop.begin());
auto find_schedule_block = ir::ir_utils::CollectIRNodesWithoutTensor(
new_loop,
[&](const Expr* x) { return x->As<ir::ScheduleBlockRealize>(); },
Expand Down Expand Up @@ -2332,13 +2332,14 @@ IRSchedule::IRSchedule(ir::ModuleExpr&& mod_expr,
}

IRSchedule::IRSchedule(const IRSchedule& other)
: impl_(std::make_unique<ScheduleImpl>(optim::IRCopy(other.GetModule()))),
: impl_(std::make_unique<ScheduleImpl>(
ir::ir_utils::IRCopy(other.GetModule()))),
trace_(other.trace_) {
this->InitSeed(other.ForkSeed());
}

IRSchedule& IRSchedule::operator=(const IRSchedule& src) {
impl_ = std::make_unique<ScheduleImpl>(optim::IRCopy(src.GetModule()));
impl_ = std::make_unique<ScheduleImpl>(ir::ir_utils::IRCopy(src.GetModule()));
trace_ = src.trace_;
this->InitSeed(src.ForkSeed());
return *this;
Expand Down
18 changes: 9 additions & 9 deletions paddle/cinn/ir/schedule/ir_schedule_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,8 @@ IterRange GetAccessedRange(const Expr& index,
var_maxs.emplace_back(range.min + range.extent - 1);
}

Expr indice_min = optim::IRCopy(index);
Expr indice_max = optim::IRCopy(index);
Expr indice_min = ir::ir_utils::IRCopy(index);
Expr indice_max = ir::ir_utils::IRCopy(index);
// replace the var by the corresponding iter_value
ReplaceExpr(&indice_min, iter_vars, var_mins);
ReplaceExpr(&indice_max, iter_vars, var_maxs);
Expand Down Expand Up @@ -408,7 +408,7 @@ std::vector<IterRange> CalculateTensorRegions(

std::vector<IterRange> result;
for (int i = 0; i < tensor_indices.size(); ++i) {
Expr binded_index = optim::IRCopy(tensor_indices[i]);
Expr binded_index = ir::ir_utils::IRCopy(tensor_indices[i]);
ReplaceExpr(&binded_index, iter_vars, iter_values);
auto range = GetAccessedRange(binded_index, loop_vars, loop_ranges);

Expand Down Expand Up @@ -656,7 +656,7 @@ Expr ConstructOtherStmtChain(const std::vector<Expr>& stmts,
const std::vector<int> reordered_indices) {
Expr new_loop;
for (int i = reordered_indices.size() - 1; i >= 0; --i) {
Expr temp = optim::IRCopy(loops[reordered_indices[i]]);
Expr temp = ir::ir_utils::IRCopy(loops[reordered_indices[i]]);
CHECK(temp.defined());
CHECK(temp.As<ir::For>());
if (new_loop.defined()) {
Expand Down Expand Up @@ -695,10 +695,10 @@ Expr ConstructNewLoopChain(const std::vector<Expr>& chain,
Expr temp;
if (loop_set.count(loop_in_chain)) {
CHECK_GE(index, 0);
temp = optim::IRCopy(ordered_loops[index]);
temp = ir::ir_utils::IRCopy(ordered_loops[index]);
--index;
} else {
temp = optim::IRCopy(loop_in_chain);
temp = ir::ir_utils::IRCopy(loop_in_chain);
}
CHECK(temp.defined());
CHECK(temp.As<ir::For>());
Expand Down Expand Up @@ -1029,9 +1029,9 @@ std::vector<IterRange> CalculateRequiredRegions(
for (const Expr& req_block : required_blocks) {
CHECK(req_block.As<ir::ScheduleBlockRealize>());
Expr block_body =
optim::IRCopy(req_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->body);
ir::ir_utils::IRCopy(req_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->body);
auto iter_vars = req_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->iter_vars;
Expand Down
7 changes: 4 additions & 3 deletions paddle/cinn/ir/test/ir_copy_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
#include "paddle/cinn/ir/utils/ir_printer.h"

namespace cinn {
namespace optim {
namespace ir {
namespace ir_utils {

TEST(IrCopy, basic) {
Expr a(1.f);
auto aa = IRCopy(a);
LOG(INFO) << "aa " << aa;
}

} // namespace optim
} // namespace ir_utils
} // namespace ir
} // namespace cinn
Loading

0 comments on commit 2966339

Please sign in to comment.