Skip to content

Commit

Permalink
add CopyTransform and enhance opfusion (PaddlePaddle#353)
Browse files Browse the repository at this point in the history
* add CopyTransform and enhance opfusion
  • Loading branch information
haozech authored Mar 4, 2021
1 parent 313d6af commit 270bed6
Show file tree
Hide file tree
Showing 24 changed files with 238 additions and 52 deletions.
5 changes: 3 additions & 2 deletions cinn/backends/codegen_cuda_dev.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ std::vector<Expr> CodeGenCUDA_Dev::GenerateBufferAliasExprs(const ir::_LoweredFu
std::set<ir::Buffer> temp_buffer_set(temp_buffers.begin(), temp_buffers.end());
// prepare temp buffer alias
std::vector<Expr> buffer_alias;
auto tensors = ir::CollectIRNodes(
op->body, [&](const Expr *x) { return x->as_tensor() && temp_buffer_set.count(x->as_tensor()->buffer); });
auto tensors = ir::CollectIRNodes(op->body, [&](const Expr *x) {
return x->as_tensor() && x->as_tensor()->buffer.defined() && temp_buffer_set.count(x->as_tensor()->buffer);
});

// unique tensors
std::set<ir::Tensor> unique_tensors;
Expand Down
3 changes: 2 additions & 1 deletion cinn/backends/llvm/execution_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,9 @@ void ExecutionEngine::Link(const ir::Module &module) {
auto m = llvm::parseAssemblyString(AsStringRef(backends::kRuntimeLlvmIr), error, *ctx);
auto b = std::make_unique<llvm::IRBuilder<>>(*ctx);
auto ir_emitter = std::make_unique<CodeGenT>(m.get(), b.get());
VLOG(3) << "ir_emitter->Compile(module) Begin";
ir_emitter->Compile(module);

VLOG(3) << "ir_emitter->Compile(module) Succeed!";
CHECK(!llvm::verifyModule(*m, &llvm::errs())) << "Invalid module found";

auto machine =
Expand Down
4 changes: 3 additions & 1 deletion cinn/frontend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ void Interpreter::Impl::Build(const std::vector<std::string>& input_names,
auto graph = std::make_shared<hlir::framework::Graph>(*program_);

hlir::framework::ApplyPass(graph.get(), "InferShape");
hlir::framework::ApplyPass(graph.get(), "OpFusion");
if (target.arch == Target::Arch::NVGPU) {
hlir::framework::ApplyPass(graph.get(), "OpFusion");
}
// Target target = common::DefaultHostTarget();
scope_ = hlir::framework::BuildScope(target, graph, scope_);
graph_compiler_.reset(new hlir::framework::GraphCompiler(target, scope_, graph));
Expand Down
35 changes: 31 additions & 4 deletions cinn/hlir/framework/graph_compiler.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "cinn/backends/codegen_cuda_dev.h"
#include "cinn/hlir/framework/instruction.h"
#include "cinn/hlir/framework/tensor.h"
#include "cinn/hlir/pe/schedule.h"
#include "cinn/poly/stage.h"

namespace cinn {
namespace hlir {
Expand Down Expand Up @@ -94,10 +96,24 @@ std::vector<std::unique_ptr<Instruction>> GraphCompiler::BuildInstructions() {
auto* node = nodes[i]->safe_as<Node>();
if (node && node->attrs.attr_store.count("FuseNumber") > 0) {
int fuse_number = std::get<int>(node->attrs.attr_store["FuseNumber"]);
auto* end_node = nodes[i + 2 * fuse_number - 2]->safe_as<Node>();
auto instr = std::unique_ptr<Instruction>(
new Instruction(target_, scope_.get(), OpGetInputNames(node), OpGetOutputNames(end_node)));
auto* fn = compiler_->Lookup(GenOpFuncName(node) + "_fused");
std::vector<std::string> inputNames;
std::vector<std::string> outputNames;
for (int j = 0; j < fuse_number; j++) {
auto* temp_node = nodes[i + 2 * j]->safe_as<Node>();
CHECK(temp_node);
auto temp_inputnames = OpGetInputNames(temp_node);
if (j == 0) {
inputNames.insert(inputNames.end(), temp_inputnames.begin(), temp_inputnames.end());
} else {
inputNames.insert(inputNames.end(), temp_inputnames.begin() + 1, temp_inputnames.end());
}
if (j == fuse_number - 1) {
auto temp_outputnames = OpGetOutputNames(temp_node);
outputNames.insert(outputNames.end(), temp_outputnames.begin(), temp_outputnames.end());
}
}
auto instr = std::unique_ptr<Instruction>(new Instruction(target_, scope_.get(), inputNames, outputNames));
auto* fn = compiler_->Lookup(GenOpFuncName(node) + "_fused");
CHECK(fn);
instr->SetLoweredFunc(fn);
instructions.push_back(std::move(instr));
Expand Down Expand Up @@ -166,6 +182,7 @@ ir::LoweredFunc GraphCompiler::GetOpFunc(const std::vector<Node*>& nodes) {
auto& strategy = Operator::GetAttrs<StrategyFunction>("CINNStrategy");
auto& shape_dict = graph_->GetAttrs<std::unordered_map<std::string, shape_t>>("infershape");
auto& dtype_dict = graph_->GetAttrs<std::unordered_map<std::string, Type>>("inferdtype");
VLOG(2) << "GetOpFunc of fused op " << nodes[0]->id();
std::vector<ir::Tensor> inputs;
poly::StageMap stages;
std::vector<int> init_shape{1};
Expand Down Expand Up @@ -222,14 +239,24 @@ ir::LoweredFunc GraphCompiler::GetOpFunc(const std::vector<Node*>& nodes) {
stages->InsertLazily(temp.as_tensor_ref(), temp_stages[temp.as_tensor_ref()]);
if (index < fuse_number - 1 && !temp.as_tensor_ref()->is_reduce_tensor()) {
stages[temp.as_tensor_ref()]->ComputeInline();
} else if (index < fuse_number - 1 && temp.as_tensor_ref()->is_reduce_tensor()) {
temp.as_tensor_ref()->WithBuffer("local", "_" + temp.as_tensor_ref()->name + "_temp_buffer");
stages[temp.as_tensor_ref()]->SetScope(poly::ScopeKind::kLocal);
} else {
inputs.push_back(temp.as_tensor_ref());
}
}
index++;
}

for (auto& s : stages) {
if (s.second->tensor()->is_reduce_tensor()) {
stages[inputs.back()]->CopyTransform(s.second->transform(), s.second->domain());
stages[inputs.back()]->CopyLoopInfo(s.second->forloop_infos(), s.second->transform());
}
}
auto func = Lower(GenOpFuncName(nodes[0]) + "_fused", stages, inputs, {}, {}, nullptr, this->target_);
VLOG(3) << "The function of fused node [" << func->name << "] is:\n" << func;
return func;
}

Expand Down
27 changes: 11 additions & 16 deletions cinn/hlir/op/nn.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,10 @@ std::shared_ptr<OpStrategy> StrategyForConv2d(const framework::NodeAttr &attrs,
if (target.arch == Target::Arch::NVGPU) {
Expr Out = arg_pack[2];
CHECK(Out.as_tensor());
// pe::CudaScheduleConv(stages, input_pad.as_tensor_ref(), weights_dilation.as_tensor_ref(), Out.as_tensor_ref(),
// target);
stages[Out.as_tensor_ref()]->Split(1, 2);
stages[Out.as_tensor_ref()]->Bind(0, "blockIdx.x");
stages[Out.as_tensor_ref()]->Bind(1, "threadIdx.x");
stages[Out.as_tensor_ref()]->Bind(1, "blockIdx.y");
stages[Out.as_tensor_ref()]->Bind(2, "blockIdx.z");
stages[Out.as_tensor_ref()]->Bind(3, "threadIdx.x");
}
*ret = CINNValuePack{{arg_pack[2], CINNValue(stages)}};
});
Expand Down Expand Up @@ -350,9 +349,10 @@ std::shared_ptr<OpStrategy> StrategyForDepthwiseConv2d(const framework::NodeAttr
stages[input_pad.as_tensor_ref()]->ComputeInline();
}
if (target.arch == Target::Arch::NVGPU) {
stages[Out.as_tensor_ref()]->Split(1, 2);
stages[Out.as_tensor_ref()]->Bind(0, "blockIdx.x");
stages[Out.as_tensor_ref()]->Bind(1, "threadIdx.x");
stages[Out.as_tensor_ref()]->Bind(1, "blockIdx.y");
stages[Out.as_tensor_ref()]->Bind(2, "blockIdx.z");
stages[Out.as_tensor_ref()]->Bind(3, "threadIdx.x");
}

*ret = CINNValuePack{{CINNValue(Out), CINNValue(stages)}};
Expand Down Expand Up @@ -1172,8 +1172,7 @@ std::shared_ptr<OpStrategy> StrategyForSlice(const framework::NodeAttr &attrs,
Expr Out = arg_pack[0];
poly::StageMap stages = arg_pack[1];
CHECK(Out.as_tensor());
stages[Out.as_tensor_ref()]->Bind(0, "blockIdx.x");
stages[Out.as_tensor_ref()]->Bind(1, "threadIdx.x");
pe::CudaScheduleInjective(stages[Out.as_tensor_ref()], output_shapes.back(), target);
}
*ret = arg_pack;
});
Expand Down Expand Up @@ -1378,8 +1377,7 @@ CINN_REGISTER_HELPER(nn_ops) {
.set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForPool1d)
.set_attr("infershape", std::function(cinn::hlir::op::InferShapeForPool1d))
.set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForPool))
.set_attr<cinn::hlir::framework::OpPatternKind>("OpPattern",
cinn::hlir::framework::OpPatternKind::kOutEWiseFusable)
.set_attr<cinn::hlir::framework::OpPatternKind>("OpPattern", cinn::hlir::framework::OpPatternKind::kOpaque)
.set_support_level(4);

CINN_REGISTER_OP(pool2d)
Expand All @@ -1389,8 +1387,7 @@ CINN_REGISTER_HELPER(nn_ops) {
.set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForPool2d)
.set_attr("infershape", std::function(cinn::hlir::op::InferShapeForPool2d))
.set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForPool))
.set_attr<cinn::hlir::framework::OpPatternKind>("OpPattern",
cinn::hlir::framework::OpPatternKind::kOutEWiseFusable)
.set_attr<cinn::hlir::framework::OpPatternKind>("OpPattern", cinn::hlir::framework::OpPatternKind::kOpaque)
.set_support_level(4);

CINN_REGISTER_OP(pool3d)
Expand All @@ -1400,8 +1397,7 @@ CINN_REGISTER_HELPER(nn_ops) {
.set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForPool3d)
.set_attr("infershape", std::function(cinn::hlir::op::InferShapeForPool3d))
.set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForPool))
.set_attr<cinn::hlir::framework::OpPatternKind>("OpPattern",
cinn::hlir::framework::OpPatternKind::kOutEWiseFusable)
.set_attr<cinn::hlir::framework::OpPatternKind>("OpPattern", cinn::hlir::framework::OpPatternKind::kOpaque)
.set_support_level(4);

CINN_REGISTER_OP(sigmoid)
Expand Down Expand Up @@ -1431,8 +1427,7 @@ CINN_REGISTER_HELPER(nn_ops) {
.set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForSlice)
.set_attr("infershape", std::function(cinn::hlir::op::InferShapeForSlice))
.set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForSlice))
.set_attr<cinn::hlir::framework::OpPatternKind>("OpPattern",
cinn::hlir::framework::OpPatternKind::kOutEWiseFusable)
.set_attr<cinn::hlir::framework::OpPatternKind>("OpPattern", cinn::hlir::framework::OpPatternKind::kOpaque)
.set_support_level(4);

CINN_REGISTER_OP(dropout_infer)
Expand Down
2 changes: 1 addition & 1 deletion cinn/hlir/pass/opfusion.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ void OpFusionPass(Graph* graph) {
auto node = store_nodes[i]->safe_as<Node>();
if (node) {
auto op_pattern = op_pattern_dict[node->op()];
if (op_pattern <= framework::kInjective) {
if (op_pattern <= framework::kOutEWiseFusable) {
int fuse_number = 1;
while (i + 2 < store_nodes.size() && store_nodes[i + 2]->safe_as<Node>()) {
auto temp_node = store_nodes[i + 2]->safe_as<Node>();
Expand Down
Empty file modified cinn/hlir/pass/use_pass.h
100644 → 100755
Empty file.
Empty file modified cinn/ir/ir_mutator.h
100644 → 100755
Empty file.
6 changes: 3 additions & 3 deletions cinn/ir/lowered_func.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ void _LoweredFunc_::PrepareAllocOutputBufferExprs() {
}

std::vector<Expr> _LoweredFunc_::PrepareAllocTempBufferExprs() const {
std::vector<Expr> alloc_output_buffer_exprs;
std::vector<Expr> alloc_temp_buffer_exprs;
for (auto& temp_buf : temp_bufs) {
if (!temp_buf->shape.empty() && temp_buf->type() != Void()) {
alloc_output_buffer_exprs.push_back(Alloc::Make(temp_buf, temp_buf->type(), temp_buf->shape, Expr(), Expr()));
alloc_temp_buffer_exprs.push_back(Alloc::Make(temp_buf, temp_buf->type(), temp_buf->shape, Expr(), Expr()));
}
}
return alloc_output_buffer_exprs;
return alloc_temp_buffer_exprs;
}

std::vector<Expr> _LoweredFunc_::CudaPrepareAllocTempBufferExprs() const {
Expand Down
8 changes: 3 additions & 5 deletions cinn/ir/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,7 @@ ir::Tensor _Tensor_::InitReduction(poly::StageMap stages, const Target &target)
shape, [=](const std::vector<Expr> &axis) { return GetReduceInitVal(); }, init_reduce_tensor_name);
stages->InsertLazily(init_tensor);
if (target.arch == Target::Arch::NVGPU) {
if (init_tensor->shape.size() > 1) {
stages[init_tensor]->Split(1, 2);
}
stages[init_tensor]->CopyTransform(stages[this]->transform(), stages[this]->domain());
stages[init_tensor]->ComputeAt2(stages[this], stages[init_tensor]->axis_names().size() - 1);
}
stages[this]->CtrlDepend(init_tensor);
Expand Down Expand Up @@ -311,9 +309,9 @@ void _Tensor_::WithBuffer(const Type &type) {
Bind(buf);
}

void _Tensor_::WithBuffer(const std::string &memory_type, const Type &type) {
void _Tensor_::WithBuffer(const std::string &memory_type, const std::string &buffer_name, const Type &type) {
Type buf_type = type.is_void() ? type_ : type;
lang::Buffer buf(buf_type);
lang::Buffer buf(buf_type, buffer_name);
buf->target = common::DefaultHostTarget();
Bind(buf);

Expand Down
2 changes: 1 addition & 1 deletion cinn/ir/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ class _Tensor_ : public ExprNode<_Tensor_> {

//! Create a buffer belong to this tensor.
void WithBuffer(const Type& type = Void());
void WithBuffer(const std::string& memory_type, const Type& type = Void());
void WithBuffer(const std::string& memory_type, const std::string& buffer_name = "", const Type& type = Void());
Tensor GetInitTensor(poly::StageMap stages, const Target& target = common::DefaultHostTarget()) const;

private:
Expand Down
Empty file modified cinn/lang/lower.cc
100644 → 100755
Empty file.
2 changes: 1 addition & 1 deletion cinn/lang/lower_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ ir::LoweredFunc LowerImpl::operator()() {
auto func = ir::_LoweredFunc_::Make(fn_name_, func_args, func_body, temp_buffers);

// some necessary modification.
optim::ComputeInlineExpand(&func->body, stages_);
optim::ComputeInlineExpand(&func->body, stages_, &all_tensor_map);
Target target = cuda_axis_info_.valid() ? common::DefaultNVGPUTarget() : common::DefaultHostTarget();
auto res = optim::Optimize(func, target, FLAGS_cinn_runtime_display_debug_info);

Expand Down
40 changes: 35 additions & 5 deletions cinn/optim/compute_inline_expand.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,49 @@ namespace {
*/
struct TensorInlineExpandMutator : public ir::IRMutator<> {
const std::string &tensor_name;
std::map<std::string, ir::Tensor> *all_tensor_map_;
bool inline_code{false};
bool temp_buffer{false};
bool memory_local{false};

TensorInlineExpandMutator(const std::string &tensor_name) : tensor_name(tensor_name) {}
TensorInlineExpandMutator(const std::string &tensor_name, std::map<std::string, ir::Tensor> *all_tensor_map)
: tensor_name(tensor_name), all_tensor_map_(all_tensor_map) {}

void operator()(Expr *expr) { ir::IRMutator<>::Visit(expr, expr); }

void Visit(const ir::_Var_ *expr, Expr *op) override {
if (inline_code && temp_buffer) {
if (utils::Startswith(expr->name, "blockIdx") || (utils::Startswith(expr->name, "threadIdx") && memory_local)) {
*op = ir::Expr(0);
}
}
}

void Visit(const ir::Load *op, Expr *expr) override {
auto *node = expr->As<ir::Load>();
auto *tensor = node->tensor.as_tensor();
if (tensor && tensor->name == tensor_name) {
*expr = tensor->inline_expanded(op->indices);
*expr = tensor->inline_expanded(op->indices);
inline_code = true;
ir::IRMutator<>::Visit(expr, expr);
inline_code = false;
} else if (inline_code && tensor->buffer.defined() &&
(utils::Endswith(tensor->buffer->name, "_read_cache") ||
utils::Endswith(tensor->buffer->name, "_cache_write_out") ||
utils::Endswith(tensor->buffer->name, "_temp_buffer"))) {
bool keep_buffer = temp_buffer;
temp_buffer = true;
bool keep_memory_local = memory_local;
if ((*all_tensor_map_).at(tensor->name)->buffer->memory_type == ir::MemoryType::GPULocal) {
memory_local = true;
}
ir::IRMutator<>::Visit(&node->tensor, &node->tensor);
for (auto &idx : node->indices) ir::IRMutator<>::Visit(&idx, &idx);
temp_buffer = keep_buffer;
memory_local = keep_memory_local;
} else {
ir::IRMutator<>::Visit(&node->tensor, &node->tensor);
for (auto &idx : node->indices) ir::IRMutator<>::Visit(&idx, &idx);
}
}
};
Expand Down Expand Up @@ -52,7 +83,6 @@ struct SSABuilder : public ir::IRMutator<> {
}

void Visit(const ir::Store *op, Expr *expr) override {
LOG(INFO) << "Expr: " << *expr;
auto *node = expr->As<ir::Store>();

auto *cur_graph_node = graph.RetriveNode(node->tensor.as_tensor()->name);
Expand All @@ -74,7 +104,7 @@ struct SSABuilder : public ir::IRMutator<> {

} // namespace

void ComputeInlineExpand(Expr *expr, poly::StageMap stages) {
void ComputeInlineExpand(Expr *expr, poly::StageMap stages, std::map<std::string, ir::Tensor> *all_tensor_map) {
// the inline tensors contained in the expression.
auto inline_tensors =
ir::CollectIRNodes(*expr, [&](const Expr *x) { return x->as_tensor() && stages[x->as_tensor()]->inlined(); });
Expand All @@ -86,7 +116,7 @@ void ComputeInlineExpand(Expr *expr, poly::StageMap stages) {
while (!inline_tensors.empty()) {
for (const auto &t : inline_tensors) {
auto *tensor = t.as_tensor();
TensorInlineExpandMutator(tensor->name)(expr);
TensorInlineExpandMutator(tensor->name, all_tensor_map)(expr);
}

inline_tensors = ir::CollectLoadTensors(
Expand Down
2 changes: 1 addition & 1 deletion cinn/optim/compute_inline_expand.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace optim {
* @param tensor_name name of the tensor to expand inline.
* @param memo a memo to avoid duplicate expand.
*/
void ComputeInlineExpand(Expr* expr, poly::StageMap stages);
void ComputeInlineExpand(Expr* expr, poly::StageMap stages, std::map<std::string, ir::Tensor>* all_tensor_map);

} // namespace optim
} // namespace cinn
15 changes: 12 additions & 3 deletions cinn/optim/replace_var_with_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ struct ReplaceVarIndexOfCacheMutator : public ir::IRMutator<> {
VLOG(2) << "Store 's tensor name is : " << tensor->name;
if (tensor->buffer.defined() &&
(utils::Endswith(tensor->buffer->name, "_read_cache") ||
utils::Endswith(tensor->buffer->name, "_cache_write_out")) &&
utils::Endswith(tensor->buffer->name, "_cache_write_out") ||
utils::Endswith(tensor->buffer->name, "_temp_buffer")) &&
((*global_tensor_map_).at(tensor->name)->buffer->memory_type == ir::MemoryType::GPULocal || blockidx_)) {
bool temp_replace = do_replace_;
do_replace_ = true;
Expand All @@ -123,9 +124,14 @@ struct ReplaceVarIndexOfCacheMutator : public ir::IRMutator<> {
if (find_replace_ == true) ResizeTempMemory(tensor->name);
ir::IRMutator<>::Visit(&node->tensor, &node->tensor);
do_replace_ = temp_replace;
ir::IRMutator<>::Visit(&node->value, &node->value);
} else {
for (auto& index : node->indices) {
ir::IRMutator<>::Visit(&index, &index);
}
ir::IRMutator<>::Visit(&node->tensor, &node->tensor);
ir::IRMutator<>::Visit(&node->value, &node->value);
}
ir::IRMutator<>::Visit(&node->value, &node->value);
}

void Visit(const ir::Load* expr, Expr* op) override {
Expand All @@ -134,14 +140,17 @@ struct ReplaceVarIndexOfCacheMutator : public ir::IRMutator<> {
VLOG(2) << "Load's tensor name is : " << tensor->name;
if (tensor->buffer.defined() &&
(utils::Endswith(tensor->buffer->name, "_read_cache") ||
utils::Endswith(tensor->buffer->name, "_cache_write_out")) &&
utils::Endswith(tensor->buffer->name, "_cache_write_out") ||
utils::Endswith(tensor->buffer->name, "_temp_buffer")) &&
((*global_tensor_map_).at(tensor->name)->buffer->memory_type == ir::MemoryType::GPULocal || blockidx_)) {
bool temp_replace = do_replace_;
do_replace_ = true;
ir::IRMutator<>::Visit(&node->tensor, &node->tensor);
for (auto& idx : node->indices) ir::IRMutator<>::Visit(&idx, &idx);
do_replace_ = temp_replace;
} else {
ir::IRMutator<>::Visit(&node->tensor, &node->tensor);
for (auto& idx : node->indices) ir::IRMutator<>::Visit(&idx, &idx);
}
}

Expand Down
Loading

0 comments on commit 270bed6

Please sign in to comment.