Skip to content

Commit

Permalink
Add schedule primitive CacheRead2 (PaddlePaddle#311)
Browse files Browse the repository at this point in the history
  • Loading branch information
haozech authored Dec 23, 2020
1 parent 491ab74 commit 0aa1e74
Show file tree
Hide file tree
Showing 22 changed files with 386 additions and 22 deletions.
2 changes: 1 addition & 1 deletion build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function prepare {
mkdir -p $build_dir
cd $build_dir

python3 -m pip install sphinx sphinx_gallery recommonmark exhale scipy --trusted-host mirrors.aliyun.com
python3 -m pip install sphinx==3.3.1 sphinx_gallery==0.8.1 recommonmark==0.6.0 exhale scipy breathe==4.24.0 --trusted-host mirrors.aliyun.com
apt install doxygen -y

mkdir -p tests
Expand Down
Empty file modified cinn/backends/codegen_cuda_dev.cc
100644 → 100755
Empty file.
54 changes: 49 additions & 5 deletions cinn/backends/codegen_cuda_dev_test.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,10 @@ TEST(CodeGenCUDA, compile_run_jit) {
{M, N}, [&](Var i, Var j) { return A(i, j) * B(i, j); }, "C");

auto stages = CreateStages({C});

std::vector<ir::Tensor> readers{C};
auto B_cache = stages[B]->CacheRead2("local", readers, stages);
stages[B_cache]->Bind(0, "blockIdx.x");
stages[B_cache]->Bind(1, "threadIdx.x");
stages[C]->Bind(0, "blockIdx.x");
stages[C]->Bind(1, "threadIdx.x");

Expand All @@ -132,7 +135,48 @@ TEST(CodeGenCUDA, compile_run_jit) {

auto source_code = codegen.Compile(builder.Build());

LOG(INFO) << "compiled code:\n\n\n" << source_code;
LOG(INFO) << "compiled CacheRead2 code:\n\n\n" << source_code;

std::string source_target = R"ROC(
extern "C" {
#include "cinn_cuda_runtime_source.cuh"
#ifdef __CUDACC_RTC__
typedef int int32_t;
typedef char int8_t;
#endif
__global__
void elementwise_add(const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ C)
{
float _B_read_cache [ ((1 * (((1 * 100) * 200) / 100)) / 200) ];
float* B_read_cache = _B_read_cache;
if ((blockIdx.x < 100)) {
{
if ((threadIdx.x < 200)) {
{
B_read_cache[0] = B[((200 * blockIdx.x) + threadIdx.x)];
}
};
}
};
if ((blockIdx.x < 100)) {
{
if ((threadIdx.x < 200)) {
{
C[((200 * blockIdx.x) + threadIdx.x)] = (A[((200 * blockIdx.x) + threadIdx.x)] * B_read_cache[0]);
}
};
}
};
}
}
)ROC";
ASSERT_EQ(utils::Trim(source_target), source_code);

// compile the code
using runtime::cuda::CUDAModule;
Expand Down Expand Up @@ -1272,14 +1316,14 @@ typedef char int8_t;
__global__
void fn3(const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ C)
{
__shared__ float _A_read_cache [ 40 * 40 ];
__shared__ float _A_read_cache [ (((1 * 40) * 40) / 40) ];
float* A_read_cache = _A_read_cache;
if ((blockIdx.x < 40)) {
{
if ((threadIdx.x < 4)) {
{
for (int32_t j_inner = 0; j_inner < 10; j_inner += 1) {
A_read_cache[((40 * blockIdx.x) + ((10 * threadIdx.x) + j_inner))] = A[((40 * blockIdx.x) + ((10 * threadIdx.x) + j_inner))];
A_read_cache[((10 * threadIdx.x) + j_inner)] = A[((40 * blockIdx.x) + ((10 * threadIdx.x) + j_inner))];
};
}
};
Expand All @@ -1291,7 +1335,7 @@ void fn3(const float* __restrict__ A, const float* __restrict__ B, float* __rest
if ((threadIdx.x < 4)) {
{
for (int32_t j_inner = 0; j_inner < 10; j_inner += 1) {
C[((40 * blockIdx.x) + ((10 * threadIdx.x) + j_inner))] = A_read_cache[((40 * blockIdx.x) + ((10 * threadIdx.x) + j_inner))];
C[((40 * blockIdx.x) + ((10 * threadIdx.x) + j_inner))] = A_read_cache[((10 * threadIdx.x) + j_inner)];
};
}
};
Expand Down
15 changes: 15 additions & 0 deletions cinn/backends/compiler.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@ void Compiler::Build(const Module& module, const std::string& code) {
}
}

std::string Compiler::GetSourceCode(const ir::Module& module) {
if (target_.arch == Target::Arch::NVGPU) {
#ifdef CINN_WITH_CUDA
auto [host_module, device_module] = SplitCudaAndHostModule(module); // NOLINT
CodeGenCUDA_Dev codegen(target_);
auto source_code = codegen.Compile(device_module);
return source_code;
#else
CINN_NOT_IMPLEMENTED
#endif
} else {
CINN_NOT_IMPLEMENTED
}
}

void Compiler::BuildDefault(const Module& module) {
if (target_.arch == Target::Arch::NVGPU) {
CompileCudaModule(module);
Expand Down
2 changes: 2 additions & 0 deletions cinn/backends/compiler.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class Compiler final {
*/
void Build(const ir::Module& module, const std::string& code = "");

std::string GetSourceCode(const ir::Module& module);

void BuildDefault(const ir::Module& module);

/**
Expand Down
19 changes: 19 additions & 0 deletions cinn/hlir/framework/graph_compiler.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,25 @@ void GraphCompiler::PrintFunc() {
}
}

std::string GraphCompiler::GenSourceCode() {
auto [nodes, edges] = graph_->topological_order();
for (auto& n : nodes) {
auto* node = n->safe_as<Node>();
if (node) {
auto lowered_func = GetOpFunc(node);
m_builder_.AddFunction(lowered_func);
}
}
// compile the module
if (!compiler_) {
compiler_ = backends::Compiler::Create(target_);
}

auto build_module = m_builder_.Build();

return compiler_->GetSourceCode(build_module);
}

std::unique_ptr<Program> GraphCompiler::Build(const std::string& code) {
auto [nodes, edges] = graph_->topological_order();
for (auto& n : nodes) {
Expand Down
2 changes: 2 additions & 0 deletions cinn/hlir/framework/graph_compiler.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class GraphCompiler final {

std::unique_ptr<Program> Build(const std::string& code = "");

std::string GenSourceCode();

void PrintFunc();

const std::shared_ptr<Scope>& GetScope() const { return scope_; }
Expand Down
Empty file modified cinn/hlir/op/broadcast.cc
100644 → 100755
Empty file.
8 changes: 8 additions & 0 deletions cinn/hlir/op/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,14 @@ std::shared_ptr<OpStrategy> StrategyForMul(const framework::NodeAttr &attrs,
VLOG(3) << "mul out: " << out;
stages->InsertLazily(out);
CHECK(!out_type.empty()) << "Output type of Mul is empty! Please check.\n";

if (target.arch == Target::Arch::NVGPU) {
std::vector<ir::Tensor> readers{out};
auto BB = stages[new_B]->CacheRead2("local", readers, stages);
stages[BB]->Split(0, 2);
stages[BB]->Bind(0, "threadIdx.x");
}

*ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}};
});

Expand Down
4 changes: 2 additions & 2 deletions cinn/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -419,15 +419,15 @@ Expr PolyFor::Make(Var iterator,
std::vector<Expr *> PolyFor::expr_fields() { return {&init, &condition, &inc, &body}; }
std::vector<const Expr *> PolyFor::expr_fields() const { return {&init, &condition, &inc, &body}; }

Expr PolyFor::extent() const {
Expr PolyFor::ExtractExtent() const {
auto nodes = CollectIRNodes(condition, [&](const Expr *e) {
return e->As<NE>() || //
e->As<EQ>() || //
e->As<Min>() || //
e->As<Max>();
});

if (nodes.empty()) {
if (!nodes.empty()) {
return Expr();
}

Expand Down
2 changes: 1 addition & 1 deletion cinn/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ struct PolyFor : public ExprNode<PolyFor>, public ForBase {

PolyFor() : ExprNode(Type()) {}

Expr extent() const;
Expr ExtractExtent() const;

static Expr Make(Var iterator,
Expr init_val,
Expand Down
14 changes: 14 additions & 0 deletions cinn/ir/lowered_func.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "cinn/ir/ir_visitor.h"
#include "cinn/optim/tensor_write_tell.h"
#include "cinn/runtime/intrinsic.h"
#include "cinn/utils/string.h"

namespace cinn {
namespace ir {
Expand Down Expand Up @@ -81,6 +82,19 @@ std::vector<Expr> _LoweredFunc_::PrepareAllocTempBufferExprs() const {
return alloc_output_buffer_exprs;
}

std::vector<Expr> _LoweredFunc_::CudaPrepareAllocTempBufferExprs() const {
std::vector<Expr> alloc_output_buffer_exprs;
for (auto temp_buf : temp_bufs) {
if (utils::Startswith(temp_buf->name, "_")) {
temp_buf->name = temp_buf->name.substr(1);
}
if (!temp_buf->shape.empty() && temp_buf->type() != Void()) {
alloc_output_buffer_exprs.push_back(Alloc::Make(temp_buf, temp_buf->type(), temp_buf->shape, Expr(), Expr()));
}
}
return alloc_output_buffer_exprs;
}

void _LoweredFunc_::PrepareDeallocOutputBufferExprs() {
CHECK(dealloc_output_buffer_exprs.empty()) << "duplicate prepare the allocate buffer for outputs";

Expand Down
1 change: 1 addition & 0 deletions cinn/ir/lowered_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ struct _LoweredFunc_ : ExprNode<_LoweredFunc_> {

//! Prepare the expressions for `alloc_tmp_buffer_exprs`.
std::vector<Expr> PrepareAllocTempBufferExprs() const;
std::vector<Expr> CudaPrepareAllocTempBufferExprs() const;
std::vector<Expr> CudaAliasVarExprs() const;

private:
Expand Down
2 changes: 1 addition & 1 deletion cinn/lang/lower_impl.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ Expr LowerGroup(const poly::ScheduleGroup& group,

forloop_infos[stage->id()] = for_infos;
}
optim::TransformGpuForloops(forloop_infos, &e);
optim::TransformGpuForloops(forloop_infos, global_tensor_map, &e);
auto axis_info = optim::GatherAxisInfoFromStages(stages);
if (axis_info.valid()) cuda_axis_info->ExtendWith(axis_info);
}
Expand Down
2 changes: 1 addition & 1 deletion cinn/optim/insert_debug_log_callee.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ struct InsertDebugLogCalleeMutator : public ir::IRMutator<> {
}
case ir::IrNodeTy::PolyFor: {
auto *node = e.As<ir::PolyFor>();
ss << "<PolyFor " << node->iterator << " in [" << node->init << ", " << node->extent() << ")"
ss << "<PolyFor " << node->iterator << " in [" << node->init << ", " << node->ExtractExtent() << ")"
<< " with condition: " << node->condition << ">";
break;
}
Expand Down
78 changes: 78 additions & 0 deletions cinn/optim/replace_var_with_expr.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "cinn/ir/ir.h"
#include "cinn/ir/ir_mutator.h"
#include "cinn/ir/ir_printer.h"
#include "cinn/ir/tensor.h"
#include "cinn/optim/ir_copy.h"

namespace cinn {
Expand Down Expand Up @@ -51,5 +52,82 @@ void ReplaceVarWithExpr(Expr* source, const Var& var, const Expr& expr) {
mutator(source);
}

struct ReplaceVarIndexOfCacheMutator : public ir::IRMutator<> {
ReplaceVarIndexOfCacheMutator(const Var& var,
const Expr& expr,
const std::map<std::string, ir::Tensor>* global_tensor_map,
bool blockidx)
: var_(var), expr_(expr), global_tensor_map_(global_tensor_map), blockidx_(blockidx) {}

void Execute(Expr* expr) {
auto* for_ = expr->As<ir::For>();
auto* poly_for = expr->As<ir::PolyFor>();
if (for_) {
ir::IRMutator<>::Visit(&for_->body, &for_->body);
} else {
ir::IRMutator<>::Visit(&poly_for->body, &poly_for->body);
}
}

private:
void Visit(const ir::_Var_* expr, Expr* op) override {
if (do_replace) {
if (expr->name != utils::GetStreamCnt(var_->name)) return;
VLOG(2) << "Do Replace: " << expr->name << " to 0";
auto copied = IRCopy(expr_);
*op = copied;
}
}

void Visit(const ir::Store* op, Expr* expr) override {
auto* node = expr->As<ir::Store>();
auto* tensor = node->tensor.as_tensor();
VLOG(2) << "Store 's tensor name is : " << tensor->name;
if (utils::Endswith(tensor->name, "read_cache") &&
((*global_tensor_map_).at(tensor->name)->buffer->memory_type == ir::MemoryType::GPULocal || blockidx_)) {
bool temp_replace = do_replace;
do_replace = true;
for (auto& index : node->indices) {
ir::IRMutator<>::Visit(&index, &index);
}
ir::IRMutator<>::Visit(&node->tensor, &node->tensor);
do_replace = temp_replace;
} else {
}
ir::IRMutator<>::Visit(&node->value, &node->value);
}

void Visit(const ir::Load* expr, Expr* op) override {
auto* node = op->As<ir::Load>();
auto* tensor = node->tensor.as_tensor();
VLOG(2) << "Load's tensor name is : " << tensor->name;
if (utils::Endswith(tensor->name, "read_cache") &&
((*global_tensor_map_).at(tensor->name)->buffer->memory_type == ir::MemoryType::GPULocal || blockidx_)) {
bool temp_replace = do_replace;
do_replace = true;
ir::IRMutator<>::Visit(&node->tensor, &node->tensor);
for (auto& idx : node->indices) ir::IRMutator<>::Visit(&idx, &idx);
do_replace = temp_replace;
} else {
}
}

private:
const std::map<std::string, ir::Tensor>* global_tensor_map_;
const Var& var_;
const Expr& expr_;
bool blockidx_;
bool do_replace{false};
};

void CUDAReplaceIndexOfCachePass(Expr* source,
const Var& var,
const Expr& expr,
const std::map<std::string, ir::Tensor>* global_tensor_map,
bool blockidx) {
ReplaceVarIndexOfCacheMutator mutator(var, expr, global_tensor_map, blockidx);
mutator.Execute(source);
}

} // namespace optim
} // namespace cinn
13 changes: 13 additions & 0 deletions cinn/optim/replace_var_with_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,18 @@ namespace optim {
*/
void ReplaceVarWithExpr(Expr *source, const Var &var, const Expr &expr);

/**
* In cuda backend, replace the var binded to 'threadIdx.x'/'blockIdx.x'
* of the cache tensor with expr.
* @param var The variable to replace.
* @param expr The candidate expression.
* @param global_tensor_map The global tensor map.
* @param blockidx If the var to be replaced is binded to blockIdx.
*/
void CUDAReplaceIndexOfCachePass(Expr *source,
const Var &var,
const Expr &expr,
const std::map<std::string, ir::Tensor> *global_tensor_map,
bool blockidx);
} // namespace optim
} // namespace cinn
Loading

0 comments on commit 0aa1e74

Please sign in to comment.