Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Browse files Browse the repository at this point in the history
… develop_scaleed_dot_product_attention_api
  • Loading branch information
liuzhenhai93 committed Aug 1, 2023
2 parents 29a118d + bbdc168 commit ddcff7a
Show file tree
Hide file tree
Showing 614 changed files with 23,256 additions and 9,330 deletions.
2 changes: 1 addition & 1 deletion .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ Checks: '
modernize-make-unique,
-modernize-pass-by-value,
-modernize-raw-string-literal,
-modernize-redundant-void-arg,
modernize-redundant-void-arg,
-modernize-replace-auto-ptr,
-modernize-replace-random-shuffle,
-modernize-shrink-to-fit,
Expand Down
24 changes: 0 additions & 24 deletions cmake/cinn/core.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -433,28 +433,6 @@ function(download_and_uncompress INSTALL_DIR URL FILENAME)
INSTALL_COMMAND "")
endfunction()

set(fusion_pass_file
${CMAKE_CURRENT_BINARY_DIR}/paddle/cinn/hlir/pass/use_general_pass.h
CACHE INTERNAL "use_general_pass.h file")
file(
WRITE ${fusion_pass_file}
"#include \"paddle/cinn/common/macros.h\" // Generated by the paddle/cinn/hlir/pass/CMakeLists.txt. DO NOT EDIT!\n\n"
)

function(find_fusion_pass_register FILENAME ADD_PATH PATTERN)
# set op_name to OUTPUT
file(READ ${FILENAME} CONTENT)
string(REGEX MATCHALL "${PATTERN}\\([a-zA-Z0-9_]*," fusion_pass_patterns
"${CONTENT}")
if(NOT fusion_pass_patterns STREQUAL "")
foreach(pass_pattern ${fusion_pass_patterns})
string(REPLACE "${PATTERN}(" "" pass_pattern "${pass_pattern}")
string(REPLACE "," "" pass_pattern "${pass_pattern}")
file(APPEND ${ADD_PATH} "USE_FUSION_PASS(${pass_pattern});\n")
endforeach()
endif()
endfunction()

function(gather_srcs SRC_GROUP)
set(options)
set(oneValueArgs)
Expand All @@ -464,8 +442,6 @@ function(gather_srcs SRC_GROUP)
set(${SRC_GROUP}
"${${SRC_GROUP}};${CMAKE_CURRENT_SOURCE_DIR}/${cpp}"
CACHE INTERNAL "")
find_fusion_pass_register("${CMAKE_CURRENT_SOURCE_DIR}/${cpp}"
${fusion_pass_file} "CINN_REGISTER_FUSION_PASS")
endforeach()
endfunction()

Expand Down
25 changes: 19 additions & 6 deletions cmake/phi.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ function(kernel_declare TARGET_LIST)
string(
REGEX
MATCH
"(PD_REGISTER_KERNEL|PD_REGISTER_KERNEL_FOR_ALL_DTYPE|PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE)\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z_]*,[ \\\t\r\n]*[A-Z_]*"
"(PD_REGISTER_KERNEL|PD_REGISTER_KERNEL_FOR_ALL_DTYPE|PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE|PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM)\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z_]*,[ \\\t\r\n]*[A-Z_]*"
first_registry
"${kernel_impl}")
set(kernel_declare_id "")
Expand Down Expand Up @@ -115,13 +115,26 @@ function(kernel_declare TARGET_LIST)
string(
REGEX
MATCH
"(PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE)\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z_]*,[ \\\t\r\n]*[A-Z_]*"
"(PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM)\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z_]*,[ \\\t\r\n]*[A-Z_]*"
is_all_backend
"${first_registry}")
if(NOT is_all_backend STREQUAL "")
# parse the registerd kernel message
string(
REPLACE "PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM("
"" kernel_msg "${first_registry}")
else()
string(
REGEX
MATCH
"(PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE)\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z_]*,[ \\\t\r\n]*[A-Z_]*"
is_all_backend
"${first_registry}")

# parse the registerd kernel message
string(REPLACE "PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(" ""
kernel_msg "${first_registry}")
# parse the registerd kernel message
string(REPLACE "PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(" ""
kernel_msg "${first_registry}")
endif()
string(REPLACE "PD_REGISTER_KERNEL(" "" kernel_msg "${kernel_msg}")
string(REPLACE "PD_REGISTER_KERNEL_FOR_ALL_DTYPE(" "" kernel_msg
"${kernel_msg}")
Expand All @@ -146,7 +159,7 @@ function(kernel_declare TARGET_LIST)
string(
REGEX
MATCH
"(PD_REGISTER_KERNEL|PD_REGISTER_KERNEL_FOR_ALL_DTYPE|PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE)\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z_]*,[ \\\t\r\n]*[A-Z_]*"
"(PD_REGISTER_KERNEL|PD_REGISTER_KERNEL_FOR_ALL_DTYPE|PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE|PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM)\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z_]*,[ \\\t\r\n]*[A-Z_]*"
first_registry
"${kernel_impl}")
endif()
Expand Down
3 changes: 0 additions & 3 deletions paddle/cinn/auto_schedule/auto_tuner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
#include "paddle/cinn/runtime/flags.h"

DECLARE_bool(auto_schedule_use_cost_model);
DECLARE_bool(cinn_ir_schedule);

namespace cinn {
namespace auto_schedule {
Expand Down Expand Up @@ -70,8 +69,6 @@ class TestAutoTuner : public ::testing::Test {

void SetUp() override {
srand(0);
// AutoTuner is combined with new IR Schedule
FLAGS_cinn_ir_schedule = true;
std::unordered_set<std::string> fetch_ids;
auto program = CreateAddReluProgram();
auto graph = cinn::frontend::Optimize(&program, fetch_ids, target);
Expand Down
3 changes: 0 additions & 3 deletions paddle/cinn/auto_schedule/measure/measurer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/runtime/flags.h"

DECLARE_bool(cinn_ir_schedule);

namespace cinn {
namespace auto_schedule {

Expand All @@ -55,7 +53,6 @@ class TestMeasurer : public ::testing::Test {
std::vector<MeasureInput> inputs;

void SetUp() override {
FLAGS_cinn_ir_schedule = true;
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
#include "paddle/cinn/utils/string.h"
#include "test/cpp/cinn/concrete_program_builder.h"

DECLARE_bool(cinn_ir_schedule);

namespace cinn {
namespace auto_schedule {

Expand Down Expand Up @@ -155,7 +153,6 @@ TEST(AutoInline, AddReluInline) {

frontend::Program program = builder.Build();

FLAGS_cinn_ir_schedule = true;
auto graph = std::make_shared<Graph>(program, target);
hlir::framework::ApplyPass(graph.get(), "OpFusionPass");

Expand Down
2 changes: 0 additions & 2 deletions paddle/cinn/auto_schedule/task/task_registry_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
#include "paddle/cinn/utils/type_defs.h"

DECLARE_bool(auto_schedule_use_cost_model);
DECLARE_bool(cinn_ir_schedule);

namespace cinn {
namespace auto_schedule {
Expand Down Expand Up @@ -70,7 +69,6 @@ std::shared_ptr<hlir::framework::Graph> CreateAddProgram(

TEST(TestTaskRegistry, basic) {
FLAGS_auto_schedule_use_cost_model = true;
FLAGS_cinn_ir_schedule = true;

#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
Expand Down
6 changes: 0 additions & 6 deletions paddle/cinn/auto_schedule/task/tune_task_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/utils/string.h"

DECLARE_bool(cinn_ir_schedule);

namespace cinn {
namespace auto_schedule {

Expand All @@ -59,8 +57,6 @@ Program CreateAddProgram() {
}

TEST(TuneTask, GraphToUnoptLoweredFunc_NoPass) {
// Auto tuner is combined with IR schedule
FLAGS_cinn_ir_schedule = true;
Context::Global().ResetNameId();
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
Expand Down Expand Up @@ -170,8 +166,6 @@ TEST(TuneTask, GraphToUnoptLoweredFunc_NoPass) {
}

TEST(TuneTask, GraphToUnoptLoweredFunc_ApplyPass) {
// Auto tuner is combined with IR schedule
FLAGS_cinn_ir_schedule = true;
Context::Global().ResetNameId();
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/backends/ir_schedule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ void TestSplitThrow() {
std::vector<Expr> vec_ast{ast_expr};
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(
mod_expr, -1, false, ir::ScheduleErrorMessageLevel::kGeneral);
mod_expr, -1, false, utils::ErrorMessageLevel::kGeneral);
auto fused = ir_sch.Fuse("B", {0, 1});
// statement that cause the exception
auto splited = ir_sch.Split(fused, {-1, -1});
Expand All @@ -196,7 +196,7 @@ void TestSplitThrow() {
auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl);
}
TEST(IrSchedule, split_throw) {
ASSERT_THROW(TestSplitThrow(), ir::enforce::EnforceNotMet);
ASSERT_THROW(TestSplitThrow(), utils::enforce::EnforceNotMet);
}

TEST(IrSchedule, reorder1) {
Expand Down
11 changes: 0 additions & 11 deletions paddle/cinn/common/macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,6 @@
__test_global_namespace_##uniq_name##__>::value, \
msg)

#define CINN_REGISTER_FUSION_PASS(pass_name, pass_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_pass__##pass_name, \
"CINN_REGISTER_FUSION_PASS must be called in global namespace"); \
static ::cinn::hlir::pass::FusionPassRegistrar<pass_class> \
__pass_registrar_##pass_name##__(#pass_name); \
int TouchFusionPassRegistrar_##pass_name() { \
__pass_registrar_##pass_name##__.Touch(); \
return 0; \
}

#define USE_FUSION_PASS(pass_name) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_fusion_pass_##pass_name, \
Expand Down
1 change: 0 additions & 1 deletion paddle/cinn/frontend/decomposer/test_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/tensor.h"
#include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/hlir/pass/use_general_pass.h"
#include "paddle/cinn/hlir/pass/use_pass.h"

namespace cinn::frontend {
Expand Down
1 change: 0 additions & 1 deletion paddle/cinn/frontend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/hlir/pass/use_general_pass.h"
#include "paddle/cinn/hlir/pass/use_pass.h"
#include "paddle/cinn/runtime/flags.h"

Expand Down
1 change: 0 additions & 1 deletion paddle/cinn/frontend/optimize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/visualize_helper.h"
#include "paddle/cinn/hlir/pass/use_general_pass.h"
#include "paddle/cinn/hlir/pass/use_pass.h"
#include "paddle/cinn/runtime/flags.h"

Expand Down
17 changes: 17 additions & 0 deletions paddle/cinn/frontend/pass/expand_zero_dim_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ class ExpandZeroDimPass : public ProgramPass {
if (instr->op_type == "transpose") {
builder.AppendInstruction(HandleTranspose(instr));
continue;
} else if (instr->op_type == "fill_constant") {
builder.AppendInstruction(HandleFillConstant(instr));
continue;
}
for (auto& input : instr->inputs) {
if (input->shape.empty()) {
Expand Down Expand Up @@ -101,6 +104,20 @@ class ExpandZeroDimPass : public ProgramPass {
}
return new_instr;
}

// Before: out-0D = fill_constant([], 123.456, "out", "float32")
// After: out-1D = fill_constant([1], 123.456, "out", "float32")
Instruction HandleFillConstant(const Instruction& instr) {
Instruction new_instr = instr;
std::vector<int32_t> shape =
new_instr.GetAttrs<std::vector<int32_t>>("shape");
if (shape.empty()) {
shape.push_back(1);
VLOG(4) << "Change fill_constant's attribute shape from [] to [1]";
}
new_instr.SetAttr<std::vector<int32_t>>("shape", shape);
return new_instr;
}
};

} // namespace pass
Expand Down
1 change: 0 additions & 1 deletion paddle/cinn/frontend/pass/test_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "paddle/cinn/frontend/program_pass.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/pass/use_general_pass.h"
#include "paddle/cinn/hlir/pass/use_pass.h"

namespace cinn::frontend {
Expand Down
8 changes: 8 additions & 0 deletions paddle/cinn/hlir/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ gather_srcs(
buffer.cc
memory.cc
instruction.cc
program.cc
parallel_compiler.cc
graph_compiler.cc
graph.cc
Expand All @@ -20,6 +21,13 @@ gather_srcs(
accuracy_checker.cc
visualize_helper.cc)

# TODO(Aurelius84): new_ir_compiler depends on pd_dialect and could
# not found under CINN_ONLY mode
if(NOT CINN_ONLY)
cinn_cc_library(new_ir_compiler SRCS new_ir_compiler.cc DEPS cinncore
pd_dialect)
endif()

if(WITH_CUDA)
cinn_nv_test(test_hlir_framework_buffer SRCS buffer_test.cc DEPS cinncore)
cinn_cc_test(test_hlir_framework_accuracy_checker SRCS
Expand Down
Loading

0 comments on commit ddcff7a

Please sign in to comment.