Skip to content

Commit

Permalink
Add SplitOuter, enhance Conv2d schedule and fix bugs (PaddlePaddle#404)
Browse files Browse the repository at this point in the history
* fix bugs

* fix codestyles

* add SplitOuter

* fix codestyles

* fix comments
  • Loading branch information
haozech authored Jul 1, 2021
1 parent 1d7a514 commit a0fe1e4
Show file tree
Hide file tree
Showing 13 changed files with 520 additions and 89 deletions.
9 changes: 5 additions & 4 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,16 @@ function build {
# build gtest first, it tends to broke the CI
make extern_gtest

make test01_elementwise_add_main -j $JOBS
make test02_matmul_main -j $JOBS
make test03_conv_main -j $JOBS
make test_codegen_c -j $JOBS
if [[ $cuda_config == "ON" ]]; then
make test_codegen_cuda_dev -j $JOBS
ctest -R test_codegen_cuda_dev -V
fi

make test01_elementwise_add_main -j $JOBS
make test02_matmul_main -j $JOBS
make test03_conv_main -j $JOBS
make test_codegen_c -j $JOBS

ctest -R test01_elementwise_add_main
ctest -R test02_matmul_main
ctest -R test03_conv_main
Expand Down
224 changes: 213 additions & 11 deletions cinn/backends/codegen_cuda_dev_test.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,100 @@ TEST(CodeGenCUDA2, compile_run_jit2) {
}
}

TEST(CodeGenCUDA2, test_of_splitouter) {
Context::Global().ResetNameId();
Expr M(100);
Expr N(100);

Target target;

Placeholder<float> A("X", {M, N});
Placeholder<float> B("Y", {M, N});

auto C = Compute(
{M, N}, [&](Var i, Var j) { return A(i, j) * B(i, j); }, "C");

auto stages = CreateStages({C});
std::vector<ir::Tensor> readers{C};
stages[C]->SplitOuter(0, 20);
stages[C]->SplitOuter(2, 17);
stages[C]->Bind(0, "blockIdx.x");
stages[C]->Bind(1, "threadIdx.x");
CodeGenCUDA_Dev codegen(target);

auto func = Lower("elementwise_add_splitouter", stages, {A, B, C});

Module::Builder builder("module", target);
builder.AddFunction(func);

auto source_code = codegen.Compile(builder.Build());

LOG(INFO) << "compiled SplitOuter code:\n\n\n" << source_code;

std::string source_target = R"ROC(
extern "C" {
#include "cinn_cuda_runtime_source.cuh"
#ifdef __CUDACC_RTC__
typedef int int32_t;
typedef char int8_t;
#endif
__global__
void elementwise_add_splitouter(const float* __restrict__ X, const float* __restrict__ Y, float* __restrict__ C)
{
if ((blockIdx.x < 20)) {
if ((threadIdx.x < 5)) {
for (int32_t j_outer = 0; j_outer < 17; j_outer += 1) {
for (int32_t j_inner = 0; j_inner < cinn_nvgpu_min_fp32(6, (100 + (-6 * j_outer))); j_inner += 1) {
C[((500 * blockIdx.x) + ((6 * j_outer) + ((100 * threadIdx.x) + j_inner)))] = (X[((500 * blockIdx.x) + ((6 * j_outer) + ((100 * threadIdx.x) + j_inner)))] * Y[((500 * blockIdx.x) + ((6 * j_outer) + ((100 * threadIdx.x) + j_inner)))]);
};
};
};
};
}
}
)ROC";
ASSERT_EQ(utils::Trim(source_target), source_code);

using runtime::cuda::CUDAModule;

backends::NVRTC_Compiler compiler;

auto ptx = compiler(source_code);
CHECK(!ptx.empty());

CUDAModule cuda_module(ptx, CUDAModule::Kind::PTX);

auto [Ad, Bd, Cd, host_data1, host_data2, host_data3] = CreateNVMemory(M.as_int32(), N.as_int32());

// launch the kernel

void* args[] = {&Ad, &Bd, &Cd};

dim3 grid(20, 1, 1);
dim3 block(5, 1, 1);
cuda_module.LaunchKernel(0, "elementwise_add_splitouter", grid, block, args);

CUDA_CALL(cudaMemcpy(host_data3.data(),
reinterpret_cast<void*>(Cd),
M.as_int32() * N.as_int32() * sizeof(float),
cudaMemcpyDeviceToHost));

for (int i = 0; i < M.as_int32(); i++) {
for (int j = 0; j < N.as_int32(); j++) {
int offset = i * N.as_int32() + j;
EXPECT_NEAR(host_data3[offset], host_data1[offset] * host_data2[offset], 1e-5);
}
}
}

TEST(CodeGenCUDA2, test_schedule_conv2d_0) {
Context::Global().ResetNameId();
Expr N(1);
Expr C(128);
Expr H(28);
Expand All @@ -196,6 +289,9 @@ TEST(CodeGenCUDA2, test_schedule_conv2d_0) {
optim::Simplify(&(conv->shape[2]));
optim::Simplify(&(conv->shape[3]));

std::vector<ir::Tensor> readers{conv};
auto PR = stages[pad_data]->CacheRead2("shared", readers, stages);
auto KR = stages[B]->CacheRead2("shared", readers, stages);
auto OL = stages[conv]->CacheWrite2("local", stages, conv);

auto tx = stages[conv]->axis(3);
Expand All @@ -211,7 +307,6 @@ TEST(CodeGenCUDA2, test_schedule_conv2d_0) {
stages[conv]->Bind(4, "threadIdx.x");

stages[OL]->ComputeAt3(stages[conv], 4);

stages[OL]->Split(6, 8);
auto on = stages[OL]->axis(0);
auto obz = stages[OL]->axis(1);
Expand All @@ -231,6 +326,27 @@ TEST(CodeGenCUDA2, test_schedule_conv2d_0) {
stages[OL]->Bind(7, "threadIdx.z");
stages[OL]->Bind(8, "threadIdx.x");

stages[KR]->ComputeAt5(stages[OL], 2);
auto OL_init = OL->GetInitTensor(stages, target);

stages[PR]->ComputeAt5(stages[OL], 2);

stages[PR]->SyncThreads({OL}, stages);
stages[PR]->SyncThreads(2, {OL_init}, stages);

stages[KR]->Split(5, 32);
stages[KR]->Split(6, 2);
stages[KR]->Reorder({5, 6, 3, 4, 7});
stages[KR]->Fuse({5, 6, 7});
stages[KR]->Split(5, 8);
stages[KR]->Bind(3, "blockIdx.z");
stages[KR]->Bind(4, "threadIdx.z");
stages[KR]->Bind(6, "threadIdx.x");

stages[PR]->Bind(5, "blockIdx.y");
stages[PR]->Bind(3, "threadIdx.z");
stages[PR]->Bind(6, "threadIdx.x");

CodeGenCUDA_Dev codegen(target);

auto func = Lower("schedule_conv2d_0", stages, {A, B, conv}, {}, {}, nullptr, target);
Expand All @@ -242,6 +358,93 @@ TEST(CodeGenCUDA2, test_schedule_conv2d_0) {

LOG(INFO) << "compiled schedule_conv2d_0 code:\n\n\n" << source_code;

std::string source_target = R"ROC(
extern "C" {
#include "cinn_cuda_runtime_source.cuh"
#ifdef __CUDACC_RTC__
typedef int int32_t;
typedef char int8_t;
#endif
__global__
void schedule_conv2d_0(const float* __restrict__ X, const float* __restrict__ Y, float* __restrict__ COD)
{
__shared__ float _input_pad_read_cache [ 448 ];
float _COD_cache_write_out [ 2 ];
__shared__ float _Y_read_cache [ 256 ];
float* COD_cache_write_out = _COD_cache_write_out;
float* COD_cache_write_out__reduce_init = _COD_cache_write_out;
float* Y_read_cache = _Y_read_cache;
float* input_pad_read_cache = _input_pad_read_cache;
if ((blockIdx.z < 8)) {
if ((blockIdx.y < 14)) {
if ((threadIdx.z < 16)) {
if ((threadIdx.x < 14)) {
for (int32_t j_inner = 0; j_inner < 2; j_inner += 1) {
COD_cache_write_out__reduce_init[j_inner] = 0;
};
};
};
};
};
for (int32_t rc_outer = 0; rc_outer < 16; rc_outer += 1) {
{
__syncthreads();
if ((blockIdx.z < 8)) {
if ((threadIdx.z < 16)) {
for (int32_t j_outer_outer = 0; j_outer_outer < 2; j_outer_outer += 1) {
if ((threadIdx.x < 8)) {
Y_read_cache[((threadIdx.x / 2) + ((8 * (threadIdx.x % 2)) + ((4 * j_outer_outer) + (16 * threadIdx.z))))] = Y[((threadIdx.x / 2) + ((128 * (threadIdx.x % 2)) + ((4096 * blockIdx.z) + ((4 * j_outer_outer) + ((8 * rc_outer) + (256 * threadIdx.z))))))];
};
};
};
};
};
if ((threadIdx.z < 8)) {
if ((blockIdx.y < 14)) {
if ((threadIdx.x < 14)) {
input_pad_read_cache[((2 * threadIdx.x) + (28 * threadIdx.z))] = X[((56 * blockIdx.y) + ((6272 * rc_outer) + ((2 * threadIdx.x) + (784 * threadIdx.z))))];
};
};
};
__syncthreads();
for (int32_t rc_inner = 0; rc_inner < 8; rc_inner += 1) {
if ((blockIdx.z < 8)) {
if ((blockIdx.y < 14)) {
if ((threadIdx.z < 16)) {
if ((threadIdx.x < 14)) {
for (int32_t j_inner = 0; j_inner < 2; j_inner += 1) {
COD_cache_write_out[j_inner] = (COD_cache_write_out[j_inner] + (input_pad_read_cache[((28 * rc_inner) + (2 * threadIdx.x))] * Y_read_cache[((8 * j_inner) + ((16 * threadIdx.z) + rc_inner))]));
};
};
};
};
};
};
};
if ((blockIdx.z < 8)) {
if ((blockIdx.y < 14)) {
if ((threadIdx.z < 16)) {
if ((threadIdx.x < 14)) {
for (int32_t j_inner = 0; j_inner < 2; j_inner += 1) {
COD[((14 * blockIdx.y) + ((6272 * blockIdx.z) + ((196 * j_inner) + ((392 * threadIdx.z) + threadIdx.x))))] = COD_cache_write_out[j_inner];
};
};
};
};
};
}
}
)ROC";
std::string trimed_source_target = utils::Trim(source_target);
int start_target = trimed_source_target.find("blockIdx");
int start_source = source_code.find("blockIdx");
ASSERT_EQ(trimed_source_target.substr(start_target), source_code.substr(start_source));
using runtime::cuda::CUDAModule;

backends::NVRTC_Compiler compiler;
Expand Down Expand Up @@ -1092,6 +1295,7 @@ TEST(elementwise_add1, share_local_cache) {
}

TEST(elementwise_add0, share_local_cache) {
Context::Global().ResetNameId();
Expr M(100);
Expr N(20);

Expand All @@ -1109,7 +1313,7 @@ TEST(elementwise_add0, share_local_cache) {
// NOTE here, the CC replace the C as the output the function.

stages[CC]->ComputeAt5(stages[C], 1);
stages[AA]->ComputeAt5(stages[C], 1);
stages[AA]->ComputeAt5(stages[CC], 1);
stages[C]->Bind(0, "blockIdx.x");
stages[C]->Bind(1, "threadIdx.x");

Expand Down Expand Up @@ -1595,7 +1799,7 @@ TEST(ElementwiseAdd, cache_read_compute_at2) {
Placeholder<float> A("AA", {M, M});

auto C = Compute(
{N, N}, [&](Expr i, Expr j) { return A(i + 50, j) + A(i, j + 50); }, "C");
{N, N}, [&](Expr i, Expr j) { return A(i + 5, j) + A(i, j + 5); }, "C");

auto stages = CreateStages({A, C});
std::vector<ir::Tensor> temp{C};
Expand Down Expand Up @@ -1630,17 +1834,17 @@ typedef char int8_t;
__global__
void fn_cacheread_computeat2(const float* __restrict__ AA, float* __restrict__ C)
{
float _AA_read_cache [ 51 * 51 ];
float _AA_read_cache [ 6 * 6 ];
float* AA_read_cache = _AA_read_cache;
if ((blockIdx.x < 50)) {
if ((threadIdx.x < 5)) {
for (int32_t j_inner = 0; j_inner < 10; j_inner += 1) {
for (int32_t i_at = 0; i_at < 51; i_at += 1) {
for (int32_t j_at = 0; j_at < 51; j_at += 1) {
AA_read_cache[((51 * i_at) + j_at)] = AA[((100 * blockIdx.x) + ((100 * i_at) + ((10 * threadIdx.x) + (j_at + j_inner))))];
for (int32_t i_at = 0; i_at < 6; i_at += 1) {
for (int32_t j_at = 0; j_at < 6; j_at += 1) {
AA_read_cache[((6 * i_at) + j_at)] = AA[((100 * blockIdx.x) + ((100 * i_at) + ((10 * threadIdx.x) + (j_at + j_inner))))];
};
};
C[((50 * blockIdx.x) + ((10 * threadIdx.x) + j_inner))] = (AA_read_cache[2550] + AA_read_cache[50]);
C[((50 * blockIdx.x) + ((10 * threadIdx.x) + j_inner))] = (AA_read_cache[30] + AA_read_cache[5]);
};
};
};
Expand Down Expand Up @@ -1681,7 +1885,7 @@ void fn_cacheread_computeat2(const float* __restrict__ AA, float* __restrict__ C
for (int i = 0; i < N.as_int32(); i++) {
for (int j = 0; j < N.as_int32(); j++) {
ASSERT_NEAR(C_target_mem[i * N.as_int32() + j],
(A_mem[(i + 50) * M.as_int32() + j] + A_mem[i * M.as_int32() + j + 50]),
(A_mem[(i + 5) * M.as_int32() + j] + A_mem[i * M.as_int32() + j + 5]),
1e-4);
}
}
Expand Down Expand Up @@ -1746,8 +1950,6 @@ TEST(ElementwiseAdd, cache_read_shared) {
Expr N(200);

auto create_module = [&] {
Context::Global().ResetNameId();

Placeholder<float> A("A", {M, N});
Placeholder<float> B("B", {M, N});

Expand Down
1 change: 0 additions & 1 deletion cinn/hlir/op/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ std::shared_ptr<OpStrategy> StrategyForConv2d(const framework::NodeAttr &attrs,
} else {
LOG(FATAL) << "Only support NCHW and NHWC data layout\n";
}

auto stages = CreateStages({A.as_tensor_ref(), B.as_tensor_ref()});

std::vector<CINNValue> res;
Expand Down
Empty file modified cinn/hlir/pe/nn.cc
100644 → 100755
Empty file.
18 changes: 6 additions & 12 deletions cinn/hlir/pe/schedule.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -786,18 +786,12 @@ void CudaScheduleConv(poly::StageMap stages, ir::Tensor &input_pad, ir::Tensor &
auto ory = stages[OL]->axis(7);
auto orx = stages[OL]->axis(8);
stages[OL]->Reorder({orc, ory, orx, on, obz, oby, otz, otx, ofi});
if (rc_factor > 1) {
stages[OL]->Split(0, rc_factor);
stages[OL]->Bind(5, "blockIdx.z");
stages[OL]->Bind(6, "blockIdx.y");
stages[OL]->Bind(7, "threadIdx.z");
stages[OL]->Bind(8, "threadIdx.x");
} else {
stages[OL]->Bind(4, "blockIdx.z");
stages[OL]->Bind(5, "blockIdx.y");
stages[OL]->Bind(6, "threadIdx.z");
stages[OL]->Bind(7, "threadIdx.x");
}
stages[OL]->Split(0, rc_factor);
stages[OL]->Reorder({0, 2, 3, 1});
stages[OL]->Bind(5, "blockIdx.z");
stages[OL]->Bind(6, "blockIdx.y");
stages[OL]->Bind(7, "threadIdx.z");
stages[OL]->Bind(8, "threadIdx.x");

return;
}
Expand Down
Loading

0 comments on commit a0fe1e4

Please sign in to comment.