diff --git a/cinn/backends/codegen_cuda_dev_test.cc b/cinn/backends/codegen_cuda_dev_test.cc old mode 100755 new mode 100644 index 77dab6a969c33..f23a7d60bd2c2 --- a/cinn/backends/codegen_cuda_dev_test.cc +++ b/cinn/backends/codegen_cuda_dev_test.cc @@ -614,7 +614,8 @@ TEST(Conv, basic) { auto stages = CreateStages({A, W, Apad, B}); stages[Apad]->ComputeInline(); - auto B_cache = stages[B]->CacheRead("shared", {}, stages); + std::vector temp; + auto B_cache = stages[B]->CacheRead2("shared", temp, stages); auto fn = Lower("fn", stages, {A, W, B, B_cache}); @@ -636,10 +637,10 @@ TEST(elementwise_add, share_local_cache) { {M, N}, [&](Expr i, Expr j) { return A(i, j) + B(i, j); }, "C"); auto stages = CreateStages({C}); - - auto AA = stages[A]->CacheRead("shared", {C}, stages); + std::vector temp{C}; + auto CC = stages[C]->CacheWrite2("local", stages); + auto AA = stages[A]->CacheRead2("shared", temp, stages); // NOTE here, the CC replace the C as the output the function. - auto CC = stages[C]->CacheWrite("local", stages); stages[C]->Bind(0, "blockIdx.x"); stages[C]->Bind(1, "threadIdx.x"); @@ -650,8 +651,7 @@ TEST(elementwise_add, share_local_cache) { stages[CC]->Bind(0, "blockIdx.x"); stages[CC]->Bind(1, "threadIdx.x"); - Target target; - Module::Builder builder("gpu_module", target); + Module::Builder builder("gpu_module", common::DefaultNVGPUTarget()); auto fn = Lower("elementwise_add", stages, {A, B, CC}, {}, {AA, C}, &builder); @@ -669,7 +669,7 @@ TEST(elementwise_add, share_local_cache) { } // compile with device code - CodeGenCUDA_Dev codegen(target); + CodeGenCUDA_Dev codegen(common::DefaultNVGPUTarget()); auto source_code = codegen.Compile(builder.Build()); LOG(INFO) << "device source code:\n" << source_code; @@ -811,12 +811,12 @@ TEST(Conv, optimize) { "B"); auto stages = CreateStages({B}); - - auto AA = stages[Apad]->CacheRead("shared", {B}, stages); - auto WW = stages[W]->CacheRead("shared", {B}, stages); - auto AL = stages[AA]->CacheRead("local", {B}, stages); - auto WL = stages[WW]->CacheRead("local", {B}, stages); - auto BL = stages[B]->CacheWrite("local", stages); + std::vector temp{B}; + auto BL = stages[B]->CacheWrite2("local", stages); + auto AA = stages[Apad]->CacheRead2("shared", temp, stages); + auto WW = stages[W]->CacheRead2("shared", temp, stages); + auto AL = stages[AA]->CacheRead2("local", temp, stages); + auto WL = stages[WW]->CacheRead2("local", temp, stages); stages[Apad]->ComputeInline(); @@ -861,8 +861,9 @@ TEST(ElementwiseAdd, cache_read_local) { auto stages = CreateStages({C}); - auto AL = stages[A]->CacheRead("local", {C}, stages); + std::vector temp{C}; + auto AL = stages[A]->CacheRead2("local", temp, stages); stages[C]->Split(1, 10); stages[AL]->Split(1, 10); @@ -1182,7 +1183,8 @@ TEST(ElementwiseAdd, cache_read_shared) { auto C = Compute( {M, N}, [&](Expr i, Expr j) { return A(i, j); }, "C"); auto stages = CreateStages({A, B, C}); - auto AL = stages[A]->CacheRead("shared", {C}, stages); + std::vector temp{C}; + auto AL = stages[A]->CacheRead2("shared", temp, stages); stages[C]->Split(1, 10); @@ -1208,7 +1210,7 @@ TEST(ElementwiseAdd, cache_read_shared) { builder.AddFunction(fn); auto source_code = codegen.Compile(builder.Build()); - std::cout << "CUDA source:\n" << source_code << std::endl; + std::cout << "CUDA source2:\n" << source_code << std::endl; auto target_source = R"ROC( extern "C" { @@ -1236,7 +1238,6 @@ void fn2(const float* __restrict__ A, const float* __restrict__ B, float* __rest } }; }; - __syncthreads(); if ((threadIdx.x < 20)) { { for (int32_t j = 0; j < 10; j += 1) { @@ -1275,7 +1276,8 @@ TEST(ElementwiseAdd, cache_read_shared_no_compute_at) { {M, N}, [&](Expr i, Expr j) { return A(i, j); }, "C"); auto stages = CreateStages({A, B, C}); - auto AL = stages[A]->CacheRead("shared", {C}, stages); + std::vector temp{C}; + auto AL = stages[A]->CacheRead2("shared", temp, stages); stages[C]->Split(1, 10); stages[AL]->Split(1, 10); @@ -1299,7 +1301,7 @@ TEST(ElementwiseAdd, cache_read_shared_no_compute_at) { builder.AddFunction(fn); auto source_code = codegen.Compile(builder.Build()); - std::cout << "CUDA source:\n" << source_code << std::endl; + std::cout << "CUDA source3:\n" << source_code << std::endl; auto target_source = R"ROC( extern "C" { @@ -1329,7 +1331,6 @@ void fn3(const float* __restrict__ A, const float* __restrict__ B, float* __rest }; } }; - __syncthreads(); if ((blockIdx.x < 40)) { { if ((threadIdx.x < 4)) { @@ -1369,7 +1370,7 @@ TEST(ElementwiseAdd, cache_write_local) { auto stages = CreateStages({A, B, C}); - auto Co = stages[C]->CacheWrite("local", stages); + auto Co = stages[C]->CacheWrite2("local", stages); stages[C]->Split(1, 10); // Cache write local, the local memory can just share in a single thread, so it must ComputeAt(inside) the innermost @@ -1384,12 +1385,12 @@ TEST(ElementwiseAdd, cache_write_local) { }; auto [A, B, C, Co, stages] = create_module(); // NOLINT - Target target; - CodeGenCUDA_Dev codegen(target); + + CodeGenCUDA_Dev codegen(common::DefaultNVGPUTarget()); auto fn = Lower("fn4", stages, {A, B, Co}, {}, {C}); - Module::Builder builder("module", common::DefaultHostTarget()); + Module::Builder builder("module", common::DefaultNVGPUTarget()); builder.AddFunction(fn); auto source_code = codegen.Compile(builder.Build()); @@ -1408,18 +1409,18 @@ typedef char int8_t; __global__ -void fn4(const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ C_cache_write_out) +void fn4(const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ C) { - float _C [ 1 * 1 ]; - float* C = _C; + float _C_cache_write_out [ 1 * 1 ]; + float* C_cache_write_out = _C_cache_write_out; if ((blockIdx.x < 40)) { { if ((threadIdx.x < 40)) { { if (((((blockIdx.x >= 0) && (blockIdx.x <= 39)) && (threadIdx.x >= 0)) && (threadIdx.x <= 39))) { - C[0] = A[((40 * blockIdx.x) + threadIdx.x)]; + C_cache_write_out[0] = A[((40 * blockIdx.x) + threadIdx.x)]; }; - C_cache_write_out[((40 * blockIdx.x) + threadIdx.x)] = C[0]; + C[((40 * blockIdx.x) + threadIdx.x)] = C_cache_write_out[0]; } }; } diff --git a/cinn/ir/ir.cc b/cinn/ir/ir.cc old mode 100644 new mode 100755 diff --git a/cinn/ir/ir.h b/cinn/ir/ir.h old mode 100644 new mode 100755 diff --git a/cinn/optim/replace_var_with_expr.cc b/cinn/optim/replace_var_with_expr.cc old mode 100755 new mode 100644 index 3ad8f3af03170..b312aba8c2b47 --- a/cinn/optim/replace_var_with_expr.cc +++ b/cinn/optim/replace_var_with_expr.cc @@ -83,7 +83,7 @@ struct ReplaceVarIndexOfCacheMutator : public ir::IRMutator<> { auto* node = expr->As(); auto* tensor = node->tensor.as_tensor(); VLOG(2) << "Store 's tensor name is : " << tensor->name; - if (utils::Endswith(tensor->name, "read_cache") && + if ((utils::Endswith(tensor->name, "_read_cache") || utils::Endswith(tensor->name, "_cache_write_out")) && ((*global_tensor_map_).at(tensor->name)->buffer->memory_type == ir::MemoryType::GPULocal || blockidx_)) { bool temp_replace = do_replace; do_replace = true; @@ -101,7 +101,7 @@ struct ReplaceVarIndexOfCacheMutator : public ir::IRMutator<> { auto* node = op->As(); auto* tensor = node->tensor.as_tensor(); VLOG(2) << "Load's tensor name is : " << tensor->name; - if (utils::Endswith(tensor->name, "read_cache") && + if ((utils::Endswith(tensor->name, "_read_cache") || utils::Endswith(tensor->name, "_cache_write_out")) && ((*global_tensor_map_).at(tensor->name)->buffer->memory_type == ir::MemoryType::GPULocal || blockidx_)) { bool temp_replace = do_replace; do_replace = true; diff --git a/cinn/optim/transform_gpu_forloop.cc b/cinn/optim/transform_gpu_forloop.cc index c77623334e964..c0febadac388c 100644 --- a/cinn/optim/transform_gpu_forloop.cc +++ b/cinn/optim/transform_gpu_forloop.cc @@ -181,17 +181,20 @@ void MarkGpuForloop(const std::string &statement, Expr var_expr(cuda_var); VLOG(2) << "gpu replacing var " << axis_var->name << " to " << cuda_var->name; optim::ReplaceVarWithExpr(expr, axis_var, var_expr); - if (utils::Endswith(statement, "read_cache") && (*global_tensor_map).count(statement) > 0) { + if ((utils::Endswith(statement, "_read_cache") || utils::Endswith(statement, "_cache_write_out")) && + (*global_tensor_map).count(statement) > 0) { if ((*global_tensor_map)[statement]->buffer->memory_type == ir::MemoryType::GPULocal) { - Expr extent = for_ ? for_->extent : poly_for->ExtractExtent(); - auto buffer_shape = (*global_tensor_map)[statement]->buffer->shape; - Expr prod(1); - for (auto i : buffer_shape) { - prod = ir::Mul::Make(prod, i); + Expr extent = for_ ? for_->extent : poly_for->ExtractExtent(); + if (extent.defined()) { + auto buffer_shape = (*global_tensor_map)[statement]->buffer->shape; + Expr prod(1); + for (auto i : buffer_shape) { + prod = ir::Mul::Make(prod, i); + } + prod = ir::Div::Make(prod, extent); + std::vector new_shape{prod}; + (*global_tensor_map)[statement]->buffer->shape = new_shape; } - prod = ir::Div::Make(prod, extent); - std::vector new_shape{prod}; - (*global_tensor_map)[statement]->buffer->shape = new_shape; } } VLOG(2) << "gpu replacing var " << cuda_var->name << " to Expr(0)"; @@ -201,18 +204,21 @@ void MarkGpuForloop(const std::string &statement, Expr var_expr(cuda_var); VLOG(2) << "gpu replacing var " << axis_var->name << " to " << cuda_var->name; optim::ReplaceVarWithExpr(expr, axis_var, var_expr); - if (utils::Endswith(statement, "read_cache") && (*global_tensor_map).count(statement) > 0) { + if ((utils::Endswith(statement, "_read_cache") || utils::Endswith(statement, "_cache_write_out")) && + (*global_tensor_map).count(statement) > 0) { if (((*global_tensor_map)[statement]->buffer->memory_type == ir::MemoryType::GPULocal) || ((*global_tensor_map)[statement]->buffer->memory_type == ir::MemoryType::GPUShared)) { - Expr extent = for_ ? for_->extent : poly_for->ExtractExtent(); - auto buffer_shape = (*global_tensor_map)[statement]->buffer->shape; - Expr prod(1); - for (auto i : buffer_shape) { - prod = ir::Mul::Make(prod, i); + Expr extent = for_ ? for_->extent : poly_for->ExtractExtent(); + if (extent.defined()) { + auto buffer_shape = (*global_tensor_map)[statement]->buffer->shape; + Expr prod(1); + for (auto i : buffer_shape) { + prod = ir::Mul::Make(prod, i); + } + prod = ir::Div::Make(prod, extent); + std::vector new_shape{prod}; + (*global_tensor_map)[statement]->buffer->shape = new_shape; } - prod = ir::Div::Make(prod, extent); - std::vector new_shape{prod}; - (*global_tensor_map)[statement]->buffer->shape = new_shape; } } VLOG(3) << "gpu replacing var " << cuda_var->name << " to Expr(0)"; diff --git a/cinn/poly/stage.cc b/cinn/poly/stage.cc index 8d9616332a852..8c90f44bb61a8 100755 --- a/cinn/poly/stage.cc +++ b/cinn/poly/stage.cc @@ -655,6 +655,33 @@ ir::Tensor Stage::CacheWrite(const std::string &memory_type, StageMap stages) { return write_stage; } +/* + * Replace the tensor's name to cache_name, and create a cache_stage to copy content from cache to original tensor. + */ +ir::Tensor Stage::CacheWrite2(const std::string &memory_type, StageMap stages) { + CHECK(tensor_); + CHECK(!tensor_->buffer.defined()) << "This tensor is already binded to a buffer, cannot cache write"; + CHECK(!meta.compute_inline) << "Cannot create a write cache on an inlined tensor"; + + std::string cache_name = Context::Global().NewName(tensor_->name + "_cache_write_out"); + auto original_name = tensor_->name; + tensor_->name = cache_name; + auto my_tensor = ir::Tensor(tensor_); + // make my_tensor a cache + my_tensor->WithBuffer(memory_type); + + auto write_stage = lang::Compute( + tensor_->shape, [=](const std::vector &dims) { return my_tensor(dims); }, original_name); + + stages->Insert(my_tensor, CreateStage(my_tensor).get()); + + stages->Insert(write_stage, CreateStage(write_stage).get()); + + stages[write_stage]->CtrlDepend(my_tensor); + + return write_stage; +} + void Stage::ComputeInline() { CHECK(tensor_); meta.compute_inline = true; diff --git a/cinn/poly/stage.h b/cinn/poly/stage.h index 844f53a6683c1..de9e3365c71e1 100755 --- a/cinn/poly/stage.h +++ b/cinn/poly/stage.h @@ -242,6 +242,8 @@ class Stage : public Object { */ ir::Tensor CacheWrite(const std::string& memory_type, poly::StageMap stages); + ir::Tensor CacheWrite2(const std::string& memory_type, poly::StageMap stages); + /** * Set thread scope. */