Skip to content

Commit

Permalink
Fix gpu CacheRead(local) (PaddlePaddle#134)
Browse files Browse the repository at this point in the history
* fix gpu cache_read

* move compute_at related postprocesses to compute_at_postprocess.cc

for better code read

* remove ForSeperate

* fix transform polyfor to for test
  • Loading branch information
Superjomn authored Jul 30, 2020
1 parent b5ea5e8 commit 5238ff1
Show file tree
Hide file tree
Showing 23 changed files with 815 additions and 986 deletions.
98 changes: 10 additions & 88 deletions cinn/backends/codegen_c_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,8 @@ void add1(void* _args, int32_t num_args)
};
for (int32_t i_outer = 0; i_outer < 25; i_outer += 1) {
for (int32_t i_inner = 0; i_inner < 4; i_inner += 1) {
for (int32_t j_outer = 0; j_outer < 1; j_outer += 1) {
for (int32_t j_inner = 0; j_inner < 16; j_inner += 1) {
D[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))] = ((2 * C[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))]) + (4 * (C[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))] * A[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))])));
};
};
for (int32_t j_outer = 1; j_outer < 2; j_outer += 1) {
for (int32_t j_inner = 0; j_inner < (20 + (-16 * j_outer)); j_inner += 1) {
for (int32_t j_outer = 0; j_outer < 2; j_outer += 1) {
for (int32_t j_inner = 0; j_inner < (1 + ((int32_t)(cinn_min(15, (19 + (-16 * j_outer)))))); j_inner += 1) {
D[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))] = ((2 * C[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))]) + (4 * (C[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))] * A[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))])));
};
};
Expand Down Expand Up @@ -360,48 +355,10 @@ void matmul(void* _args, int32_t num_args)
const float* B = ((const float*)(_B->host_memory));
float* C = ((float*)(_C->host_memory));
float* C_init = ((float*)(_C->host_memory));
for (int32_t i_outer = 0; i_outer < 3; i_outer += 1) {
for (int32_t j_outer = 0; j_outer < 15; j_outer += 1) {
for (int32_t i_inner = 0; i_inner < 32; i_inner += 1) {
for (int32_t j_inner = 0; j_inner < 32; j_inner += 1) {
C_init[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = 0;
for (int32_t k0_outer = 0; k0_outer < 50; k0_outer += 1) {
for (int32_t k0_inner = 0; k0_inner < 4; k0_inner += 1) {
C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = (C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] + (A[((200 * i_inner) + ((6400 * i_outer) + ((4 * k0_outer) + k0_inner)))] * B[((32 * j_outer) + ((500 * k0_inner) + ((2000 * k0_outer) + j_inner)))]));
};
};
};
};
};
for (int32_t j_outer = 15; j_outer < 16; j_outer += 1) {
for (int32_t i_inner = 0; i_inner < 32; i_inner += 1) {
for (int32_t j_inner = 0; j_inner < (500 + (-32 * j_outer)); j_inner += 1) {
C_init[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = 0;
for (int32_t k0_outer = 0; k0_outer < 50; k0_outer += 1) {
for (int32_t k0_inner = 0; k0_inner < 4; k0_inner += 1) {
C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = (C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] + (A[((200 * i_inner) + ((6400 * i_outer) + ((4 * k0_outer) + k0_inner)))] * B[((32 * j_outer) + ((500 * k0_inner) + ((2000 * k0_outer) + j_inner)))]));
};
};
};
};
};
};
for (int32_t i_outer = 3; i_outer < 4; i_outer += 1) {
for (int32_t j_outer = 0; j_outer < 15; j_outer += 1) {
for (int32_t i_inner = 0; i_inner < (100 + (-32 * i_outer)); i_inner += 1) {
for (int32_t j_inner = 0; j_inner < 32; j_inner += 1) {
C_init[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = 0;
for (int32_t k0_outer = 0; k0_outer < 50; k0_outer += 1) {
for (int32_t k0_inner = 0; k0_inner < 4; k0_inner += 1) {
C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = (C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] + (A[((200 * i_inner) + ((6400 * i_outer) + ((4 * k0_outer) + k0_inner)))] * B[((32 * j_outer) + ((500 * k0_inner) + ((2000 * k0_outer) + j_inner)))]));
};
};
};
};
};
for (int32_t j_outer = 15; j_outer < 16; j_outer += 1) {
for (int32_t i_inner = 0; i_inner < (100 + (-32 * i_outer)); i_inner += 1) {
for (int32_t j_inner = 0; j_inner < (500 + (-32 * j_outer)); j_inner += 1) {
for (int32_t i_outer = 0; i_outer < 4; i_outer += 1) {
for (int32_t j_outer = 0; j_outer < 16; j_outer += 1) {
for (int32_t i_inner = 0; i_inner < (1 + ((int32_t)(cinn_min(31, (99 + (-32 * i_outer)))))); i_inner += 1) {
for (int32_t j_inner = 0; j_inner < (1 + ((int32_t)(cinn_min(31, (499 + (-32 * j_outer)))))); j_inner += 1) {
C_init[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = 0;
for (int32_t k0_outer = 0; k0_outer < 50; k0_outer += 1) {
for (int32_t k0_inner = 0; k0_inner < 4; k0_inner += 1) {
Expand Down Expand Up @@ -485,45 +442,10 @@ void matmul_with_packing(void* _args, int32_t num_args)
};
};
};
for (int32_t i_outer = 0; i_outer < 3; i_outer += 1) {
for (int32_t j_outer = 0; j_outer < 15; j_outer += 1) {
for (int32_t i_inner = 0; i_inner < 32; i_inner += 1) {
for (int32_t j_inner = 0; j_inner < 32; j_inner += 1) {
for (int32_t k0_outer = 0; k0_outer < 50; k0_outer += 1) {
for (int32_t k0_inner = 0; k0_inner < 4; k0_inner += 1) {
C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = (A[((200 * i_inner) + ((6400 * i_outer) + ((4 * k0_outer) + k0_inner)))] * PackedB[((6400 * j_outer) + ((32 * k0_inner) + ((128 * k0_outer) + j_inner)))]);
};
};
};
};
};
for (int32_t j_outer = 15; j_outer < 16; j_outer += 1) {
for (int32_t i_inner = 0; i_inner < 32; i_inner += 1) {
for (int32_t j_inner = 0; j_inner < (500 + (-32 * j_outer)); j_inner += 1) {
for (int32_t k0_outer = 0; k0_outer < 50; k0_outer += 1) {
for (int32_t k0_inner = 0; k0_inner < 4; k0_inner += 1) {
C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = (A[((200 * i_inner) + ((6400 * i_outer) + ((4 * k0_outer) + k0_inner)))] * PackedB[((j_inner % 32) + ((6400 * (j_inner / 32)) + ((6400 * j_outer) + ((32 * k0_inner) + (128 * k0_outer)))))]);
};
};
};
};
};
};
for (int32_t i_outer = 3; i_outer < 4; i_outer += 1) {
for (int32_t j_outer = 0; j_outer < 15; j_outer += 1) {
for (int32_t i_inner = 0; i_inner < (100 + (-32 * i_outer)); i_inner += 1) {
for (int32_t j_inner = 0; j_inner < 32; j_inner += 1) {
for (int32_t k0_outer = 0; k0_outer < 50; k0_outer += 1) {
for (int32_t k0_inner = 0; k0_inner < 4; k0_inner += 1) {
C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = (A[((200 * i_inner) + ((6400 * i_outer) + ((4 * k0_outer) + k0_inner)))] * PackedB[((6400 * j_outer) + ((32 * k0_inner) + ((128 * k0_outer) + j_inner)))]);
};
};
};
};
};
for (int32_t j_outer = 15; j_outer < 16; j_outer += 1) {
for (int32_t i_inner = 0; i_inner < (100 + (-32 * i_outer)); i_inner += 1) {
for (int32_t j_inner = 0; j_inner < (500 + (-32 * j_outer)); j_inner += 1) {
for (int32_t i_outer = 0; i_outer < 4; i_outer += 1) {
for (int32_t j_outer = 0; j_outer < 16; j_outer += 1) {
for (int32_t i_inner = 0; i_inner < (1 + ((int32_t)(cinn_min(31, (99 + (-32 * i_outer)))))); i_inner += 1) {
for (int32_t j_inner = 0; j_inner < (1 + ((int32_t)(cinn_min(31, (499 + (-32 * j_outer)))))); j_inner += 1) {
for (int32_t k0_outer = 0; k0_outer < 50; k0_outer += 1) {
for (int32_t k0_inner = 0; k0_inner < 4; k0_inner += 1) {
C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = (A[((200 * i_inner) + ((6400 * i_outer) + ((4 * k0_outer) + k0_inner)))] * PackedB[((j_inner % 32) + ((6400 * (j_inner / 32)) + ((6400 * j_outer) + ((32 * k0_inner) + (128 * k0_outer)))))]);
Expand Down
148 changes: 117 additions & 31 deletions cinn/backends/codegen_cuda_dev_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ TEST(Conv, basic) {
{rc, ry, rx});
B->WithBuffer();

B->stage()->CacheRead("share", {B});
B->stage()->CacheRead("shared", {B});

auto fn = Lower("fn", {A, W, B});

Expand All @@ -596,7 +596,7 @@ TEST(elementwise_add, share_local_cache) {
auto C = Compute(
{M, N}, [&](Expr i, Expr j) { return A(i, j) + B(i, j); }, "C");

auto AA = A->stage()->CacheRead("share", {C});
auto AA = A->stage()->CacheRead("shared", {C});
// NOTE here, the CC replace the C as the output the function.
auto CC = C->stage()->CacheWrite("local");

Expand Down Expand Up @@ -759,8 +759,8 @@ TEST(Conv, basic_add_cache) {
Apad->shape, [=](const std::vector<Expr>& dims) -> Expr { return Apad(dims); }, "AA");
auto WW = Compute(
W->shape, [=](const std::vector<Expr>& dims) { return W(dims); }, "WW");
AA->WithBuffer("share");
WW->WithBuffer("share");
AA->WithBuffer("shared");
WW->WithBuffer("shared");

auto AL = Compute(
AA->shape, [=](const std::vector<Expr>& dims) -> Expr { return AA(dims); }, "AL");
Expand Down Expand Up @@ -849,8 +849,8 @@ TEST(Conv, optimize) {

Apad->stage()->ComputeInline();

auto AA = Apad->stage()->CacheRead("share", {B});
auto WW = W->stage()->CacheRead("share", {B});
auto AA = Apad->stage()->CacheRead("shared", {B});
auto WW = W->stage()->CacheRead("shared", {B});
auto AL = AA->stage()->CacheRead("local", {B});
auto WL = WW->stage()->CacheRead("local", {B});
auto BL = B->stage()->CacheWrite("local");
Expand Down Expand Up @@ -884,7 +884,7 @@ TEST(Conv, optimize) {
LOG(INFO) << Lower("conv", {A, W, BL}, {}, {AA, WW, AL, WL, B});
}

TEST(ElementwiseAdd, cache_read) {
TEST(ElementwiseAdd, cache_read_local) {
Context::Global().ResetNameId();

Expr M(100);
Expand Down Expand Up @@ -931,21 +931,27 @@ void fn0_kernel(const float* __restrict__ A, const float* __restrict__ B, float*
{
float _A_read_cache_3 [ 1 * 10 ];
float* A_read_cache_3 = _A_read_cache_3;
if ((threadIdx.x < 100)) {
{
if (((((threadIdx.x >= 0) && (threadIdx.x <= 99)) && (blockIdx.x >= 0)) && (blockIdx.x <= 19))) {
for (int32_t j_inner = 0; j_inner < 10; j_inner += 1) {
A_read_cache_3[j_inner] = A[((10 * blockIdx.x) + ((200 * threadIdx.x) + j_inner))];
if ((blockIdx.x < 20)) {
{
if (((((threadIdx.x >= 0) && (threadIdx.x <= 99)) && (blockIdx.x >= 0)) && (blockIdx.x <= 19))) {
for (int32_t j_inner = 0; j_inner < 10; j_inner += 1) {
A_read_cache_3[j_inner] = A[((10 * blockIdx.x) + ((200 * threadIdx.x) + j_inner))];
};
};
for (int32_t i = 0; i < 10; i += 1) {
C[((10 * blockIdx.x) + ((200 * threadIdx.x) + i))] = (A_read_cache_3[i] + B[((10 * blockIdx.x) + ((200 * threadIdx.x) + i))]);
};
}
};
for (int32_t i = 0; i < 10; i += 1) {
C[((10 * blockIdx.x) + ((200 * threadIdx.x) + i))] = (A_read_cache_3[i] + B[((10 * blockIdx.x) + ((200 * threadIdx.x) + i))]);
};
}
};
}
}
)ROC";
// ASSERT_EQ(utils::Trim(source_target), source_code);
ASSERT_EQ(utils::Trim(source_target), source_code);

auto [host_module, device_module] = SplitCudaAndHostModule(module); // NOLINT

Expand Down Expand Up @@ -988,35 +994,52 @@ void fn0_kernel(const float* __restrict__ A, const float* __restrict__ B, float*
}

TEST(ElementwiseAdd, cache_read1) {
Context::Global().ResetNameId();

Expr M(100);
Expr N(200);

Placeholder<float> A("A", {M, N});
Placeholder<float> B("B", {M, N});
auto create_module = [&] {
Context::Global().ResetNameId();

auto C = Compute(
{M - 2, N}, [&](Expr i, Expr j) { return A(i, j) + A(i + 1, j) + A(i + 2, j) + B(i, j); }, "C");
C->stage()->Split(1, 10);
Placeholder<float> A("A", {M, N});
Placeholder<float> B("B", {M, N});

auto AL = A->stage()->CacheRead("local", {C});
AL->stage()->Split(1, 10);
auto C = Compute(
{M - 2, N}, [&](Expr i, Expr j) { return A(i, j) + A(i + 1, j) + A(i + 2, j) + B(i, j); }, "C");
C->stage()->Split(1, 10);

AL->stage()->ComputeAt(C->stage(), 1, poly::Stage::ComputeAtKind::kComputeAtUnk, A->name);
auto AL = A->stage()->CacheRead("local", {C});
AL->stage()->Split(1, 10);

AL->stage()->ComputeAt(C->stage(), 1, poly::Stage::ComputeAtKind::kComputeAtUnk, A->name);

return std::make_tuple(A, B, C, AL);
};
{
auto [A, B, C, AL] = create_module(); // NOLINT
auto fn = Lower("fn1", {A, B, C}, {}, {AL});
CodeGenC codegen_c(common::DefaultHostTarget());
codegen_c.SetInlineBuiltinCodes(false);

Module::Builder builder("module", common::DefaultHostTarget());
builder.AddFunction(fn);

auto c_source_code = codegen_c.Compile(builder.Build(), CodeGenC::OutputKind::CImpl);
std::cout << "C source code:\n" << c_source_code << std::endl;
}

auto [A, B, C, AL] = create_module(); // NOLINT
C->stage()->Bind(0, "threadIdx.x");
C->stage()->Bind(1, "blockIdx.x");

Target target;
CodeGenCUDA_Dev codegen(target);

auto fn = Lower("fn1", {A, B, C}, {}, {AL});

Module::Builder builder("module", target);
Module::Builder builder("module", common::DefaultHostTarget());
builder.AddFunction(fn);

auto source_code = codegen.Compile(builder.Build());
std::cout << "source:\n" << source_code << std::endl;
std::cout << "CUDA source:\n" << source_code << std::endl;

std::string source_target = R"ROC(
extern "C" {
Expand Down Expand Up @@ -1055,18 +1078,81 @@ void fn1_kernel(const float* __restrict__ A, const float* __restrict__ B, float*
}
)ROC";
ASSERT_EQ(utils::Trim(source_target), source_code);

std::string source_target1 = R"ROC(
extern "C" {
#ifdef __CUDACC_RTC__
typedef int int32_t;
typedef char int8_t;
#endif
__global__
void fn1_kernel(const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ C)
{
float _A_read_cache_3 [ 3 * 10 ];
float* A_read_cache_3 = _A_read_cache_3;
{
if (((((threadIdx.x >= 0) && (threadIdx.x <= 97)) && (blockIdx.x >= 0)) && (blockIdx.x <= 19))) {
for (int32_t i = 0; i < 3; i += 1) {
for (int32_t j_inner = 0; j_inner < 10; j_inner += 1) {
A_read_cache_3[((10 * i) + j_inner)] = A[((10 * blockIdx.x) + ((200 * (i+threadIdx.x)) + j_inner))];
};
};
};
for (int32_t i = 0; i < 10; i += 1) {
C[((10 * blockIdx.x) + ((200 * threadIdx.x) + i))] = (A_read_cache_3[i] + (A_read_cache_3[(10 + i)] + (A_read_cache_3[(20 + i)] + B[((10 * blockIdx.x) + ((200 * threadIdx.x) + i))])));
};
};
}
}
)ROC";

// ASSERT_EQ(utils::Trim(source_target), source_code);

common::CudaModuleTester tester;
// tester.Compile(builder.Build(), source_target1);
tester.Compile(builder.Build());

auto* A_host = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build();
auto* B_host = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build();
auto* C_host = common::BufferBuilder(Float(32), {M.as_int32() - 2, N.as_int32()}).set_zero().Build();
auto* A_host = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build();
auto* B_host = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build();
auto* C_host = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_zero().Build();
auto* C_target_host = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_zero().Build();

auto* A_dev = tester.CreateDeviceBuffer(A_host);
auto* B_dev = tester.CreateDeviceBuffer(B_host);
auto* C_dev = tester.CreateDeviceBuffer(C_host);

cinn_buffer_t* dev_bufs[3];
for (int i = 0; i < 3; i++) dev_bufs[i] = new cinn_buffer_t;
dev_bufs[0]->host_memory = reinterpret_cast<uint8_t*>(A_dev);
dev_bufs[1]->host_memory = reinterpret_cast<uint8_t*>(B_dev);
dev_bufs[2]->host_memory = reinterpret_cast<uint8_t*>(C_dev);
auto args = common::ArgsBuilder().Add(dev_bufs[0]).Add(dev_bufs[1]).Add(dev_bufs[2]).Build();

CUDA_CALL(cudaDeviceSynchronize());
tester("fn1", args.data(), args.size());
CUDA_CALL(cudaDeviceSynchronize());

CUDA_CALL(cudaMemcpy(reinterpret_cast<void*>(C_target_host->host_memory),
C_dev,
C_target_host->num_elements() * sizeof(float),
cudaMemcpyDeviceToHost));

auto* C_target_mem = reinterpret_cast<float*>(C_target_host->host_memory);
auto* A_mem = reinterpret_cast<float*>(A_host->host_memory);
auto* B_mem = reinterpret_cast<float*>(B_host->host_memory);
for (int i = 0; i < M.as_int32() - 2; i++) {
for (int j = 0; j < N.as_int32(); j++) {
ASSERT_NEAR(C_target_mem[i * N.as_int32() + j],
A_mem[i * N.as_int32() + j] + A_mem[(i + 1) * N.as_int32() + j] + A_mem[(i + 2) * N.as_int32() + j] +
B_mem[i * N.as_int32() + j],
1e-5);
}
}
}

} // namespace backends
Expand Down
4 changes: 3 additions & 1 deletion cinn/backends/extern_func_emitter.cc
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#include "cinn/backends/extern_func_emitter.h"

#include <glog/raw_logging.h>
#include <functional>
#include <iostream>
#include <string>

#include "cinn/backends/extern_func_emitter_builtin.h"
#include "cinn/backends/llvm/runtime_symbol_registry.h"
#include "cinn/runtime/cpu/host_intrinsics.h"
#include "cinn/utils/string.h"

namespace cinn {
namespace backends {
Expand All @@ -17,7 +19,7 @@ ExternFunctionEmitterRegistry& ExternFunctionEmitterRegistry::Global() {
}

void ExternFunctionEmitterRegistry::Register(const ExternFuncID& name, ExternFunctionEmitter* x) {
std::cerr << "Register extern function emitter [" << name << "]" << std::endl;
RAW_LOG_INFO("Register extern function emitter [%s]", utils::GetStreamCnt(name).c_str());
CHECK(x);
data_[name] = std::unique_ptr<ExternFunctionEmitter>(x);
}
Expand Down
Loading

0 comments on commit 5238ff1

Please sign in to comment.