diff --git a/cinn/backends/codegen_c_test.cc b/cinn/backends/codegen_c_test.cc index 3027de73faa06..5b61e359f3ab6 100644 --- a/cinn/backends/codegen_c_test.cc +++ b/cinn/backends/codegen_c_test.cc @@ -179,13 +179,8 @@ void add1(void* _args, int32_t num_args) }; for (int32_t i_outer = 0; i_outer < 25; i_outer += 1) { for (int32_t i_inner = 0; i_inner < 4; i_inner += 1) { - for (int32_t j_outer = 0; j_outer < 1; j_outer += 1) { - for (int32_t j_inner = 0; j_inner < 16; j_inner += 1) { - D[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))] = ((2 * C[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))]) + (4 * (C[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))] * A[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))]))); - }; - }; - for (int32_t j_outer = 1; j_outer < 2; j_outer += 1) { - for (int32_t j_inner = 0; j_inner < (20 + (-16 * j_outer)); j_inner += 1) { + for (int32_t j_outer = 0; j_outer < 2; j_outer += 1) { + for (int32_t j_inner = 0; j_inner < (1 + ((int32_t)(cinn_min(15, (19 + (-16 * j_outer)))))); j_inner += 1) { D[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))] = ((2 * C[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))]) + (4 * (C[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))] * A[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))]))); }; }; @@ -360,48 +355,10 @@ void matmul(void* _args, int32_t num_args) const float* B = ((const float*)(_B->host_memory)); float* C = ((float*)(_C->host_memory)); float* C_init = ((float*)(_C->host_memory)); - for (int32_t i_outer = 0; i_outer < 3; i_outer += 1) { - for (int32_t j_outer = 0; j_outer < 15; j_outer += 1) { - for (int32_t i_inner = 0; i_inner < 32; i_inner += 1) { - for (int32_t j_inner = 0; j_inner < 32; j_inner += 1) { - C_init[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = 0; - for (int32_t k0_outer = 0; k0_outer < 50; k0_outer += 1) { - for (int32_t k0_inner = 0; k0_inner < 4; k0_inner += 1) { - C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = (C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] + (A[((200 * i_inner) + ((6400 * i_outer) + ((4 * k0_outer) + k0_inner)))] * B[((32 * j_outer) + ((500 * k0_inner) + ((2000 * k0_outer) + j_inner)))])); - }; - }; - }; - }; - }; - for (int32_t j_outer = 15; j_outer < 16; j_outer += 1) { - for (int32_t i_inner = 0; i_inner < 32; i_inner += 1) { - for (int32_t j_inner = 0; j_inner < (500 + (-32 * j_outer)); j_inner += 1) { - C_init[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = 0; - for (int32_t k0_outer = 0; k0_outer < 50; k0_outer += 1) { - for (int32_t k0_inner = 0; k0_inner < 4; k0_inner += 1) { - C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = (C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] + (A[((200 * i_inner) + ((6400 * i_outer) + ((4 * k0_outer) + k0_inner)))] * B[((32 * j_outer) + ((500 * k0_inner) + ((2000 * k0_outer) + j_inner)))])); - }; - }; - }; - }; - }; - }; - for (int32_t i_outer = 3; i_outer < 4; i_outer += 1) { - for (int32_t j_outer = 0; j_outer < 15; j_outer += 1) { - for (int32_t i_inner = 0; i_inner < (100 + (-32 * i_outer)); i_inner += 1) { - for (int32_t j_inner = 0; j_inner < 32; j_inner += 1) { - C_init[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = 0; - for (int32_t k0_outer = 0; k0_outer < 50; k0_outer += 1) { - for (int32_t k0_inner = 0; k0_inner < 4; k0_inner += 1) { - C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = (C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] + (A[((200 * i_inner) + ((6400 * i_outer) + ((4 * k0_outer) + k0_inner)))] * B[((32 * j_outer) + ((500 * k0_inner) + ((2000 * k0_outer) + j_inner)))])); - }; - }; - }; - }; - }; - for (int32_t j_outer = 15; j_outer < 16; j_outer += 1) { - for (int32_t i_inner = 0; i_inner < (100 + (-32 * i_outer)); i_inner += 1) { - for (int32_t j_inner = 0; j_inner < (500 + (-32 * j_outer)); j_inner += 1) { + for (int32_t i_outer = 0; i_outer < 4; i_outer += 1) { + for (int32_t j_outer = 0; j_outer < 16; j_outer += 1) { + for (int32_t i_inner = 0; i_inner < (1 + ((int32_t)(cinn_min(31, (99 + (-32 * i_outer)))))); i_inner += 1) { + for (int32_t j_inner = 0; j_inner < (1 + ((int32_t)(cinn_min(31, (499 + (-32 * j_outer)))))); j_inner += 1) { C_init[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = 0; for (int32_t k0_outer = 0; k0_outer < 50; k0_outer += 1) { for (int32_t k0_inner = 0; k0_inner < 4; k0_inner += 1) { @@ -485,45 +442,10 @@ void matmul_with_packing(void* _args, int32_t num_args) }; }; }; - for (int32_t i_outer = 0; i_outer < 3; i_outer += 1) { - for (int32_t j_outer = 0; j_outer < 15; j_outer += 1) { - for (int32_t i_inner = 0; i_inner < 32; i_inner += 1) { - for (int32_t j_inner = 0; j_inner < 32; j_inner += 1) { - for (int32_t k0_outer = 0; k0_outer < 50; k0_outer += 1) { - for (int32_t k0_inner = 0; k0_inner < 4; k0_inner += 1) { - C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = (A[((200 * i_inner) + ((6400 * i_outer) + ((4 * k0_outer) + k0_inner)))] * PackedB[((6400 * j_outer) + ((32 * k0_inner) + ((128 * k0_outer) + j_inner)))]); - }; - }; - }; - }; - }; - for (int32_t j_outer = 15; j_outer < 16; j_outer += 1) { - for (int32_t i_inner = 0; i_inner < 32; i_inner += 1) { - for (int32_t j_inner = 0; j_inner < (500 + (-32 * j_outer)); j_inner += 1) { - for (int32_t k0_outer = 0; k0_outer < 50; k0_outer += 1) { - for (int32_t k0_inner = 0; k0_inner < 4; k0_inner += 1) { - C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = (A[((200 * i_inner) + ((6400 * i_outer) + ((4 * k0_outer) + k0_inner)))] * PackedB[((j_inner % 32) + ((6400 * (j_inner / 32)) + ((6400 * j_outer) + ((32 * k0_inner) + (128 * k0_outer)))))]); - }; - }; - }; - }; - }; - }; - for (int32_t i_outer = 3; i_outer < 4; i_outer += 1) { - for (int32_t j_outer = 0; j_outer < 15; j_outer += 1) { - for (int32_t i_inner = 0; i_inner < (100 + (-32 * i_outer)); i_inner += 1) { - for (int32_t j_inner = 0; j_inner < 32; j_inner += 1) { - for (int32_t k0_outer = 0; k0_outer < 50; k0_outer += 1) { - for (int32_t k0_inner = 0; k0_inner < 4; k0_inner += 1) { - C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = (A[((200 * i_inner) + ((6400 * i_outer) + ((4 * k0_outer) + k0_inner)))] * PackedB[((6400 * j_outer) + ((32 * k0_inner) + ((128 * k0_outer) + j_inner)))]); - }; - }; - }; - }; - }; - for (int32_t j_outer = 15; j_outer < 16; j_outer += 1) { - for (int32_t i_inner = 0; i_inner < (100 + (-32 * i_outer)); i_inner += 1) { - for (int32_t j_inner = 0; j_inner < (500 + (-32 * j_outer)); j_inner += 1) { + for (int32_t i_outer = 0; i_outer < 4; i_outer += 1) { + for (int32_t j_outer = 0; j_outer < 16; j_outer += 1) { + for (int32_t i_inner = 0; i_inner < (1 + ((int32_t)(cinn_min(31, (99 + (-32 * i_outer)))))); i_inner += 1) { + for (int32_t j_inner = 0; j_inner < (1 + ((int32_t)(cinn_min(31, (499 + (-32 * j_outer)))))); j_inner += 1) { for (int32_t k0_outer = 0; k0_outer < 50; k0_outer += 1) { for (int32_t k0_inner = 0; k0_inner < 4; k0_inner += 1) { C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = (A[((200 * i_inner) + ((6400 * i_outer) + ((4 * k0_outer) + k0_inner)))] * PackedB[((j_inner % 32) + ((6400 * (j_inner / 32)) + ((6400 * j_outer) + ((32 * k0_inner) + (128 * k0_outer)))))]); diff --git a/cinn/backends/codegen_cuda_dev_test.cc b/cinn/backends/codegen_cuda_dev_test.cc index 512a62e7d5068..e94578f60aea8 100644 --- a/cinn/backends/codegen_cuda_dev_test.cc +++ b/cinn/backends/codegen_cuda_dev_test.cc @@ -575,7 +575,7 @@ TEST(Conv, basic) { {rc, ry, rx}); B->WithBuffer(); - B->stage()->CacheRead("share", {B}); + B->stage()->CacheRead("shared", {B}); auto fn = Lower("fn", {A, W, B}); @@ -596,7 +596,7 @@ TEST(elementwise_add, share_local_cache) { auto C = Compute( {M, N}, [&](Expr i, Expr j) { return A(i, j) + B(i, j); }, "C"); - auto AA = A->stage()->CacheRead("share", {C}); + auto AA = A->stage()->CacheRead("shared", {C}); // NOTE here, the CC replace the C as the output the function. auto CC = C->stage()->CacheWrite("local"); @@ -759,8 +759,8 @@ TEST(Conv, basic_add_cache) { Apad->shape, [=](const std::vector& dims) -> Expr { return Apad(dims); }, "AA"); auto WW = Compute( W->shape, [=](const std::vector& dims) { return W(dims); }, "WW"); - AA->WithBuffer("share"); - WW->WithBuffer("share"); + AA->WithBuffer("shared"); + WW->WithBuffer("shared"); auto AL = Compute( AA->shape, [=](const std::vector& dims) -> Expr { return AA(dims); }, "AL"); @@ -849,8 +849,8 @@ TEST(Conv, optimize) { Apad->stage()->ComputeInline(); - auto AA = Apad->stage()->CacheRead("share", {B}); - auto WW = W->stage()->CacheRead("share", {B}); + auto AA = Apad->stage()->CacheRead("shared", {B}); + auto WW = W->stage()->CacheRead("shared", {B}); auto AL = AA->stage()->CacheRead("local", {B}); auto WL = WW->stage()->CacheRead("local", {B}); auto BL = B->stage()->CacheWrite("local"); @@ -884,7 +884,7 @@ TEST(Conv, optimize) { LOG(INFO) << Lower("conv", {A, W, BL}, {}, {AA, WW, AL, WL, B}); } -TEST(ElementwiseAdd, cache_read) { +TEST(ElementwiseAdd, cache_read_local) { Context::Global().ResetNameId(); Expr M(100); @@ -931,21 +931,27 @@ void fn0_kernel(const float* __restrict__ A, const float* __restrict__ B, float* { float _A_read_cache_3 [ 1 * 10 ]; float* A_read_cache_3 = _A_read_cache_3; + if ((threadIdx.x < 100)) { { - if (((((threadIdx.x >= 0) && (threadIdx.x <= 99)) && (blockIdx.x >= 0)) && (blockIdx.x <= 19))) { - for (int32_t j_inner = 0; j_inner < 10; j_inner += 1) { - A_read_cache_3[j_inner] = A[((10 * blockIdx.x) + ((200 * threadIdx.x) + j_inner))]; + if ((blockIdx.x < 20)) { + { + if (((((threadIdx.x >= 0) && (threadIdx.x <= 99)) && (blockIdx.x >= 0)) && (blockIdx.x <= 19))) { + for (int32_t j_inner = 0; j_inner < 10; j_inner += 1) { + A_read_cache_3[j_inner] = A[((10 * blockIdx.x) + ((200 * threadIdx.x) + j_inner))]; + }; }; + for (int32_t i = 0; i < 10; i += 1) { + C[((10 * blockIdx.x) + ((200 * threadIdx.x) + i))] = (A_read_cache_3[i] + B[((10 * blockIdx.x) + ((200 * threadIdx.x) + i))]); + }; + } }; - for (int32_t i = 0; i < 10; i += 1) { - C[((10 * blockIdx.x) + ((200 * threadIdx.x) + i))] = (A_read_cache_3[i] + B[((10 * blockIdx.x) + ((200 * threadIdx.x) + i))]); - }; + } }; } } )ROC"; - // ASSERT_EQ(utils::Trim(source_target), source_code); + ASSERT_EQ(utils::Trim(source_target), source_code); auto [host_module, device_module] = SplitCudaAndHostModule(module); // NOLINT @@ -988,22 +994,40 @@ void fn0_kernel(const float* __restrict__ A, const float* __restrict__ B, float* } TEST(ElementwiseAdd, cache_read1) { - Context::Global().ResetNameId(); - Expr M(100); Expr N(200); - Placeholder A("A", {M, N}); - Placeholder B("B", {M, N}); + auto create_module = [&] { + Context::Global().ResetNameId(); - auto C = Compute( - {M - 2, N}, [&](Expr i, Expr j) { return A(i, j) + A(i + 1, j) + A(i + 2, j) + B(i, j); }, "C"); - C->stage()->Split(1, 10); + Placeholder A("A", {M, N}); + Placeholder B("B", {M, N}); - auto AL = A->stage()->CacheRead("local", {C}); - AL->stage()->Split(1, 10); + auto C = Compute( + {M - 2, N}, [&](Expr i, Expr j) { return A(i, j) + A(i + 1, j) + A(i + 2, j) + B(i, j); }, "C"); + C->stage()->Split(1, 10); - AL->stage()->ComputeAt(C->stage(), 1, poly::Stage::ComputeAtKind::kComputeAtUnk, A->name); + auto AL = A->stage()->CacheRead("local", {C}); + AL->stage()->Split(1, 10); + + AL->stage()->ComputeAt(C->stage(), 1, poly::Stage::ComputeAtKind::kComputeAtUnk, A->name); + + return std::make_tuple(A, B, C, AL); + }; + { + auto [A, B, C, AL] = create_module(); // NOLINT + auto fn = Lower("fn1", {A, B, C}, {}, {AL}); + CodeGenC codegen_c(common::DefaultHostTarget()); + codegen_c.SetInlineBuiltinCodes(false); + + Module::Builder builder("module", common::DefaultHostTarget()); + builder.AddFunction(fn); + + auto c_source_code = codegen_c.Compile(builder.Build(), CodeGenC::OutputKind::CImpl); + std::cout << "C source code:\n" << c_source_code << std::endl; + } + + auto [A, B, C, AL] = create_module(); // NOLINT C->stage()->Bind(0, "threadIdx.x"); C->stage()->Bind(1, "blockIdx.x"); @@ -1011,12 +1035,11 @@ TEST(ElementwiseAdd, cache_read1) { CodeGenCUDA_Dev codegen(target); auto fn = Lower("fn1", {A, B, C}, {}, {AL}); - - Module::Builder builder("module", target); + Module::Builder builder("module", common::DefaultHostTarget()); builder.AddFunction(fn); auto source_code = codegen.Compile(builder.Build()); - std::cout << "source:\n" << source_code << std::endl; + std::cout << "CUDA source:\n" << source_code << std::endl; std::string source_target = R"ROC( extern "C" { @@ -1055,18 +1078,81 @@ void fn1_kernel(const float* __restrict__ A, const float* __restrict__ B, float* } )ROC"; - ASSERT_EQ(utils::Trim(source_target), source_code); + + std::string source_target1 = R"ROC( +extern "C" { + +#ifdef __CUDACC_RTC__ +typedef int int32_t; +typedef char int8_t; +#endif + + + +__global__ +void fn1_kernel(const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ C) +{ + float _A_read_cache_3 [ 3 * 10 ]; + float* A_read_cache_3 = _A_read_cache_3; + { + if (((((threadIdx.x >= 0) && (threadIdx.x <= 97)) && (blockIdx.x >= 0)) && (blockIdx.x <= 19))) { + for (int32_t i = 0; i < 3; i += 1) { + for (int32_t j_inner = 0; j_inner < 10; j_inner += 1) { + A_read_cache_3[((10 * i) + j_inner)] = A[((10 * blockIdx.x) + ((200 * (i+threadIdx.x)) + j_inner))]; + }; + }; + }; + for (int32_t i = 0; i < 10; i += 1) { + C[((10 * blockIdx.x) + ((200 * threadIdx.x) + i))] = (A_read_cache_3[i] + (A_read_cache_3[(10 + i)] + (A_read_cache_3[(20 + i)] + B[((10 * blockIdx.x) + ((200 * threadIdx.x) + i))]))); + }; + }; +} + +} +)ROC"; + + // ASSERT_EQ(utils::Trim(source_target), source_code); common::CudaModuleTester tester; + // tester.Compile(builder.Build(), source_target1); tester.Compile(builder.Build()); - auto* A_host = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build(); - auto* B_host = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build(); - auto* C_host = common::BufferBuilder(Float(32), {M.as_int32() - 2, N.as_int32()}).set_zero().Build(); + auto* A_host = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build(); + auto* B_host = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_random().Build(); + auto* C_host = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_zero().Build(); + auto* C_target_host = common::BufferBuilder(Float(32), {M.as_int32(), N.as_int32()}).set_zero().Build(); auto* A_dev = tester.CreateDeviceBuffer(A_host); auto* B_dev = tester.CreateDeviceBuffer(B_host); auto* C_dev = tester.CreateDeviceBuffer(C_host); + + cinn_buffer_t* dev_bufs[3]; + for (int i = 0; i < 3; i++) dev_bufs[i] = new cinn_buffer_t; + dev_bufs[0]->host_memory = reinterpret_cast(A_dev); + dev_bufs[1]->host_memory = reinterpret_cast(B_dev); + dev_bufs[2]->host_memory = reinterpret_cast(C_dev); + auto args = common::ArgsBuilder().Add(dev_bufs[0]).Add(dev_bufs[1]).Add(dev_bufs[2]).Build(); + + CUDA_CALL(cudaDeviceSynchronize()); + tester("fn1", args.data(), args.size()); + CUDA_CALL(cudaDeviceSynchronize()); + + CUDA_CALL(cudaMemcpy(reinterpret_cast(C_target_host->host_memory), + C_dev, + C_target_host->num_elements() * sizeof(float), + cudaMemcpyDeviceToHost)); + + auto* C_target_mem = reinterpret_cast(C_target_host->host_memory); + auto* A_mem = reinterpret_cast(A_host->host_memory); + auto* B_mem = reinterpret_cast(B_host->host_memory); + for (int i = 0; i < M.as_int32() - 2; i++) { + for (int j = 0; j < N.as_int32(); j++) { + ASSERT_NEAR(C_target_mem[i * N.as_int32() + j], + A_mem[i * N.as_int32() + j] + A_mem[(i + 1) * N.as_int32() + j] + A_mem[(i + 2) * N.as_int32() + j] + + B_mem[i * N.as_int32() + j], + 1e-5); + } + } } } // namespace backends diff --git a/cinn/backends/extern_func_emitter.cc b/cinn/backends/extern_func_emitter.cc index 8179361e27430..12d44e19db92e 100644 --- a/cinn/backends/extern_func_emitter.cc +++ b/cinn/backends/extern_func_emitter.cc @@ -1,5 +1,6 @@ #include "cinn/backends/extern_func_emitter.h" +#include #include #include #include @@ -7,6 +8,7 @@ #include "cinn/backends/extern_func_emitter_builtin.h" #include "cinn/backends/llvm/runtime_symbol_registry.h" #include "cinn/runtime/cpu/host_intrinsics.h" +#include "cinn/utils/string.h" namespace cinn { namespace backends { @@ -17,7 +19,7 @@ ExternFunctionEmitterRegistry& ExternFunctionEmitterRegistry::Global() { } void ExternFunctionEmitterRegistry::Register(const ExternFuncID& name, ExternFunctionEmitter* x) { - std::cerr << "Register extern function emitter [" << name << "]" << std::endl; + RAW_LOG_INFO("Register extern function emitter [%s]", utils::GetStreamCnt(name).c_str()); CHECK(x); data_[name] = std::unique_ptr(x); } diff --git a/cinn/common/arithmatic.cc b/cinn/common/arithmatic.cc index 6f36b23fee0d5..39ab91742536a 100644 --- a/cinn/common/arithmatic.cc +++ b/cinn/common/arithmatic.cc @@ -29,7 +29,11 @@ std::string ExprToGinacConerter::Repr(const ir::Expr& expr) { auto* var_n = expr.As<_Var_>(); auto* broadcast_n = expr.As(); auto* mod_n = expr.As(); - if (load_n || broadcast_n || mod_n) { + auto* min_n = expr.As(); + auto* max_n = expr.As(); + auto* div_n = expr.As
(); + auto* frac_n = expr.As(); + if (load_n || broadcast_n || mod_n || min_n || max_n || div_n || frac_n) { std::string repr = GetStreamCnt(expr); Replace(&repr, "[", "lsq_"); Replace(&repr, "]", "_rsq"); @@ -65,8 +69,16 @@ GiNaC::ex ExprToGinacConerter::BuildHelper(ir::Expr expr) { auto* broadcast_n = expr.As(); auto* mod_n = expr.As(); auto* frac_n = expr.As(); + auto* min_n = expr.As(); + auto* max_n = expr.As(); + + bool is_integer_math = expr.type().is_int(); - if (load_n || var_n || broadcast_n || mod_n) { + bool is_invalid_arith = load_n || var_n || broadcast_n || mod_n || min_n || max_n; + if (is_integer_math) + is_invalid_arith = is_invalid_arith || div_n || frac_n; // GiNac can't deal with integer division. + + if (is_invalid_arith) { RecordExpr(expr); std::string repr = Repr(expr); return CreateGinacSymbol(repr); @@ -98,8 +110,6 @@ GiNaC::ex ExprToGinacConerter::operator()(Expr expr) { auto complex_nodes = CollectIRNodes(expr, [](const Expr* n) { return n->As() || // n->As() || // - n->As() || // - n->As() || // n->As() || // n->As() || // n->As() || // @@ -118,9 +128,6 @@ GiNaC::ex ExprToGinacConerter::operator()(Expr expr) { n->As(); }); - for (auto& node : complex_nodes) { - LOG(INFO) << "complex nodes: " << node; - } CHECK(complex_nodes.empty()) << "Ginac converter can only deal with simple math expression, but get some complex nodes" << expr; @@ -252,10 +259,12 @@ bool MathContainsSymbol(Expr expr, Var symbol) { // lhs >= rhs. std::tuple Solve(Expr lhs, Expr rhs, Var var) { + VLOG(4) << "Solve: " << lhs << "=" << rhs << " in " << var; ExprToGinacConerter converter; auto lhs_ex = converter(lhs); auto rhs_ex = converter(rhs); ginac::lst eqs{lhs_ex == rhs_ex}; + VLOG(4) << "eqs: " << eqs; const auto& symbol = converter.GetSymbol(var->name); ginac::lst vars{symbol}; ginac::ex res = ginac::lsolve(eqs, vars); diff --git a/cinn/common/cas.cc b/cinn/common/cas.cc index 659b7b820e7c9..07eb6fa63b973 100644 --- a/cinn/common/cas.cc +++ b/cinn/common/cas.cc @@ -1453,7 +1453,8 @@ Expr SolveInequality(Expr inequality, Var val) { Expr all = AutoSimplify(a - b); - if (common::IsPureMath(a) && common::IsPureMath(b)) { + // if (common::IsPureMath(a) && common::IsPureMath(b)) { + if (true) { auto [res, positive] = common::Solve(a, b, val); // NOLINT // Simplify it with CAS to avoid random result from GiNac. res = AutoSimplify(res); diff --git a/cinn/common/cas.h b/cinn/common/cas.h index 43ac97a3cef3f..b9b6feafca76a 100644 --- a/cinn/common/cas.h +++ b/cinn/common/cas.h @@ -36,6 +36,7 @@ Expr CasSimplify(Expr u, const std::unordered_map& var * @return an copied expression looks like x < 100. */ Expr SolveInequality(Expr inequality, Var val); +Expr SolveInequalityInt(Expr inequality, Var val); namespace detail { diff --git a/cinn/common/cas_test.cc b/cinn/common/cas_test.cc index 05a68f6cc13dd..04c80e67337c6 100644 --- a/cinn/common/cas_test.cc +++ b/cinn/common/cas_test.cc @@ -306,12 +306,15 @@ TEST(CAS, IntConnerCase) { TEST(SolveInequality, basic) { Var x("x", Int(32)); + Var y("y", Int(32)); #define TEST_SOLVE(expr__, str__) EXPECT_EQ(GetStreamCnt(SolveInequality(expr__, x)), str__); TEST_SOLVE(x * -1 + 20 < 0, "(x > 20)"); TEST_SOLVE(x * 2 + 3 < x * 10 - 20, "(x > 2)"); TEST_SOLVE(x * -1 < -1, "(x > 1)"); TEST_SOLVE(Expr(2) * x * -1 - x < x + 200, "(x > -50)"); + TEST_SOLVE(Expr(2) * x + 30 - x * 3 + y * 23 < 2, "(x > int32((28 + (23 * y))))"); + TEST_SOLVE(x + ir::Min::Make(Expr(2), Expr(3) * y) < 100, "(x < int32((100 - cinn_min(2, (3 * y)))))"); } } // namespace common diff --git a/cinn/common/cuda_test_helper.cc b/cinn/common/cuda_test_helper.cc index be4dd41ab81dc..e705c616dfe7b 100644 --- a/cinn/common/cuda_test_helper.cc +++ b/cinn/common/cuda_test_helper.cc @@ -11,14 +11,20 @@ namespace cinn { namespace common { #ifdef CINN_WITH_CUDA -void CudaModuleTester::Compile(const lang::Module& m) { +void CudaModuleTester::Compile(const lang::Module& m, const std::string& rewrite_cuda_code) { auto [host_module, device_module] = backends::SplitCudaAndHostModule(m); // NOLINT backends::CodeGenCUDA_Dev codegen(DefaultHostTarget()); auto source_code = codegen.Compile(m); // compile CUDA kernel. backends::NVRTC_Compiler compiler; - auto ptx = compiler(source_code); + + std::string ptx; + if (rewrite_cuda_code.empty()) + ptx = compiler(source_code); + else + ptx = compiler(rewrite_cuda_code); + cuda_module_ = new runtime::cuda::CUDAModule(ptx, runtime::cuda::CUDAModule::Kind::PTX); for (auto& fn : device_module.functions()) { diff --git a/cinn/common/cuda_test_helper.h b/cinn/common/cuda_test_helper.h index a30c03e80a646..758484ce56998 100644 --- a/cinn/common/cuda_test_helper.h +++ b/cinn/common/cuda_test_helper.h @@ -18,7 +18,7 @@ class CudaModuleTester { // Call the host function in JIT. void operator()(const std::string& fn_name, void* args, int arg_num); - void Compile(const lang::Module& m); + void Compile(const lang::Module& m, const std::string& rewrite_cuda_code = ""); void* LookupKernel(const std::string& name); diff --git a/cinn/common/ir_util.cc b/cinn/common/ir_util.cc index 4bfc3bf2d187b..12defb1a9f078 100644 --- a/cinn/common/ir_util.cc +++ b/cinn/common/ir_util.cc @@ -323,5 +323,41 @@ std::vector GatherItersToTensorProducer(const std::string &target_t return Visitor(target_tensor_name)(expr); } +std::vector GetForloopStackToStore(Expr *expr, const std::string &tensor_name) { + VLOG(4) << "search store " << tensor_name << " in expr:\n"; + VLOG(4) << *expr; + struct Mutator : public ir::IRMutator<> { + std::vector forloop_stack; + bool found{false}; + + std::string tensor_name; + + explicit Mutator(const std::string &tensor_name) : tensor_name(tensor_name) {} + + std::vector operator()(Expr *expr) { + ir::IRMutator<>::Visit(expr, expr); + return forloop_stack; + } + + void Visit(const ir::For *op, Expr *expr) { + auto *node = expr->As(); + forloop_stack.push_back(expr); + ir::IRMutator<>::Visit(&node->body, &node->body); + if (!found) forloop_stack.pop_back(); + } + + void Visit(const ir::PolyFor *op, Expr *expr) { + auto *node = expr->As(); + forloop_stack.push_back(expr); + ir::IRMutator<>::Visit(&node->body, &node->body); + if (!found) forloop_stack.pop_back(); + } + + void Visit(const ir::Store *op, Expr *expr) { found = op->tensor.as_tensor()->name == tensor_name; } + }; + + return Mutator(tensor_name)(expr); +} + } // namespace common } // namespace cinn diff --git a/cinn/common/ir_util.h b/cinn/common/ir_util.h index 5a55e42b35507..6ff9454f0c427 100644 --- a/cinn/common/ir_util.h +++ b/cinn/common/ir_util.h @@ -17,22 +17,10 @@ Expr CastIfNeeded(Expr body, Type type); void Substitute(Expr *expr, const std::map &var_map); template -Expr make_const(Type t, T v) { - if (t.is_vector()) { - if (t.type() == Type::type_t::Int) { - return ir::Broadcast::Make(make_shared(t.ElementOf(), v), t.lanes()); - } else { - return ir::Broadcast::Make(make_shared(t.ElementOf(), v), t.lanes()); - } - } else { - if (t.type() == Type::type_t::Int) { - return make_shared(t, v); - } else { - return make_shared(t, v); - } - } - return Expr(); -} +Expr make_const(Type t, T v); + +//! Get a stack of forloops(For and PolyFor nodes) to a Store node target to \p tensor_name +std::vector GetForloopStackToStore(Expr *expr, const std::string &tensor_name); // make const // @{ @@ -84,5 +72,23 @@ Expr or_any(const std::vector &conds); //! Cast the expression \p e to type \type. Expr cast(Expr e, Type type); +template +Expr make_const(Type t, T v) { + if (t.is_vector()) { + if (t.type() == Type::type_t::Int) { + return ir::Broadcast::Make(make_shared(t.ElementOf(), v), t.lanes()); + } else { + return ir::Broadcast::Make(make_shared(t.ElementOf(), v), t.lanes()); + } + } else { + if (t.type() == Type::type_t::Int) { + return make_shared(t, v); + } else { + return make_shared(t, v); + } + } + return Expr(); +} + } // namespace common } // namespace cinn diff --git a/cinn/lang/CMakeLists.txt b/cinn/lang/CMakeLists.txt index 2ad1188c3c830..4d01ba6f4bf98 100644 --- a/cinn/lang/CMakeLists.txt +++ b/cinn/lang/CMakeLists.txt @@ -1,5 +1,6 @@ set(srcs buffer.cc compute.cc placeholder.cc tensor.cc module.cc lower.cc builtin.cc lower_impl.cc + compute_at_postprocess.cc ) foreach(cpp ${srcs}) diff --git a/cinn/lang/compute_at_postprocess.cc b/cinn/lang/compute_at_postprocess.cc new file mode 100644 index 0000000000000..bf15f6042b1e6 --- /dev/null +++ b/cinn/lang/compute_at_postprocess.cc @@ -0,0 +1,466 @@ +#include "cinn/lang/compute_at_postprocess.h" + +#include "cinn/common/ir_util.h" +#include "cinn/ir/ir_mutator.h" +#include "cinn/ir/ir_operators.h" +#include "cinn/ir/ir_printer.h" +#include "cinn/lang/tensor.h" +#include "cinn/optim/ir_replace.h" +#include "cinn/optim/ir_simplify.h" +#include "cinn/poly/compute_at_transform.h" + +namespace cinn { +namespace lang { +using ir::ComputeAtInfo; + +namespace detail { + +/** + * Process the producer related Store and Load indices. + */ +struct NormalizeProducerDomainMutator : public ir::IRMutator<> { + std::map offsets; + std::vector consumer_axis; + std::string producer_tuple; + + NormalizeProducerDomainMutator(const std::string& producer_tuple, const std::vector& consumer_axis) + : producer_tuple(producer_tuple), consumer_axis(consumer_axis) {} + + void operator()(Expr* forloop) { ir::IRMutator<>::Visit(forloop, forloop); } + + //! Add offsets to store, e.g. offset is i->3, the original store expr is a[i,j] = b[i*2,j], the result expression + //! will be a[i+3,j] = b[(i+3)*2,j] + void AddOffsetsToStoreExpr(Expr* expr) { + LOG(INFO) << "*AddOffsetsToStoreExpr: " << *expr; + CHECK(expr->As()); + for (auto& offset : offsets) { + LOG(INFO) << "Add to axis " << offset.first << " with offset " << offset.first << " => +" << offset.second; + optim::IrReplace(expr, offset.first, Expr(offset.first) + offset.second); + } + } + + /* Set the producer axis to zero in Store node, e.g. a store node, a[c0,c1] = ... will be a[0,0] + * + * poly_for (i, cinn_max(0, (po0 - 1)), (i <= (po0 + 1)), 1) + * { + * cache[i, po1] = A[i, po1] + * } + * + * will transform to + * + * poly_for (i, cinn_max(0, (po0 - 1)), (i <= (po0 + 1)), 1) + * { + * cache[i, 0] = A[i, po1] + * } + */ + void SetProducerAxisToZeroInStore(Expr* expr) { + auto* node = expr->As(); + CHECK(node); + + VLOG(3) << "SetProducerAxisToZeroInStore: " << *expr; + for (auto& indice : node->indices) { + for (auto& consumer_axis : consumer_axis) { + VLOG(3) << indice << " set producer axis [" << consumer_axis << "] to 0"; + optim::IrReplace(&indice, consumer_axis, common::make_const(0)); + } + } + } + + /* + * Make producer Store indice start from zero. + * + * NOTE the axis here should be producer's axis, `i` in the root function comment. + * + * poly_for (i, cinn_max(0, (po0 - 1)), (i <= (po0 + 1)), 1) + * { + * cache[i, po1] = A[i, po1] + * } + * + * will transform to + * + * poly_for (i, 0, (i + cinn_max(0, (po0 - 1)) <= (po0 + 1)), 1) + * { + * cache[i, po1] = A[i + cinn_max(0, (po0 - 1)), po1] + * } + */ + void AddOffsetToAxisInStoreValue(Expr* expr) { + optim::Simplify(expr); + LOG(INFO) << "AddOffsetToAxisInStoreValue to:\n" << *expr; + + auto* node = expr->As(); + + auto loads_but_producer = ir::CollectIRNodes(node->value, [&](const Expr* x) { + return x->As() && x->As()->tensor.as_tensor()->name != node->tensor.as_tensor()->name; + }); + + for (auto& item : loads_but_producer) { + auto* load = item.As(); + for (auto& indice : load->indices) { + for (auto& offset : offsets) { + LOG(INFO) << "*Add indice to [" << indice << "] => [" << offset.first << "] with offset [" << offset.second + << "]"; + optim::IrReplace(&Reference(&indice), offset.first, Expr(offset.first) + offset.second); + LOG(INFO) << "get: " << indice; + } + } + } + } + + void Visit(const ir::Store* op, Expr* expr) override { + auto* node = expr->As(); + + if (op->tensor.as_tensor()->name == producer_tuple) { + // AddOffsetsToStoreExpr(expr); + + // replace the producer axis in store indice to zero. + SetProducerAxisToZeroInStore(expr); + + // replace the consumer axis in value(not producer) to offset. + AddOffsetToAxisInStoreValue(expr); + } else { + ir::IRMutator<>::Visit(op, expr); + } + } + + void Visit(const ir::For* op, Expr* expr) override { + auto* node = expr->As(); + if (!common::is_zero(op->min)) { + auto offset = op->min; + node->min = common::make_const(0); + node->extent = node->extent - offset; + offsets[node->loop_var] = offset; + } + ir::IRMutator<>::Visit(&node->body, &node->body); + } + + void Visit(const ir::PolyFor* op, Expr* expr) override { + auto* node = expr->As(); + if (!common::is_zero(op->init)) { + auto offset = op->init; + offsets[node->iterator] = offset; + node->init = common::make_const(0); + UpdatePolyForConditionWithOffset(&node->condition, node->iterator, offset); + } + ir::IRMutator<>::Visit(&node->body, &node->body); + } + + void UpdatePolyForConditionWithOffset(Expr* cond, Var iter, Expr offset) { + optim::IrReplace(cond, iter, Expr(iter) + offset); + } +}; + +struct ResetProducerLoadIndiceInConsumerMutator : public ir::IRMutator<> { + const std::string& producer_tensor_name; + const std::vector& consumer_axis; + const ComputeAtInfo& compute_at_info; + + ResetProducerLoadIndiceInConsumerMutator(const std::string& producer_tensor_name, + const std::vector& consumer_axis, + const ComputeAtInfo& compute_at_info) + : producer_tensor_name(producer_tensor_name), consumer_axis(consumer_axis), compute_at_info(compute_at_info) {} + + void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } + + void Visit(const ir::Load* op, Expr* expr) override { + VLOG(3) << "Consumer modify Load " << *expr << "'s axis for producer [" << producer_tensor_name << "]"; + auto* node = expr->As(); + if (op->tensor.as_tensor()->name == producer_tensor_name) { + CHECK_LE(compute_at_info.preceding_offset_for_producer_load.size(), node->indices.size()); + for (auto axis : consumer_axis) { + for (auto& indice : node->indices) { + VLOG(3) << "Consumer Load " << indice << " set axis [" << axis << "] to 0"; + optim::IrReplace(&indice, axis, common::make_const(0)); + } + } + + for (int i = 0; i < compute_at_info.preceding_offset_for_producer_load.size(); i++) { + node->indices[i] = node->indices[i] + compute_at_info.preceding_offset_for_producer_load[i]; + } + } + // Load not recursive, no need to visit it's items. + } +}; + +} // namespace detail + +using ir::ComputeAtInfo; +/** + * Lets define the consumer tensor as C and the producer tensor as P for short. + * First, find the forloop generating C, keep the forloop levels in a stack. + * We need to modify the following + * 1. P's Store indice(change the parameters to zero) + * 2. P's Store value, change the parameters in Load to consumer's precending axis + * 3. replace the precending axis of the P's Load to zero in C + */ +struct CorrectComputeAtRelatedIndiceMutator : public ir::IRMutator<> { + std::string tensor_name; + + explicit CorrectComputeAtRelatedIndiceMutator(const std::string& tensor_name) : tensor_name(tensor_name) {} + + void operator()(Expr* e) { return ir::IRMutator<>::Visit(e, e); } + + void Visit(const ir::PolyFor* op, Expr* expr) override { + forloop_stack.push_back(expr); + ir::IRMutator<>::Visit(op, expr); + forloop_stack.pop_back(); + } + + void Visit(const ir::For* op, Expr* expr) override { + forloop_stack.push_back(expr); + ir::IRMutator<>::Visit(op, expr); + forloop_stack.pop_back(); + } + + /** + * Normalize the producer's domain, make it start from zero. This is essential for shrink the buffer and inference the + * buffer size. + * + * e.g. + * for (i=p0; i<3+p0; i++) { + * p[i] + * } + * will be transformed to + * for (i=0; i<3; i++) { + * p[i+p0] + * } + * + * @param producer_forloop_root The root of the producer's own axis, not the axis of consumer. + * + * About the \p producer_forloop_root, after compute_at schedule, + * // consumer iter ci + * for (ci) { + * // producer iter pi + * for (pi) { + * } + * } + * The pi should be the \p producer_forloop_root + */ + void NormalizeProducerDomain(Expr* producer_forloop_root, + const std::string& producer_tuple, + const std::vector& consumer_axis) { + VLOG(4) << "Normalize producer domain: " << producer_tuple; + VLOG(4) << "producer_forloop_root:\n" << *producer_forloop_root; + VLOG(4) << "consumer_axis:"; + for (auto& var : consumer_axis) { + VLOG(4) << "iter: " << var; + } + + detail::NormalizeProducerDomainMutator(producer_tuple, consumer_axis)(producer_forloop_root); + } + + //! Reset the indice of the producer Load in Consumer. + // Here we just set the minimum consumer axis to zero. e.g., for consumer statement such as + // `C[i] = A[i-1]+A[i]+A[i+1]` and level set to 0, the result statement will be `C[i] = A[0]+A[1]+A[2]`, this includes + // the following steps: + // 1. make the preceding level+1 axis to zero in producer load, we get `C[i] = A[-1]+A[0]+A[1]`. + // 2. for each adjusted axis, add an offset stored in ComputeAtInfo to make the minimum indice zero, then we get `C[i] + // = A[0]+A[1]+A[2]`. + void ResetProducerLoadIndiceInConsumer(const std::vector& consumer_axis, + Expr* consumer_store_expr, + const std::string& producer_tensor_name, + const ComputeAtInfo& compute_at_info) { + detail::ResetProducerLoadIndiceInConsumerMutator( + producer_tensor_name, consumer_axis, compute_at_info)(consumer_store_expr); + } + + void Visit(const ir::Store* op, Expr* expr) override { + auto* node = expr->As(); + + if (op->tensor.as_tensor()->name != tensor_name) { + ir::IRMutator<>::Visit(op, expr); + return; + } + + // get the target consumer + auto& compute_at_infos = op->tensor.as_tensor()->compute_at_infos; + CHECK(!compute_at_infos.empty()); + + std::vector levels; + for (Expr* forloop : forloop_stack) { + auto* for_n = forloop->As(); + auto* poly_for_n = forloop->As(); + if (for_n) + levels.push_back(for_n->loop_var); + else if (poly_for_n) + levels.push_back(poly_for_n->iterator); + else + NOT_IMPLEMENTED + } + + for (auto& compute_at_info : compute_at_infos) { + VLOG(4) << "compute_at: " << compute_at_info.producer_tensor_name; + detail::ReplaceParamWithConsumerAxis(compute_at_info, levels, forloop_stack.front()); + } + + for (auto& compute_at_info : compute_at_infos) { + int level = compute_at_info.level; + std::vector consumer_aixs(levels.begin(), levels.begin() + level + 1); + Expr* producer_forloop_root; + if (forloop_stack[level]->As()) { + producer_forloop_root = &forloop_stack[level]->As()->body; + } else { + producer_forloop_root = &forloop_stack[level]->As()->body; + } + + auto forloop_stack_to_store = GetForloopStackToStore(producer_forloop_root, compute_at_info.producer_tensor_name); + producer_forloop_root = forloop_stack_to_store.empty() ? forloop_stack[level] : forloop_stack_to_store.front(); + NormalizeProducerDomain(producer_forloop_root, compute_at_info.producer_tensor_name, consumer_aixs); + ResetProducerLoadIndiceInConsumer( + consumer_aixs, forloop_stack[level], compute_at_info.producer_tensor_name, compute_at_info); + } + } + + std::vector forloop_stack; +}; + +void ProcessComputeAtInfo(Expr* expr) { + // 1. collect all the consumer tensors thouse have compute_at_infos. + // 2. for each producer tensor, reset the producer tensor loads indice. + + // first, visit the consumer tensor with compute_at info. + // second, in the forloop stack, find the producer tensor + // - set the presending axis to zero in producer's Store node and Load node + // - replace the ISL parameter to the precending axis + // in consumer, reset presending axis in producer's Load to zero. + + auto tensor_with_compute_at_infos = ir::CollectIRNodes( + *expr, [&](const Expr* x) { return x->as_tensor() && !x->as_tensor()->compute_at_infos.empty(); }); + + for (auto& tensor : tensor_with_compute_at_infos) { + VLOG(4) << "consumer: " << tensor; + CorrectComputeAtRelatedIndiceMutator(tensor.as_tensor()->name)(expr); + } +} + +void UpdateComputeAtBufferShape(Expr* expr) { + auto tensor_with_compute_at_infos = ir::CollectIRNodes(*expr, [&](const Expr* x) { + return x->as_tensor() && !x->as_tensor()->inlined() && !x->as_tensor()->compute_at_infos.empty(); + }); + + auto tensor_map = ir::CollectTensorMap( + *expr, [&](const Expr* x) { return !x->as_tensor()->inlined() && x->as_tensor()->buffer.defined(); }); + + std::unordered_map buffer_to_compute_at_info; + for (auto& item : tensor_map) { + auto& compute_at_infos = item.second.as_tensor()->compute_at_infos; + if (compute_at_infos.empty()) continue; + for (auto& compute_at : compute_at_infos) { + auto& producer_tensor = tensor_map.at(compute_at.producer_tensor_name); + buffer_to_compute_at_info[producer_tensor.as_tensor()->buffer->name] = &compute_at_infos.front(); + } + } + + auto process_tensor = [&](ir::_Tensor_* tensor, const ComputeAtInfo& compute_at_info) { + tensor->shape.clear(); + for (int v : compute_at_info.adjusted_producer_shape) { + tensor->shape.push_back(Expr(v)); + } + VLOG(4) << "Updated tensor: " << ir::Tensor(tensor); + }; + + auto process_buffer = [&](ir::_Buffer_* buffer, const ComputeAtInfo& compute_at_info) { + buffer->shape.clear(); + for (int v : compute_at_info.adjusted_producer_shape) { + buffer->shape.push_back(Expr(v)); + } + VLOG(4) << "Updated buffer: " << ir::Buffer(buffer); + }; + + auto process_alloca = [&](ir::Alloc* alloca, const ComputeAtInfo& compute_at_info) { + alloca->extents.clear(); + for (int v : compute_at_info.adjusted_producer_shape) { + alloca->extents.push_back(Expr(v)); + } + VLOG(4) << "Updated alloca: " << Expr(alloca); + }; + + auto tensors = ir::CollectIRNodes(*expr, [&](const Expr* x) { return x->as_tensor() && !x->as_tensor()->inlined(); }); + for (auto& t : tensors) { + if (!t.as_tensor()->buffer.defined() || !buffer_to_compute_at_info.count(t.as_tensor()->buffer->name)) continue; + auto& buffer = t.as_tensor()->buffer; + auto compute_at_it = buffer_to_compute_at_info.find(buffer->name); + if (compute_at_it != buffer_to_compute_at_info.end()) { + process_tensor(&Reference(t.as_tensor()), *compute_at_it->second); + process_buffer(Reference(t.as_tensor()).buffer->self(), *compute_at_it->second); + VLOG(4) << "resizing buffer " << t; + VLOG(4) << "resizing tensor " << t.as_tensor()->buffer; + } + } + + // update lowered func temporay buffers + auto lowered_fns = ir::CollectIRNodes(*expr, [&](const Expr* x) { return x->as_lowered_func(); }); + for (auto& lowered_fn : lowered_fns) { + auto* node = lowered_fn.as_lowered_func(); + for (auto& buf : node->temp_bufs) { + auto compute_at_it = buffer_to_compute_at_info.find(buf->name); + if (compute_at_it != buffer_to_compute_at_info.end()) { + process_buffer(Reference(&buf).operator->(), *compute_at_it->second); + } + } + } +} + +namespace detail { + +/** + * + * e.g. The original code is as follows: + * + * poly_for (po0, 0, (po0 <= 9), 1) + * { + * poly_for (po1, 0, (po1 <= 9), 1) + * { + * { + * if (((((_cp_C_0 >= 0) and (_cp_C_0 <= 9)) and (_cp_C_1 >= 0)) and (_cp_C_1 <= 9))) { + * poly_for (i, cinn_max(0, (_cp_C_0 - 1)), (i <= (_cp_C_0 + 1)), 1) + * { + * cache(i, _cp_C_1) + * } + * } + * C(po0, po1) + * } + * } + * } + * Note that, the _cp_C_0 like variables are ISL parameters. + * + * will transform to + * + * poly_for (po0, 0, (po0 <= 9), 1) + * { + * poly_for (po1, 0, (po1 <= 9), 1) + * { + * { + * if (((((po0 >= 0) and (po0 <= 9)) and (po1 >= 0)) and (po1 <= 9))) { + * poly_for (i, cinn_max(0, (po0 - 1)), (i <= (po0 + 1)), 1) + * { + * cache[i, po1] = A[i, po1] + * } + * } + * C[po0, po1] = select((po0 < 10), (((cache[(po0 - 1), po1] + cache[po0, po1]) + cache[(po0 + 1), po1]) + B[po0, + * po1]), 0) + * } + * } + * } + * + * @param info The compute at information. + * @param axis The consumer axis. + * @param consumer_forloop_root The first level of forloop of consumer. + */ +void ReplaceParamWithConsumerAxis(const ComputeAtInfo& info, + const std::vector& axis, + Expr* consumer_forloop_root) { + CHECK_LE(info.level + 1, axis.size()); + // replace the params to consumer's precending level+1 axis. + for (int i = 0; i < info.level + 1; i++) { + Var var(poly::GenConsumerParamName(info.consumer_tensor_name.c_str(), i)); + VLOG(4) << "replacing " << var << " to " << axis[i]; + optim::IrReplace(consumer_forloop_root, Expr(var), axis[i]); + } + + LOG(INFO) << "After ReplaceParamWithConsumerAxis:\n" << *consumer_forloop_root; +} + +} // namespace detail + +} // namespace lang +} // namespace cinn diff --git a/cinn/lang/compute_at_postprocess.h b/cinn/lang/compute_at_postprocess.h new file mode 100644 index 0000000000000..8909d93ecc2a6 --- /dev/null +++ b/cinn/lang/compute_at_postprocess.h @@ -0,0 +1,72 @@ +//! \file This file contains some post process of ComputeAt schedule. +#pragma once +#include +#include +#include +#include + +#include "cinn/ir/ir.h" +#include "cinn/lang/tensor.h" + +namespace cinn { +namespace lang { + +/** + * Deal with the `compute_at` transform, in the stage transform phase, we modified the domain and transform of the + * producer tensor, after isl Ast generation, there remains some postprocess here include + * + * 1. in producer tensor load, make each axis to zero + * 2. add offset + * + * e.g. + * + * auto A_cache = Compute({M, N}, [&](Expr i, Expr j) { return A(i, j); }, "cache"); + * auto C = Compute( + * {Expr(10), Expr(10)}, [&](Expr i, Expr j) { return A_cache(i, j) + A_cache(i+1,j) + B(i, j); }, "C"); + * A_cache->stage()->ComputeAt2(C->stage(), 0); + * + * \code + * function fn (_A, _B, _cache, _C) + * { + * for (_p0, 10) + * { + * for (i, 10) + * { + * if ((i <= 1)) { + * for (j, 10) + * { + * cache[i, j] = A[i, j] + * } + * } + * C[_p0, i] = (cache[_p0, i] + (cache[(1 + _p0), i] + B[_p0, i])) + * } + * } + * } + * \endcode + * + * The expression `C[_p0, i] = (cache[_p0, i] + (cache[(1 + _p0), i] + B[_p0, i]))` produces tensor `C`, but the cache + * should start from zero. + */ +void ProcessComputeAtInfo(Expr* expr); + +/** + * Resize the compute_at consumer buffer size. + */ +void UpdateComputeAtBufferShape(Expr* expr); + +namespace detail { + +/** + * Replace isl parameters with consumer iterators. + * @param info ComputeAt schedule related information. + * @param axis The consumer axis. + * @param consumer_forloop_root The first forloop level of consumer expression. + */ +void ReplaceParamWithConsumerAxis(const ir::ComputeAtInfo& info, + const std::vector& axis, + Expr* consumer_forloop_root); + +} // namespace detail + +} // namespace lang +} // namespace cinn diff --git a/cinn/lang/lower_impl.cc b/cinn/lang/lower_impl.cc index 1ec0ce0783159..cacdbbed4efe0 100644 --- a/cinn/lang/lower_impl.cc +++ b/cinn/lang/lower_impl.cc @@ -6,9 +6,11 @@ #include "cinn/common/ir_util.h" #include "cinn/ir/ir_printer.h" +#include "cinn/lang/compute_at_postprocess.h" #include "cinn/lang/tensor.h" #include "cinn/optim/cache_read_write_replace.h" #include "cinn/optim/ir_replace.h" +#include "cinn/optim/ir_simplify.h" #include "cinn/poly/compute_at_transform.h" namespace cinn { @@ -408,398 +410,6 @@ Expr LowerImpl::GenerateFunctionBody(const poly::Schedule* schedule) { return body; } -/** - * Lets define the consumer tensor as C and the producer tensor as P for short. - * First, find the forloop generating C, keep the forloop levels in a stack. - * We need to modify the following - * 1. P's Store indice(change the parameters to zero) - * 2. P's Store value, change the parameters in Load to consumer's precending axis - * 3. replace the precending axis of the P's Load to zero in C - */ -struct CorrectComputeAtRelatedIndiceMutator : public ir::IRMutator<> { - std::string tensor_name; - - CorrectComputeAtRelatedIndiceMutator(const std::string& tensor_name) : tensor_name(tensor_name) {} - - void operator()(Expr* e) { return ir::IRMutator<>::Visit(e, e); } - - void Visit(const ir::PolyFor* op, Expr* expr) override { - forloop_stack.push_back(expr); - ir::IRMutator<>::Visit(op, expr); - forloop_stack.pop_back(); - } - - void Visit(const ir::For* op, Expr* expr) override { - forloop_stack.push_back(expr); - ir::IRMutator<>::Visit(op, expr); - forloop_stack.pop_back(); - } - - // Replace the isl params with the real axis like `p0` in consumer. - void ReplaceParamWithConsumerAxis(const ComputeAtInfo& info, - const std::vector& axis, - Expr* consumer_forloop_root) { - CHECK_LE(info.level + 1, axis.size()); - // replace the params to consumer's precending level+1 axis. - for (int i = 0; i < info.level + 1; i++) { - Var var(poly::GenConsumerParamName(info.consumer_tensor_name.c_str(), i)); - VLOG(4) << "replacing " << var << " to " << axis[i]; - optim::IrReplace(consumer_forloop_root, Expr(var), axis[i]); - } - } - - //! Get a stack of forloops to a Store node target to \p tensor_name - std::vector GetForloopStackToStore(Expr* expr, const std::string& tensor_name) { - VLOG(4) << "search store " << tensor_name << " in expr:\n"; - VLOG(4) << *expr; - struct Mutator : public ir::IRMutator<> { - std::vector forloop_stack; - bool found{false}; - - std::string tensor_name; - - Mutator(const std::string& tensor_name) : tensor_name(tensor_name) {} - - std::vector operator()(Expr* expr) { - ir::IRMutator<>::Visit(expr, expr); - return forloop_stack; - } - - void Visit(const ir::For* op, Expr* expr) { - auto* node = expr->As(); - forloop_stack.push_back(expr); - ir::IRMutator<>::Visit(&node->body, &node->body); - if (!found) forloop_stack.pop_back(); - } - - void Visit(const ir::PolyFor* op, Expr* expr) { - auto* node = expr->As(); - forloop_stack.push_back(expr); - ir::IRMutator<>::Visit(&node->body, &node->body); - if (!found) forloop_stack.pop_back(); - } - - void Visit(const ir::Store* op, Expr* expr) { found = op->tensor.as_tensor()->name == tensor_name; } - }; - - return Mutator(tensor_name)(expr); - } - - /** - * Normalize the producer's domain, make it start from zero. This is essential for shrink the buffer and inference the - * buffer size. - * - * e.g. - * for (i=p0; i<3+p0; i++) { - * p[i] - * } - * will be transformed to - * for (i=0; i<3; i++) { - * p[i+p0] - * } - * - * @param producer_forloop_root The root of the producer's own axis, not the axis of consumer. - * - * About the \p producer_forloop_root, after compute_at schedule, - * // consumer iter ci - * for (ci) { - * // producer iter pi - * for (pi) { - * } - * } - * The pi should be the \p producer_forloop_root - */ - void NormalizeProducerDomain(Expr* producer_forloop_root, - const std::string& producer_tuple, - const std::vector& consumer_axis) { - VLOG(4) << "Normalize producer domain: " << producer_tuple; - VLOG(4) << "producer_forloop_root:\n" << *producer_forloop_root; - VLOG(4) << "consumer_axis:"; - for (auto& var : consumer_axis) { - VLOG(4) << "iter: " << var; - } - - struct Mutator : public ir::IRMutator<> { - std::map offsets; - std::vector consumer_axis; - std::string producer_tuple; - - Mutator(const std::string& producer_tuple, const std::vector& consumer_axis) - : producer_tuple(producer_tuple), consumer_axis(consumer_axis) {} - - void operator()(Expr* forloop) { ir::IRMutator<>::Visit(forloop, forloop); } - - //! Add offsets to store, e.g. offset is i->3, the original store expr is a[i,j] = b[i*2,j], the result expression - //! will be a[i+3,j] = b[(i+3)*2,j] - void AddOffsetsToStoreExpr(Expr* expr) { - CHECK(expr->As()); - for (auto& offset : offsets) { - optim::IrReplace(expr, offset.first, Expr(offset.first) + offset.second); - } - } - - //! Set the producer axis to zero in Store node, e.g. a store node, a[c0,c1] = ... will be a[0,0] - void SetProducerAxisToZeroInStore(Expr* expr) { - auto* node = expr->As(); - CHECK(node); - - VLOG(3) << "SetProducerAxisToZeroInStore: " << *expr; - for (auto& indice : node->indices) { - for (auto& consumer_axis : consumer_axis) { - VLOG(3) << indice << " set producer axis [" << consumer_axis << "] to 0"; - optim::IrReplace(&indice, consumer_axis, common::make_const(0)); - } - } - } - - //! NOTE the axis here should be producer's axis, `i` in the root function comment. - void AddOffsetToAxisInStoreValue(Expr* expr) { - auto* node = expr->As(); - - auto loads_but_producer = ir::CollectIRNodes(node->value, [&](const Expr* x) { - return x->As() && x->As()->tensor.as_tensor()->name != node->tensor.as_tensor()->name; - }); - - for (auto& item : loads_but_producer) { - auto* load = item.As(); - for (auto& indice : load->indices) { - for (auto& offset : offsets) { - optim::IrReplace(&Reference(&indice), offset.first, Expr(offset.first) + offset.second); - } - } - } - } - - void Visit(const ir::Store* op, Expr* expr) override { - auto* node = expr->As(); - - if (op->tensor.as_tensor()->name == producer_tuple) { - AddOffsetsToStoreExpr(expr); - - // replace the producer axis in store indice to zero. - SetProducerAxisToZeroInStore(expr); - - // replace the consumer axis in value(not producer) to offset. - AddOffsetToAxisInStoreValue(expr); - } else { - ir::IRMutator<>::Visit(op, expr); - } - } - - void Visit(const ir::For* op, Expr* expr) override { - auto* node = expr->As(); - if (!common::is_zero(op->min)) { - auto offset = op->min; - node->min = common::make_const(0); - node->extent = node->extent - offset; - offsets[node->loop_var] = offset; - } - ir::IRMutator<>::Visit(&node->body, &node->body); - } - - void Visit(const ir::PolyFor* op, Expr* expr) override { - auto* node = expr->As(); - if (!common::is_zero(op->init)) { - auto offset = op->init; - node->init = common::make_const(0); - UpdatePolyForConditionWithOffset(&node->condition, node->iterator, offset); - } - ir::IRMutator<>::Visit(&node->body, &node->body); - } - - void UpdatePolyForConditionWithOffset(Expr* cond, Var iter, Expr offset) { - optim::IrReplace(cond, iter, Expr(iter) + offset); - } - }; - - Mutator(producer_tuple, consumer_axis)(producer_forloop_root); - } - - //! Reset the indice of the producer Load in Consumer. - // Here we just set the minimum consumer axis to zero. e.g., for consumer statement such as - // `C[i] = A[i-1]+A[i]+A[i+1]` and level set to 0, the result statement will be `C[i] = A[0]+A[1]+A[2]`, this includes - // the following steps: - // 1. make the preceding level+1 axis to zero in producer load, we get `C[i] = A[-1]+A[0]+A[1]`. - // 2. for each adjusted axis, add an offset stored in ComputeAtInfo to make the minimum indice zero, then we get `C[i] - // = A[0]+A[1]+A[2]`. - void ResetProducerLoadIndiceInConsumer(const std::vector& consumer_axis, - Expr* consumer_store_expr, - const std::string& producer_tensor_name, - const ComputeAtInfo& compute_at_info) { - struct Mutator : public ir::IRMutator<> { - const std::string& producer_tensor_name; - const std::vector& consumer_axis; - const ComputeAtInfo& compute_at_info; - - Mutator(const std::string& producer_tensor_name, - const std::vector& consumer_axis, - const ComputeAtInfo& compute_at_info) - : producer_tensor_name(producer_tensor_name), - consumer_axis(consumer_axis), - compute_at_info(compute_at_info) {} - - void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } - - void Visit(const ir::Load* op, Expr* expr) override { - VLOG(3) << "Consumer modify Load " << *expr << "'s axis for producer [" << producer_tensor_name << "]"; - auto* node = expr->As(); - if (op->tensor.as_tensor()->name == producer_tensor_name) { - CHECK_LE(compute_at_info.preceding_offset_for_producer_load.size(), node->indices.size()); - for (auto axis : consumer_axis) { - for (auto& indice : node->indices) { - VLOG(3) << "Consumer Load " << indice << " set axis [" << axis << "] to 0"; - optim::IrReplace(&indice, axis, common::make_const(0)); - } - } - - for (int i = 0; i < compute_at_info.preceding_offset_for_producer_load.size(); i++) { - node->indices[i] = node->indices[i] + compute_at_info.preceding_offset_for_producer_load[i]; - } - } - // Load not recursive, no need to visit it's items. - } - }; - - Mutator(producer_tensor_name, consumer_axis, compute_at_info)(consumer_store_expr); - } - - void Visit(const ir::Store* op, Expr* expr) override { - auto* node = expr->As(); - - if (op->tensor.as_tensor()->name != tensor_name) { - ir::IRMutator<>::Visit(op, expr); - return; - } - - // get the target consumer - auto& compute_at_infos = op->tensor.as_tensor()->compute_at_infos; - CHECK(!compute_at_infos.empty()); - - std::vector levels; - for (Expr* forloop : forloop_stack) { - auto* for_n = forloop->As(); - auto* poly_for_n = forloop->As(); - if (for_n) - levels.push_back(for_n->loop_var); - else if (poly_for_n) - levels.push_back(poly_for_n->iterator); - else - NOT_IMPLEMENTED - } - - for (auto& compute_at_info : compute_at_infos) { - VLOG(4) << "compute_at: " << compute_at_info.producer_tensor_name; - ReplaceParamWithConsumerAxis(compute_at_info, levels, forloop_stack.front()); - } - - for (auto& compute_at_info : compute_at_infos) { - int level = compute_at_info.level; - std::vector consumer_aixs(levels.begin(), levels.begin() + level + 1); - Expr* producer_forloop_root; - if (forloop_stack[level]->As()) { - producer_forloop_root = &forloop_stack[level]->As()->body; - } else { - producer_forloop_root = &forloop_stack[level]->As()->body; - } - - auto forloop_stack_to_store = GetForloopStackToStore(producer_forloop_root, compute_at_info.producer_tensor_name); - producer_forloop_root = forloop_stack_to_store.empty() ? forloop_stack[level] : forloop_stack_to_store.back(); - NormalizeProducerDomain(producer_forloop_root, compute_at_info.producer_tensor_name, consumer_aixs); - ResetProducerLoadIndiceInConsumer( - consumer_aixs, forloop_stack[level], compute_at_info.producer_tensor_name, compute_at_info); - } - } - - std::vector forloop_stack; -}; - -void ProcessComputeAtInfo(Expr* expr) { - // 1. collect all the consumer tensors thouse have compute_at_infos. - // 2. for each producer tensor, reset the producer tensor loads indice. - - // first, visit the consumer tensor with compute_at info. - // second, in the forloop stack, find the producer tensor - // - set the presending axis to zero in producer's Store node and Load node - // - replace the ISL parameter to the precending axis - // in consumer, reset presending axis in producer's Load to zero. - - auto tensor_with_compute_at_infos = ir::CollectIRNodes( - *expr, [&](const Expr* x) { return x->as_tensor() && !x->as_tensor()->compute_at_infos.empty(); }); - - for (auto& tensor : tensor_with_compute_at_infos) { - VLOG(4) << "consumer: " << tensor; - CorrectComputeAtRelatedIndiceMutator(tensor.as_tensor()->name)(expr); - } -} - -void UpdateComputeAtBufferShape(Expr* expr) { - auto tensor_with_compute_at_infos = ir::CollectIRNodes(*expr, [&](const Expr* x) { - return x->as_tensor() && !x->as_tensor()->inlined() && !x->as_tensor()->compute_at_infos.empty(); - }); - - auto tensor_map = ir::CollectTensorMap( - *expr, [&](const Expr* x) { return !x->as_tensor()->inlined() && x->as_tensor()->buffer.defined(); }); - - std::unordered_map buffer_to_compute_at_info; - for (auto& item : tensor_map) { - auto& compute_at_infos = item.second.as_tensor()->compute_at_infos; - if (compute_at_infos.empty()) continue; - for (auto& compute_at : compute_at_infos) { - auto& producer_tensor = tensor_map.at(compute_at.producer_tensor_name); - buffer_to_compute_at_info[producer_tensor.as_tensor()->buffer->name] = &compute_at_infos.front(); - } - } - - auto process_tensor = [&](ir::_Tensor_* tensor, const ComputeAtInfo& compute_at_info) { - tensor->shape.clear(); - for (int v : compute_at_info.adjusted_producer_shape) { - tensor->shape.push_back(Expr(v)); - } - VLOG(4) << "Updated tensor: " << ir::Tensor(tensor); - }; - - auto process_buffer = [&](ir::_Buffer_* buffer, const ComputeAtInfo& compute_at_info) { - buffer->shape.clear(); - for (int v : compute_at_info.adjusted_producer_shape) { - buffer->shape.push_back(Expr(v)); - } - VLOG(4) << "Updated buffer: " << ir::Buffer(buffer); - }; - - auto process_alloca = [&](ir::Alloc* alloca, const ComputeAtInfo& compute_at_info) { - alloca->extents.clear(); - for (int v : compute_at_info.adjusted_producer_shape) { - alloca->extents.push_back(Expr(v)); - } - VLOG(4) << "Updated alloca: " << Expr(alloca); - }; - - auto tensors = ir::CollectIRNodes(*expr, [&](const Expr* x) { return x->as_tensor() && !x->as_tensor()->inlined(); }); - for (auto& t : tensors) { - if (!t.as_tensor()->buffer.defined() || !buffer_to_compute_at_info.count(t.as_tensor()->buffer->name)) continue; - auto& buffer = t.as_tensor()->buffer; - auto compute_at_it = buffer_to_compute_at_info.find(buffer->name); - if (compute_at_it != buffer_to_compute_at_info.end()) { - process_tensor(&Reference(t.as_tensor()), *compute_at_it->second); - process_buffer(Reference(t.as_tensor()).buffer->self(), *compute_at_it->second); - VLOG(4) << "resizing buffer " << t; - VLOG(4) << "resizing tensor " << t.as_tensor()->buffer; - } - } - - // update lowered func temporay buffers - auto lowered_fns = ir::CollectIRNodes(*expr, [&](const Expr* x) { return x->as_lowered_func(); }); - for (auto& lowered_fn : lowered_fns) { - auto* node = lowered_fn.as_lowered_func(); - for (auto& buf : node->temp_bufs) { - auto compute_at_it = buffer_to_compute_at_info.find(buf->name); - if (compute_at_it != buffer_to_compute_at_info.end()) { - process_buffer(Reference(&buf).operator->(), *compute_at_it->second); - } - } - } -} - void LowerImpl::AddAxisInfoToFunc(ir::_LoweredFunc_* func) {} } // namespace detail diff --git a/cinn/lang/lower_impl.h b/cinn/lang/lower_impl.h index 3801e8cf3007f..0de252cc42d64 100644 --- a/cinn/lang/lower_impl.h +++ b/cinn/lang/lower_impl.h @@ -177,49 +177,6 @@ class LowerImpl { */ bool TensorContainsGPUInfo(ir::Tensor t); -/** - * Deal with the `compute_at` transform, in the stage transform phase, we modified the domain and transform of the - * producer tensor, after isl Ast generation, there remains some postprocess here include - * - * 1. in producer tensor load, make each axis to zero - * 2. add offset - * - * e.g. - * - * auto A_cache = Compute({M, N}, [&](Expr i, Expr j) { return A(i, j); }, "cache"); - * auto C = Compute( - * {Expr(10), Expr(10)}, [&](Expr i, Expr j) { return A_cache(i, j) + A_cache(i+1,j) + B(i, j); }, "C"); - * A_cache->stage()->ComputeAt2(C->stage(), 0); - * - * \code - * function fn (_A, _B, _cache, _C) - * { - * for (_p0, 10) - * { - * for (i, 10) - * { - * if ((i <= 1)) { - * for (j, 10) - * { - * cache[i, j] = A[i, j] - * } - * } - * C[_p0, i] = (cache[_p0, i] + (cache[(1 + _p0), i] + B[_p0, i])) - * } - * } - * } - * \endcode - * - * The expression `C[_p0, i] = (cache[_p0, i] + (cache[(1 + _p0), i] + B[_p0, i]))` produces tensor `C`, but the cache - * should start from zero. - */ -void ProcessComputeAtInfo(Expr* expr); - -/** - * Resize the compute_at consumer buffer size. - */ -void UpdateComputeAtBufferShape(Expr* expr); - /** * Mark the PolyFor as Vectorized if it is scheduled Vectorize in Stage. */ diff --git a/cinn/lang/tensor.cc b/cinn/lang/tensor.cc index ae2a6cde6107b..724b333835b42 100644 --- a/cinn/lang/tensor.cc +++ b/cinn/lang/tensor.cc @@ -351,7 +351,7 @@ void _Tensor_::WithBuffer(const std::string &memory_type, const Type &type) { buf->target = common::DefaultHostTarget(); Bind(buf); - if (memory_type == "share") { + if (memory_type == "shared") { buf->memory_type = MemoryType::GPUShared; } else if (memory_type == "local") { buf->memory_type = MemoryType::GPULocal; diff --git a/cinn/optim/cache_read_write_replace_test.cc b/cinn/optim/cache_read_write_replace_test.cc index 86ec73057adbe..c26b02fe0032a 100644 --- a/cinn/optim/cache_read_write_replace_test.cc +++ b/cinn/optim/cache_read_write_replace_test.cc @@ -20,7 +20,7 @@ TEST(CacheReadWriteReplace, basic) { {M, N}, [&](Expr i, Expr j) -> Expr { return A(i, j) + B(i, j); }, "C"); // AA cache - auto AA = A->stage()->CacheRead("share", {C}); + auto AA = A->stage()->CacheRead("shared", {C}); auto CC = C->stage()->CacheWrite("local"); auto fn = Lower("fn", {A, B, C}, {}, {AA, CC}); diff --git a/cinn/optim/transform_polyfor_to_for.cc b/cinn/optim/transform_polyfor_to_for.cc index c03d05c1e81ba..a50bd3dec1172 100644 --- a/cinn/optim/transform_polyfor_to_for.cc +++ b/cinn/optim/transform_polyfor_to_for.cc @@ -16,239 +16,8 @@ namespace cinn { namespace optim { -void PolyForAutoSeparate(Expr* expr); - namespace { -/** - * Separate a forloop in the expression. - * *forloop(min, extent) -- to separate - * | - * forloop - * | - * forloop(min, Min(a,b)) - * will be rewriten to - * forloop(min, separator) --- forloop(separator, extent) - * | | - * forloop forloop - * | | - * forloop(min, a) forloop(min, b) - */ -struct ForSeparater : ir::IRMutator { - //! @param forloop The forloop to separate. - //! @param separator The separator to split the domain of the \p forloop. - //! @param sub_forloop The forloop whose extent has a Min node. - ForSeparater(Expr* forloop, Expr separator, ir::For* sub_forloop) - : forloop_(forloop), separator_(separator), sub_forloop_(sub_forloop) {} - - void operator()(Expr* expr) { Visit(expr); } - - private: - void Visit(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } - - void Visit(const ir::For* op, Expr* expr) override { - auto* node = expr->As(); - if (expr == forloop_) { // find the forloop to split - is_separated_ = true; - - auto forloop_branch0 = ir::For::Make(op->loop_var, - op->min, - separator_, - op->for_type(), - op->device_api, - optim::IRCopy(op->body), - op->vectorize_info()); - - is_left_branch_ = true; - Visit(&forloop_branch0.As()->body); - - auto forloop_branch1 = ir::For::Make(op->loop_var, - separator_, - op->extent, - op->for_type(), - op->device_api, - optim::IRCopy(op->body), - op->vectorize_info()); - - is_left_branch_ = false; - Visit(&forloop_branch1.As()->body); - - *expr = ir::Block::Make({forloop_branch0, forloop_branch1}); - } else if (MatchSubForloop(op)) { - CHECK(is_separated_); - sub_forloop_counter_++; - - auto* min_n = op->extent.As(); - CHECK(min_n); - - if (is_left_branch_) { // the first - node->extent = min_n->a(); - } else { // the second - node->extent = min_n->b(); - } - } else { - Visit(&node->body); - } - } - - //! Tell whether we meet the target subforloop, the forloop having the same iterator and min, extent in the body of - //! the forloop(to separate) should be the target sub-forloop. - bool MatchSubForloop(const ir::For* forloop) { - return forloop->loop_var == sub_forloop_->loop_var && forloop->min == sub_forloop_->min && - forloop->extent == sub_forloop_->extent; - } - - private: - bool is_separated_{false}; - int sub_forloop_counter_; - Expr* forloop_; - Expr separator_; - // Tell whether the forloop is located at the root's left branch(min, separator_). - bool is_left_branch_{false}; - ir::For* sub_forloop_; -}; - -/* - * Separate a forloop, if successfully found one and seperate it, just return. - * NOTE The input expressions can only deal with PolyFor, not For nodes. - */ -struct ForAutoSeparateMutator : ir::IRMutator { - Expr* operator()(Expr* expr) { - ir::IRMutator<>::Visit(expr, expr); - return separated_forloop; - } - - private: - //! Fill it if a forloop is separated. - Expr* separated_forloop{}; - void Visit(Expr* expr) { - // The root_ might be replaced only if root_ == the forloop_to_separate. - ir::IRMutator<>::Visit(expr, expr); - } - - void Visit(const ir::For* op, Expr* expr) override { - auto* node = expr->As(); - forloop_stack.push_back(expr); - - // the iterators are not treated as constant. - std::set iterators; - - do { // We use a do-while here to break in any position and continue the post-procession after the while block. - auto* min_n = op->extent.As(); - if (!min_n) break; - // TODO(Superjomn) We can support max latter. - - Expr left = min_n->a(); - Expr right = min_n->b(); - - CHECK(common::IsPureMath(left)); - CHECK(common::IsPureMath(right)); - - // find the forloop level to separate - std::vector separate_levels; - int level = 0; - for (auto& _forloop : forloop_stack) { - auto* forloop = _forloop->As(); - iterators.insert(forloop->loop_var->name); - - bool contains = - common::MathContainsSymbol(left, forloop->loop_var) || common::MathContainsSymbol(right, forloop->loop_var); - if (contains) separate_levels.push_back(level); - bool forloop_extent_is_var = !ir::CollectIRNodes(forloop->extent, [&](const Expr* n) { - return n->is_var() && !iterators.count(n->As()->name); - }).empty(); - if (forloop_extent_is_var && !separate_levels.empty()) { - VLOG(3) << "forloop_extent is a var, " << forloop->extent << " quit separation"; - separate_levels.clear(); - break; - } - level++; - } - //! ignore the complex cases. - if (separate_levels.empty()) { - VLOG(3) << "separate_levels is empty, quit separation"; - return; - } - if (separate_levels.size() > 1) break; - CHECK_EQ(separate_levels.size(), 1UL); - - Expr* forloop_to_separate_expr = forloop_stack[separate_levels.front()]; - auto forloop_to_separate = forloop_to_separate_expr->As(); - - // check the min not include the current iterator, or it is illegal. - Expr solve_res; - bool is_positive; - // solve left <= right - std::tie(solve_res, is_positive) = common::Solve(right, left, forloop_to_separate->loop_var); - VLOG(4) << "solve_res: " << solve_res; - VLOG(4) << "is_positive: " << is_positive; - - // make a round if is a float - if (solve_res.type().is_float()) { - float v = solve_res.as_float(); - int32_t x = is_positive ? std::floor(v) : std::ceil(v); - solve_res = Expr(x); - } - - // separate to two forloops with domain: - // 1. (0, solve_res) with min.lhs - // 2. (solve_res, extent) with min.rhs - - ForSeparater for_separater(forloop_to_separate_expr, solve_res, node); - for_separater(forloop_to_separate_expr); - - separated_forloop = forloop_to_separate_expr; - return; - // Visit(forloop_to_separate_expr); - // iterator >= solve_res - } while (false); - - Visit(&node->body); - } - - //! Separate the PolyFor into two PolyFors. - void SeparateForloop(Expr* poly_for_expr, Expr upper_bound) { - auto* node = poly_for_expr->As(); - CHECK(node); - CHECK(common::is_zero(node->min)); - - Expr body = node; - } - - //! Stack of the forloops. - std::vector forloop_stack; -}; - -Expr* PolyForAutoSeparateHelper(Expr* expr) { - ForAutoSeparateMutator mutator; - return mutator(expr); -} - -struct ForAutoSeparateMutatorMain : public ir::IRMutator { - void operator()(Expr* expr) { Visit(expr); } - - private: - void Visit(const ir::Block* op, Expr* expr) { - auto* node = expr->As(); - for (auto& expr : node->stmts) { - auto* res = PolyForAutoSeparateHelper(&expr); - if (res) { - Visit(res); - } - } - } - - void Visit(const ir::For* op, Expr* expr) { - auto* res = PolyForAutoSeparateHelper(expr); - if (res) Visit(res); - } - - void Visit(Expr* expr) { - CHECK(expr); - ir::IRMutator<>::Visit(expr, expr); - } -}; - Expr PlusOneWithMinMax(Expr expr) { auto* min_n = expr.As(); auto* max_n = expr.As(); @@ -276,50 +45,63 @@ struct PolyForWithSimpleConditionToForMutator : public ir::IRMutator { void Visit(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } void Visit(const ir::PolyFor* op, Expr* expr) override { - auto* lt_n = op->condition.As(); - auto* le_n = op->condition.As(); + auto* node = expr->As(); + auto* ge_n = node->condition.As(); + auto* gt_n = node->condition.As(); + if (ge_n) { + node->condition = (ge_n->a() * -1) <= (ge_n->b() * -1); + } + if (gt_n) { + node->condition = (ge_n->a() * -1) < (ge_n->b() * -1); + } + + auto* lt_n = node->condition.As(); + auto* le_n = node->condition.As(); + + if (lt_n) { + if (lt_n->b() != common::make_const(0)) { + node->condition = lt_n->a() - lt_n->b() < 0; + } + } + if (le_n) { + if (le_n->b() != common::make_const(0)) { + node->condition = le_n->a() - le_n->b() <= 0; + } + } + lt_n = node->condition.As(); + le_n = node->condition.As(); if (!(lt_n || le_n)) return; // check the lhs is the iterator bool can_extract_extent = (lt_n && lt_n->a().as_var() && lt_n->a().as_var()->name == op->iterator->name) || (le_n && le_n->a().as_var() && le_n->a().as_var()->name == op->iterator->name); - if (can_extract_extent) { - Expr lhs = lt_n ? lt_n->a() : le_n->a(); - Expr rhs = lt_n ? lt_n->b() : PlusOneWithMinMax(le_n->b()); - rhs = common::AutoSimplify(rhs); - - if (op->is_vectorized()) CHECK(op->vectorize_info().valid()); - Expr new_for = - ir::For::Make(op->iterator, op->init, rhs, op->for_type(), op->device_api, op->body, op->vectorize_info()); - *expr = new_for; - - Visit(&new_for.As()->body); + if (!can_extract_extent) { + node->condition = common::SolveInequality(node->condition, op->iterator); + optim::Simplify(&node->condition); + lt_n = node->condition.As(); + le_n = node->condition.As(); + if (!(lt_n || le_n)) return; } - } -}; -} // namespace + Expr lhs = lt_n ? lt_n->a() : le_n->a(); + Expr rhs = lt_n ? lt_n->b() : PlusOneWithMinMax(le_n->b()); + rhs = common::AutoSimplify(rhs); -namespace detail { + if (op->is_vectorized()) CHECK(op->vectorize_info().valid()); -void PolyForWithSimpleConditionToFor(Expr* expr) { - PolyForWithSimpleConditionToForMutator mutator; - mutator(expr); -} + Expr new_for = + ir::For::Make(op->iterator, op->init, rhs, op->for_type(), op->device_api, op->body, op->vectorize_info()); + *expr = new_for; -void PolyForAutoSeparate(Expr* expr) { - ForAutoSeparateMutatorMain main; - main(expr); -} + Visit(&new_for.As()->body); + } +}; -} // namespace detail +} // namespace -void TransformPolyForToFor(Expr* expr, bool auto_separate) { - detail::PolyForWithSimpleConditionToFor(expr); - if (auto_separate) detail::PolyForAutoSeparate(expr); -} +void TransformPolyForToFor(Expr* expr, bool auto_separate) { PolyForWithSimpleConditionToForMutator()(expr); } } // namespace optim } // namespace cinn diff --git a/cinn/optim/transform_polyfor_to_for.h b/cinn/optim/transform_polyfor_to_for.h index 93cf086b07773..0753c7c83b957 100644 --- a/cinn/optim/transform_polyfor_to_for.h +++ b/cinn/optim/transform_polyfor_to_for.h @@ -12,17 +12,6 @@ namespace detail { void PolyForWithSimpleConditionToFor(Expr* expr); -//! Automatically separate the PolyFor with some specific kind of conditions(such as i < min(a, b)) into two For nodes. -//! e.g. PolyFor(i, 0, 100) { PolyFor(j, 0, min(i, 40))} -//! to -//! \code -//! { -//! PolyFor(i, 0, 40) { PolyFor(j, 0, i) } -//! PolyFor(i, 40, 100) { PolyFor(j, 0, 40) } -//! } -//! \endcode -void PolyForAutoSeparate(Expr* expr); - } // namespace detail } // namespace optim diff --git a/cinn/optim/transform_polyfor_to_for_test.cc b/cinn/optim/transform_polyfor_to_for_test.cc index 99a00e3dc761a..49b00dc037f1d 100644 --- a/cinn/optim/transform_polyfor_to_for_test.cc +++ b/cinn/optim/transform_polyfor_to_for_test.cc @@ -72,15 +72,8 @@ void matmul(void* _args, int32_t num_args) float* C = ((float*)(_C->host_memory)); for (int32_t i_outer = 0; i_outer < 64; i_outer += 1) { for (int32_t i_inner = 0; i_inner < 8; i_inner += 1) { - for (int32_t j_outer = 0; j_outer < 62; j_outer += 1) { - for (int32_t j_inner = 0; j_inner < 8; j_inner += 1) { - for (int32_t k0 = 0; k0 < 200; k0 += 1) { - C[((500 * i_inner) + ((4000 * i_outer) + ((8 * j_outer) + j_inner)))] = (C[((500 * i_inner) + ((4000 * i_outer) + ((8 * j_outer) + j_inner)))] + (A[((200 * i_inner) + ((1600 * i_outer) + k0))] * B[((8 * j_outer) + ((500 * k0) + j_inner))])); - }; - }; - }; - for (int32_t j_outer = 62; j_outer < 63; j_outer += 1) { - for (int32_t j_inner = 0; j_inner < (500 + (-8 * j_outer)); j_inner += 1) { + for (int32_t j_outer = 0; j_outer < 63; j_outer += 1) { + for (int32_t j_inner = 0; j_inner < (1 + ((int32_t)(cinn_min(7, (499 + (-8 * j_outer)))))); j_inner += 1) { for (int32_t k0 = 0; k0 < 200; k0 += 1) { C[((500 * i_inner) + ((4000 * i_outer) + ((8 * j_outer) + j_inner)))] = (C[((500 * i_inner) + ((4000 * i_outer) + ((8 * j_outer) + j_inner)))] + (A[((200 * i_inner) + ((1600 * i_outer) + k0))] * B[((8 * j_outer) + ((500 * k0) + j_inner))])); }; diff --git a/cinn/optim/vectorize_loops_test.cc b/cinn/optim/vectorize_loops_test.cc index 3832a2d02717a..d732014009f76 100644 --- a/cinn/optim/vectorize_loops_test.cc +++ b/cinn/optim/vectorize_loops_test.cc @@ -19,123 +19,6 @@ using namespace ir; // NOLINT using utils::GetStreamCnt; using utils::Trim; -TEST(VectorizeLoops, Split_sperate) { - Expr M(100); - Expr K(200); - Expr N(500); - Expr bn(32); - Placeholder A("A", {M, K}); - Placeholder B("B", {K, N}); - - // C = A * B - lang::Buffer C_buf(Float(32)); - Var k(K.as_int32(), "k0"); - - Tensor C = Compute({M, N}, [&](Var i, Var j) { return lang::Sum(A(i, k) * B(k, j)); }, "C", {k}); - C->Bind(C_buf); - - { - poly::Iterator i_outer, i_inner, j_outer, j_inner, k_outer, k_inner; - std::tie(i_outer, i_inner, j_outer, j_inner) = C->stage()->Tile(0, 1, bn.as_int32(), bn.as_int32()); - std::tie(k_outer, k_inner) = C->stage()->Split(poly::Iterator("k0"), 8); - C->stage()->Reorder({i_outer, j_outer, k_outer, k_inner, i_inner, j_inner}); - C->stage()->Split(j_inner, 8); - } - - // Code gen - auto funcs = Lower("matmul", {A, B, C}); - - Target target; - target.arch = Target::Arch ::X86; - target.bits = Target::Bit ::k32; - target.os = Target::OS ::Linux; - - Expr body = optim::Optimize(Expr(funcs)); - - lang::Module::Builder builder("module1", target); - builder.AddFunction(ir::LoweredFunc(body.As())); - - CodeGenC codegen(target); - codegen.SetInlineBuiltinCodes(false); - auto out = codegen.Compile(builder.Build(), CodeGenC::OutputKind::CImpl); - - auto target_out = R"ROC( -#include -#include - -void matmul(void* _args, int32_t num_args) -{ - const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0])); - const cinn_buffer_t* _B = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1])); - cinn_buffer_t* _C = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[2])); - cinn_buffer_malloc((void*)(0), _C); - const float* A = ((const float*)(_A->host_memory)); - const float* B = ((const float*)(_B->host_memory)); - float* C = ((float*)(_C->host_memory)); - for (int32_t i_outer = 0; i_outer < 3; i_outer += 1) { - for (int32_t j_outer = 0; j_outer < 15; j_outer += 1) { - for (int32_t k0_outer = 0; k0_outer < 25; k0_outer += 1) { - for (int32_t k0_inner = 0; k0_inner < 8; k0_inner += 1) { - for (int32_t i_inner = 0; i_inner < 32; i_inner += 1) { - for (int32_t j_inner_outer = 0; j_inner_outer < 4; j_inner_outer += 1) { - for (int32_t j_inner_inner = 0; j_inner_inner < cinn_min(8, (500 + ((-8 * j_inner_outer) + (-32 * j_outer)))); j_inner_inner += 1) { - C[((500 * i_inner) + ((16000 * i_outer) + ((8 * j_inner_outer) + ((32 * j_outer) + j_inner_inner))))] = (C[((500 * i_inner) + ((16000 * i_outer) + ((8 * j_inner_outer) + ((32 * j_outer) + j_inner_inner))))] + (A[((200 * i_inner) + ((6400 * i_outer) + ((8 * k0_outer) + k0_inner)))] * B[((8 * j_inner_outer) + ((32 * j_outer) + ((500 * k0_inner) + ((4000 * k0_outer) + j_inner_inner))))])); - }; - }; - }; - }; - }; - }; - for (int32_t j_outer = 15; j_outer < 16; j_outer += 1) { - for (int32_t k0_outer = 0; k0_outer < 25; k0_outer += 1) { - for (int32_t k0_inner = 0; k0_inner < 8; k0_inner += 1) { - for (int32_t i_inner = 0; i_inner < 32; i_inner += 1) { - for (int32_t j_inner_outer = 0; j_inner_outer < (63 + (-4 * j_outer)); j_inner_outer += 1) { - for (int32_t j_inner_inner = 0; j_inner_inner < cinn_min(8, (500 + ((-8 * j_inner_outer) + (-32 * j_outer)))); j_inner_inner += 1) { - C[((500 * i_inner) + ((16000 * i_outer) + ((8 * j_inner_outer) + ((32 * j_outer) + j_inner_inner))))] = (C[((500 * i_inner) + ((16000 * i_outer) + ((8 * j_inner_outer) + ((32 * j_outer) + j_inner_inner))))] + (A[((200 * i_inner) + ((6400 * i_outer) + ((8 * k0_outer) + k0_inner)))] * B[((8 * j_inner_outer) + ((32 * j_outer) + ((500 * k0_inner) + ((4000 * k0_outer) + j_inner_inner))))])); - }; - }; - }; - }; - }; - }; - }; - for (int32_t i_outer = 3; i_outer < 4; i_outer += 1) { - for (int32_t j_outer = 0; j_outer < 15; j_outer += 1) { - for (int32_t k0_outer = 0; k0_outer < 25; k0_outer += 1) { - for (int32_t k0_inner = 0; k0_inner < 8; k0_inner += 1) { - for (int32_t i_inner = 0; i_inner < (100 + (-32 * i_outer)); i_inner += 1) { - for (int32_t j_inner_outer = 0; j_inner_outer < 4; j_inner_outer += 1) { - for (int32_t j_inner_inner = 0; j_inner_inner < cinn_min(8, (500 + ((-8 * j_inner_outer) + (-32 * j_outer)))); j_inner_inner += 1) { - C[((500 * i_inner) + ((16000 * i_outer) + ((8 * j_inner_outer) + ((32 * j_outer) + j_inner_inner))))] = (C[((500 * i_inner) + ((16000 * i_outer) + ((8 * j_inner_outer) + ((32 * j_outer) + j_inner_inner))))] + (A[((200 * i_inner) + ((6400 * i_outer) + ((8 * k0_outer) + k0_inner)))] * B[((8 * j_inner_outer) + ((32 * j_outer) + ((500 * k0_inner) + ((4000 * k0_outer) + j_inner_inner))))])); - }; - }; - }; - }; - }; - }; - for (int32_t j_outer = 15; j_outer < 16; j_outer += 1) { - for (int32_t k0_outer = 0; k0_outer < 25; k0_outer += 1) { - for (int32_t k0_inner = 0; k0_inner < 8; k0_inner += 1) { - for (int32_t i_inner = 0; i_inner < (100 + (-32 * i_outer)); i_inner += 1) { - for (int32_t j_inner_outer = 0; j_inner_outer < (63 + (-4 * j_outer)); j_inner_outer += 1) { - for (int32_t j_inner_inner = 0; j_inner_inner < cinn_min(8, (500 + ((-8 * j_inner_outer) + (-32 * j_outer)))); j_inner_inner += 1) { - C[((500 * i_inner) + ((16000 * i_outer) + ((8 * j_inner_outer) + ((32 * j_outer) + j_inner_inner))))] = (C[((500 * i_inner) + ((16000 * i_outer) + ((8 * j_inner_outer) + ((32 * j_outer) + j_inner_inner))))] + (A[((200 * i_inner) + ((6400 * i_outer) + ((8 * k0_outer) + k0_inner)))] * B[((8 * j_inner_outer) + ((32 * j_outer) + ((500 * k0_inner) + ((4000 * k0_outer) + j_inner_inner))))])); - }; - }; - }; - }; - }; - }; - }; - cinn_buffer_free((void*)(0), _C); -} -)ROC"; - - std::cout << "\n" << out << std::endl; - EXPECT_EQ(utils::Trim(target_out), utils::Trim(out)); -} - TEST(Vectorize, replace_var) { using namespace ir; // NOLINT diff --git a/cinn/poly/stage_test.cc b/cinn/poly/stage_test.cc index ee9628dc671ee..5c8879228d724 100644 --- a/cinn/poly/stage_test.cc +++ b/cinn/poly/stage_test.cc @@ -225,6 +225,8 @@ function fn (_A, _cache, _C) } TEST(ComputeAt, level1) { + Context::Global().ResetNameId(); + Expr M(100), N(200); Placeholder A("A", {M, N}); Placeholder B("B", {M, N}); @@ -251,9 +253,9 @@ function fn (_A, _B, _cache, _C) for (po1, 10) { if (((((po0 >= 0) and (po0 <= 9)) and (po1 >= 0)) and (po1 <= 9))) { - poly_for (i, 0, ((i + cinn_max(0, (po0 - 1))) <= (po0 + 1)), 1) + for (i, (1 + int32((1 + (po0 - cinn_max(0, (po0 - 1))))))) { - cache[i, 0] = A[i, po1] + cache[i, 0] = A[(i + cinn_max(0, (po0 - 1))), po1] } } C[po0, po1] = select((po0 < 10), (cache[-1, 0] + (cache[0, 0] + (cache[1, 0] + B[po0, po1]))), 0) @@ -275,6 +277,7 @@ function fn (_A, _B, _cache, _C) } TEST(ComputeAt, simple) { + /* { Expr n(64); auto A = Placeholder("A", {n, n}); @@ -289,6 +292,7 @@ TEST(ComputeAt, simple) { auto fn = Lower("fn", {A, A1, B}); LOG(INFO) << "fn:\n" << fn; } + */ { Expr n(64); @@ -313,11 +317,11 @@ function fn (_A, _A1, _B) for (po1, 16) { if (((((po1 >= 0) and (((16 * po0) + po1) >= 0)) and (po1 <= 15)) and (((16 * po0) + po1) <= 31))) { - for (i, (3 + ((16 * po0) + po1))) + for (i, 3) { for (j, 32) { - A1[i, j] = A[i, j] + A1[i, j] = A[(i + ((16 * po0) + po1)), j] } } }