Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

[Hackathon NO.84] 为神经网络编译器 CINN 增加 ReverseComputeInline 原语 #1331

Merged
merged 8 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions cinn/backends/ir_schedule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,7 @@ function test_bind (_A, _B)
}
)ROC"));
}

TEST(IrSchedule, simple_compute_at) {
Context::Global().ResetNameId();
Expr M(128);
Expand Down Expand Up @@ -2497,6 +2498,70 @@ void test_compute_inline4(const float* __restrict__ A, float* __restrict__ C)
}
#endif

TEST(IrSchedule, reverse_compute_inline1) {
Context::Global().ResetNameId();
Expr M(32);
Expr N(32);
Expr P(32);

Target target = common::DefaultHostTarget();

Placeholder<float> A("A", {M, N, P});
auto B = Compute(
{M, N, P}, [&](Var i, Var j, Var k) { return A(i, j, k) + Expr(1.f); }, "B");
auto C = Compute(
{M, N, P}, [&](Var i, Var j, Var k) { return B(j, i, k) * Expr(2.f); }, "C");

zrr1999 marked this conversation as resolved.
Show resolved Hide resolved
auto stages = CreateStages({A, B, C});

auto func = cinn::lang::LowerVec("test_compute_inline1", stages, {A, C}, {}, {}, nullptr, target, true);

auto ast_expr = func[0]->body;
std::vector<Expr> vec_ast{ast_expr};
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);

auto block_c = ir_sch.GetBlock("C");
ir_sch.ReverseComputeInline(block_c);
Module::Builder builder("module1", target);
for (auto& i : func) {
builder.AddFunction(i);
}
auto module = builder.Build();
CodeGenC codegen(target);
codegen.SetInlineBuiltinCodes(false);
auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl);

VLOG(1) << "compute_inline1 source code is :\n" << source_code;

std::string target_code = R"ROC(
#include <cinn_runtime.h>
#include <stdio.h>

void test_compute_inline1(void* _args, int32_t num_args)
{
const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0]));
cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1]));
cinn_buffer_t* _B = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 32, 32, 32 });
cinn_buffer_malloc((void*)(0), _C);
cinn_buffer_malloc((void*)(0), _B);
const float* A = ((const float*)(_A->memory));
float* B = ((float*)(_B->memory));
float* C = ((float*)(_C->memory));
for (int32_t i = 0; i < 32; i += 1) {
for (int32_t j = 0; j < 32; j += 1) {
for (int32_t k = 0; k < 32; k += 1) {
C[((1024 * i) + ((32 * j) + k))] = fma(2.00000000f, A[((32 * i) + ((1024 * j) + k))], 2.00000000f);
BiynXu marked this conversation as resolved.
Show resolved Hide resolved
};
};
};
cinn_buffer_free((void*)(0), _B);
cinn_buffer_free((void*)(0), _C);
}
)ROC";
ASSERT_EQ(utils::Trim(target_code), utils::Trim(source_code));
}

TEST(IrSchedule, copytransform1) {
Context::Global().ResetNameId();
Expr M(32);
Expand Down
75 changes: 74 additions & 1 deletion cinn/ir/ir_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ namespace cinn {
namespace ir {

/**
* A struct helps to implment Schedule primitives.
* A struct helps to implement Schedule primitives.
*/
class ScheduleImpl {
public:
Expand Down Expand Up @@ -96,6 +96,7 @@ class ScheduleImpl {
void Vectorize(const Expr& loop, int factor);
void Unroll(const Expr& loop);
void ComputeInline(const Expr& schedule_block);
void ReverseComputeInline(const Expr& schedule_block);
void Bind(const Expr& loop, const std::string& thread_axis);
Expr Rfactor(const Expr& rf_loop, int rf_axis);
Expr AddUnitLoop(const Expr& block) const;
Expand Down Expand Up @@ -1371,6 +1372,72 @@ void ComputeInlineChecker::BuildDataDependency() {
ir_schedule_.SyncThreads(loops.back(), true);
}

bool ReverseComputeInliner::BodyPatternAllowInline() {
if (!inlined_store_.defined()) {
return false;
}
CHECK(inlined_store_.As<Store>());
auto find_vars = ir::CollectIRNodesWithoutTensor(inlined_store_, [&](const Expr* x) { return x->as_var(); });
std::set<Var, CompVar> vars_set;
for (auto& i : find_vars) vars_set.insert(i.as_var_ref());
int n_vars = vars_set.size();
if (!UpdateAndCheckIndexVars(inlined_store_.As<Store>()->indices, n_vars)) {
return false;
}
return true;
}

void ReverseComputeInliner::Visit(const ir::Load* expr, Expr* op) {
LOG(INFO) << (expr->tensor).as_tensor_ref()->name;
LOG(INFO) << inlined_tensor_->name;
if ((expr->tensor).as_tensor_ref()->name == inlined_tensor_->name) {
*op = ReplaceInlinedTensor(op);
return;
}
IRMutator::Visit(expr, op);
}

void ReverseComputeInliner::Visit(const ir::Store* expr, Expr* op) {
LOG(INFO) << (expr->tensor).as_tensor_ref()->name;
LOG(INFO) << inlined_tensor_->name;
if ((expr->tensor).as_tensor_ref()->name == inlined_tensor_->name) {
LOG(INFO) << (expr->tensor).as_tensor_ref()->name;
*op = target_store_;
return;
}
IRMutator::Visit(expr, op);
}

//! Replace the 'Load' node on the tensor to 'Load' node of its producers.
Expr ReverseComputeInliner::ReplaceInlinedTensor(Expr* load) {
CHECK(load->As<ir::Load>());
SetIndexSubstitution(load->As<ir::Load>()->indices);
Expr value_copy = optim::IRCopy(inlined_store_.As<Store>()->value);
ReplaceExpr(&value_copy, idx_sub_var_, idx_sub_expr_);
return value_copy;
}

void ScheduleImpl::ReverseComputeInline(const Expr& schedule_block) {
Expr root = this->GetRootBlock(schedule_block);
Expr inlined_store = CheckReverseComputeInlineValidationAndGetStore(schedule_block, root);

CHECK(schedule_block.As<ir::ScheduleBlockRealize>());
auto compute_body = schedule_block.As<ir::ScheduleBlockRealize>()->schedule_block.As<ir::ScheduleBlock>()->body;

auto find_store = ir::CollectIRNodesWithoutTensor(
compute_body, [&](const Expr* x) { return x->As<ir::Store>(); }, true);
CHECK_EQ(find_store.size(), 1U);
auto target_store = *find_store.begin();

ReverseComputeInliner inliner(inlined_store.As<ir::Store>()->tensor.as_tensor_ref(), inlined_store, target_store);
CHECK(inliner.BodyPatternAllowInline());
// Create a plan that removes the block to be inlined
LeafBlockRemovalPlan remove_plan(schedule_block, &inliner.src_stmt, &inliner.tgt_stmt);
remove_plan(&root);
inliner(&root);
inliner(&root);
}

struct FindBlockParent : public ir::IRMutator<> {
public:
FindBlockParent(const std::string& block_name) : block_name_(block_name) {}
Expand Down Expand Up @@ -2104,6 +2171,12 @@ void IRSchedule::ComputeInline(const Expr& schedule_block) {
trace_.Append(ScheduleDesc::Step("ComputeInline", {{"schedule_block", std::vector<Expr>({schedule_block})}}, {}, {}));
}

void IRSchedule::ReverseComputeInline(const Expr& schedule_block) {
impl_->ReverseComputeInline(schedule_block);
trace_.Append(
ScheduleDesc::Step("ReverseComputeInline", {{"schedule_block", std::vector<Expr>({schedule_block})}}, {}, {}));
}

void IRSchedule::Bind(const Expr& loop, const std::string& thread_axis) {
impl_->Bind(loop, thread_axis);
trace_.Append(ScheduleDesc::Step("Bind", {{"loop", std::vector<Expr>({loop})}}, {{"thread_axis", thread_axis}}, {}));
Expand Down
31 changes: 31 additions & 0 deletions cinn/ir/ir_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,12 @@ class IRSchedule {
*/
void ComputeInline(const Expr& schedule_block);

/**
* \brief Mark an schedule block as inlined.
zrr1999 marked this conversation as resolved.
Show resolved Hide resolved
* @param schedule_block the schedule block to be inlined.
*/
void ReverseComputeInline(const Expr& schedule_block);

/**
* \brief Bind the loop to the given thread axis.
* @param loop the loop to Bind.
Expand Down Expand Up @@ -474,6 +480,31 @@ class ComputeInliner : public BaseInliner {
Expr ReplaceInlinedTensor(Expr* load);
};

/*!
* \brief Helper to inline its producer(s) block into the consumer
zrr1999 marked this conversation as resolved.
Show resolved Hide resolved
* The derived class implements the following functionalities:
* 1) Substitute `Load` on the tensor to be inlined
* to its value calculation in the producer block
* 2) Analyze the producer block to determine the remapping of index variables
*/
class ReverseComputeInliner : public BaseInliner {
public:
explicit ReverseComputeInliner(const Tensor& inlined_tensor, const Expr& inlined_store, const Expr& target_store)
: BaseInliner(inlined_tensor, inlined_store), target_store_(target_store) {}

bool BodyPatternAllowInline();

protected:
Expr target_store_{nullptr};

private:
void Visit(const ir::Load* expr, Expr* op) override;
void Visit(const ir::Store* expr, Expr* op) override;

//! Replace the 'Load' node on the tensor to 'Store' node of its consumers.
Expr ReplaceInlinedTensor(Expr* load);
};

// The struct used to remove the original block in ComputeAt.
class LeafBlockRemovalPlan : public ir::IRMutator<> {
public:
Expand Down
29 changes: 29 additions & 0 deletions cinn/ir/ir_schedule_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,35 @@ Expr CheckComputeInlineValidationAndGetStore(const Expr& schedule_block, const E
return (*find_store.begin());
}

Expr CheckReverseComputeInlineValidationAndGetStore(const Expr& schedule_block, const Expr& root) {
CHECK(schedule_block.As<ir::ScheduleBlockRealize>());
auto compute_body = schedule_block.As<ir::ScheduleBlockRealize>()->schedule_block.As<ir::ScheduleBlock>()->body;
// 1. Check the schedule block to be reverse inlined is not a reduce tensor.
auto find_load = ir::CollectIRNodesWithoutTensor(
compute_body, [&](const Expr* x) { return x->As<ir::Load>(); }, true);
CHECK_EQ(find_load.size(), 1U);
Expr tensor = (*find_load.begin()).As<ir::Load>()->tensor;
CHECK(!tensor.as_tensor_ref()->is_reduce_tensor());
// 2. Check this schedule block is the only reader of the tensor.
find_load = ir::CollectIRNodesWithoutTensor(
root,
[&](const Expr* x) {
return x->As<ir::Load>() && (x->As<ir::Load>()->tensor).as_tensor_ref()->name == tensor.as_tensor_ref()->name;
},
true);
CHECK_EQ(find_load.size(), 1U);
// 3. Check there is no overlap between the buffers the schedule block reads and writes.
auto find_store = ir::CollectIRNodesWithoutTensor(
compute_body, [&](const Expr* x) { return x->As<ir::Store>() && x->As<ir::Store>()->tensor == tensor; });
CHECK(find_store.empty());
// 4. Get store that will be inlined.
auto find_inlined_store = ir::CollectIRNodesWithoutTensor(
root, [&](const Expr* x) { return x->As<ir::Store>() && x->As<ir::Store>()->tensor == tensor; });
CHECK_EQ(find_inlined_store.size(), 1U);
auto inlined_store = *find_inlined_store.begin();
return inlined_store;
}

bool ContainVar(const std::vector<Expr>& exprs, const std::string& var_name) {
for (auto& expr : exprs) {
auto find_expr = ir::CollectIRNodesWithoutTensor(
Expand Down
2 changes: 2 additions & 0 deletions cinn/ir/ir_schedule_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,8 @@ std::vector<IterRange> CalculateRequiredRegions(const Expr& block,

Expr CheckComputeInlineValidationAndGetStore(const Expr& schedule_block, const Expr& root);

Expr CheckReverseComputeInlineValidationAndGetStore(const Expr& schedule_block, const Expr& root);
BiynXu marked this conversation as resolved.
Show resolved Hide resolved

zrr1999 marked this conversation as resolved.
Show resolved Hide resolved
/*!
* \brief Get the prime factors of a number.
* For example, 12 = 2^2 * 3^1, then the return value is {2: 2, 3: 1}.
Expand Down
4 changes: 4 additions & 0 deletions cinn/ir/schedule_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,10 @@ CINN_BUILD_STEP_KIND(ComputeInline)
.Inputs({"schedule_block"})
.SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::ComputeInline)));

CINN_BUILD_STEP_KIND(ReverseComputeInline)
.Inputs({"schedule_block"})
.SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::ReverseComputeInline)));

CINN_BUILD_STEP_KIND(Bind)
.Inputs({"loop"})
.Attrs({"thread_axis"})
Expand Down
11 changes: 11 additions & 0 deletions cinn/ir/schedule_desc_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,17 @@ TEST_F(TestScheduleDesc, StepKind_ComputeInline) {
CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_ReverseComputeInline) {
lowered_funcs = LowerCompute({32, 32, 32}, target, true, "elementwise-add_const");
ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);
auto block_c = ir_sch.GetBlock("C");
trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("C")}}, {block_c}));
ir_sch.ReverseComputeInline(block_c);
trace.Append(ScheduleDesc::Step("ReverseComputeInline", {{"schedule_block", std::vector<Expr>({block_c})}}, {}, {}));
CheckReplayResult(ir_sch, trace);
CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_Bind) {
lowered_funcs = LowerCompute({32, 128}, target);
ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);
Expand Down